keras/keras_core/backend/jax/numpy.py
Chen Qian f6df67f2d2 Add numpy module in jax/ and tensorflow/ (#13)
* Add jax/numpy and tensorflow/numpy

* refactor code

* more

* even better
2023-04-18 18:45:30 -07:00

26 lines
422 B
Python

import jax.numpy as jnp
def add(x1, x2):
return jnp.add(x1, x2)
def subtract(x1, x2):
return jnp.subtract(x1, x2)
def matmul(x1, x2):
return jnp.matmul(x1, x2)
def multiply(x1, x2):
return jnp.multiply(x1, x2)
def mean(x, axis=None, keepdims=False):
return jnp.mean(x, axis=axis, keepdims=keepdims)
def max(x, axis=None, keepdims=False):
return jnp.max(x, axis=axis, keepdims=keepdims)