import jax.numpy as jnp from keras_core.backend.jax.core import convert_to_tensor def add(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.add(x1, x2) def bincount(x, weights=None, minlength=0): if len(x.shape) == 2: bincounts = [ jnp.bincount(arr, weights=weights, minlength=minlength) for arr in list(x) ] return jnp.stack(bincounts) return jnp.bincount(x, weights=weights, minlength=minlength) def einsum(subscripts, *operands, **kwargs): return jnp.einsum(subscripts, *operands, **kwargs) def subtract(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.subtract(x1, x2) def matmul(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.matmul(x1, x2) def multiply(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.multiply(x1, x2) def mean(x, axis=None, keepdims=False): return jnp.mean(x, axis=axis, keepdims=keepdims) def max(x, axis=None, keepdims=False, initial=None): return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial) def ones(shape, dtype="float32"): return jnp.ones(shape, dtype=dtype) def zeros(shape, dtype="float32"): return jnp.zeros(shape, dtype=dtype) def absolute(x): return jnp.absolute(x) def abs(x): return absolute(x) def all(x, axis=None, keepdims=False): return jnp.all(x, axis=axis, keepdims=keepdims) def any(x, axis=None, keepdims=False): return jnp.any(x, axis=axis, keepdims=keepdims) def amax(x, axis=None, keepdims=False): return jnp.amax(x, axis=axis, keepdims=keepdims) def amin(x, axis=None, keepdims=False): return jnp.amin(x, axis=axis, keepdims=keepdims) def append( x1, x2, axis=None, ): return jnp.append(x1, x2, axis=axis) def arange(start, stop=None, step=None, dtype=None): return jnp.arange(start, stop, step=step, dtype=dtype) def arccos(x): return jnp.arccos(x) def arcsin(x): return jnp.arcsin(x) def arctan(x): return jnp.arctan(x) def arctan2(x1, x2): return jnp.arctan2(x1, x2) def argmax(x, axis=None): return jnp.argmax(x, axis=axis) def argmin(x, axis=None): return jnp.argmin(x, axis=axis) def argsort(x, axis=-1): return jnp.argsort(x, axis=axis) def array(x, dtype=None): return jnp.array(x, dtype=dtype) def average(x, axis=None, weights=None): return jnp.average(x, weights=weights, axis=axis) def broadcast_to(x, shape): return jnp.broadcast_to(x, shape) def ceil(x): return jnp.ceil(x) def clip(x, x_min, x_max): return jnp.clip(x, x_min, x_max) def concatenate(xs, axis=0): return jnp.concatenate(xs, axis=axis) def conjugate(x): return jnp.conjugate(x) def conj(x): return conjugate(x) def copy(x): return jnp.copy(x) def cos(x): return jnp.cos(x) def count_nonzero(x, axis=None): return jnp.count_nonzero(x, axis=axis) def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): return jnp.cross( x1, x2, axisa=axisa, axisb=axisb, axisc=axisc, axis=axis, ) def cumprod(x, axis=None): return jnp.cumprod(x, axis=axis) def cumsum(x, axis=None): return jnp.cumsum(x, axis=axis) def diag(x, k=0): return jnp.diag(x, k=k) def diagonal(x, offset=0, axis1=0, axis2=1): return jnp.diagonal( x, offset=offset, axis1=axis1, axis2=axis2, ) def dot(x, y): return jnp.dot(x, y) def empty(shape, dtype="float32"): return jnp.empty(shape, dtype=dtype) def equal(x1, x2): return jnp.equal(x1, x2) def exp(x): return jnp.exp(x) def expand_dims(x, axis): return jnp.expand_dims(x, axis) def expm1(x): return jnp.expm1(x) def flip(x, axis=None): return jnp.flip(x, axis=axis) def floor(x): return jnp.floor(x) def full(shape, fill_value, dtype=None): return jnp.full(shape, fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): return jnp.full_like(x, fill_value, dtype=dtype) def greater(x1, x2): return jnp.greater(x1, x2) def greater_equal(x1, x2): return jnp.greater_equal(x1, x2) def hstack(xs): return jnp.hstack(xs) def identity(n, dtype="float32"): return jnp.identity(n, dtype=dtype) def imag(x): return jnp.imag(x) def isclose(x1, x2): return jnp.isclose(x1, x2) def isfinite(x): return jnp.isfinite(x) def isinf(x): return jnp.isinf(x) def isnan(x): return jnp.isnan(x) def less(x1, x2): return jnp.less(x1, x2) def less_equal(x1, x2): return jnp.less_equal(x1, x2) def linspace( start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): return jnp.linspace( start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype, axis=axis, ) def log(x): return jnp.log(x) def log10(x): return jnp.log10(x) def log1p(x): return jnp.log1p(x) def log2(x): return jnp.log2(x) def logaddexp(x1, x2): return jnp.logaddexp(x1, x2) def logical_and(x1, x2): return jnp.logical_and(x1, x2) def logical_not(x): return jnp.logical_not(x) def logical_or(x1, x2): return jnp.logical_or(x1, x2) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): return jnp.logspace( start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype, axis=axis, ) def maximum(x1, x2): return jnp.maximum(x1, x2) def meshgrid(*x, indexing="xy"): return jnp.meshgrid(*x, indexing=indexing) def min(x, axis=None, keepdims=False, initial=None): return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial) def minimum(x1, x2): return jnp.minimum(x1, x2) def mod(x1, x2): return jnp.mod(x1, x2) def moveaxis(x, source, destination): return jnp.moveaxis(x, source=source, destination=destination) def nan_to_num(x): return jnp.nan_to_num(x) def ndim(x): return jnp.ndim(x) def nonzero(x): return jnp.nonzero(x) def not_equal(x1, x2): return jnp.not_equal(x1, x2) def ones_like(x, dtype=None): return jnp.ones_like(x, dtype=dtype) def zeros_like(x, dtype=None): return jnp.zeros_like(x, dtype=dtype) def outer(x1, x2): return jnp.outer(x1, x2) def pad(x, pad_width, mode="constant"): return jnp.pad(x, pad_width, mode=mode) def prod(x, axis=None, keepdims=False, dtype=None): return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) def ravel(x): return jnp.ravel(x) def real(x): return jnp.real(x) def reciprocal(x): return jnp.reciprocal(x) def repeat(x, repeats, axis=None): return jnp.repeat(x, repeats, axis=axis) def reshape(x, new_shape): return jnp.reshape(x, new_shape) def roll(x, shift, axis=None): return jnp.roll(x, shift, axis=axis) def sign(x): return jnp.sign(x) def sin(x): return jnp.sin(x) def size(x): return jnp.size(x) def sort(x, axis=-1): return jnp.sort(x, axis=axis) def split(x, indices_or_sections, axis=0): return jnp.split(x, indices_or_sections, axis=axis) def stack(x, axis=0): return jnp.stack(x, axis=axis) def std(x, axis=None, keepdims=False): return jnp.std(x, axis=axis, keepdims=keepdims) def swapaxes(x, axis1, axis2): return jnp.swapaxes(x, axis1=axis1, axis2=axis2) def take(x, indices, axis=None): return jnp.take(x, indices, axis=axis) def take_along_axis(x, indices, axis=None): return jnp.take_along_axis(x, indices, axis=axis) def tan(x): return jnp.tan(x) def tensordot(x1, x2, axes=2): return jnp.tensordot(x1, x2, axes=axes) def round(x, decimals=0): return jnp.round(x, decimals=decimals) def tile(x, repeats): return jnp.tile(x, repeats) def trace(x, offset=0, axis1=0, axis2=1): return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2) def tri(N, M=None, k=0, dtype="float32"): return jnp.tri(N, M=M, k=k, dtype=dtype) def tril(x, k=0): return jnp.tril(x, k=k) def triu(x, k=0): return jnp.triu(x, k=k) def vdot(x1, x2): return jnp.vdot(x1, x2) def vstack(xs): return jnp.vstack(xs) def where(condition, x1, x2): return jnp.where(condition, x1, x2) def divide(x1, x2): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.divide(x1, x2) def true_divide(x1, x2): return jnp.true_divide(x1, x2) def power(x1, x2): return jnp.power(x1, x2) def negative(x): return jnp.negative(x) def square(x): return jnp.square(x) def sqrt(x): return jnp.sqrt(x) def squeeze(x, axis=None): return jnp.squeeze(x, axis=axis) def transpose(x, axes=None): return jnp.transpose(x, axes=axes) def var(x, axis=None, keepdims=False): return jnp.var(x, axis=axis, keepdims=keepdims) def sum(x, axis=None, keepdims=False): return jnp.sum(x, axis=axis, keepdims=keepdims) def eye(N, M=None, k=0, dtype="float32"): return jnp.eye(N, M=M, k=k, dtype=dtype)