From afa8882b3d1c8e5079c153907254a5fd43c790d1 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi <60985914+nkovela1@users.noreply.github.com> Date: Thu, 18 May 2023 21:48:18 +0000 Subject: [PATCH] Add numpy ops for PyTorch backend (#182) * Add PyTorch numpy functionality * Add dtype conversion * Partial fix for PyTorch numpy tests * small logic fix * Revert numpy_test * Add tensor conversion to numpy * Fix some arithmetic tests * Fix some torch functions for numpy compatibility * Fix pytorch ops for numpy compatibility, add TODOs * Fix formatting * Implement nits and fix dtype standardization * Add pytest skipif decorator and fix nits * Fix formatting and rename dtypes map * Split tests by backend --- keras_core/backend/common/variables.py | 14 +- keras_core/backend/torch/numpy.py | 726 ++++++++++-------- keras_core/layers/reshaping/up_sampling3d.py | 12 +- keras_core/operations/numpy_test.py | 126 ++- keras_core/utils/feature_space.py | 757 ------------------- keras_core/utils/feature_space_test.py | 378 --------- 6 files changed, 542 insertions(+), 1471 deletions(-) delete mode 100644 keras_core/utils/feature_space.py delete mode 100644 keras_core/utils/feature_space_test.py diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index f04bc292d..e3de325f3 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -336,20 +336,28 @@ ALLOWED_DTYPES = { "int64", "bfloat16", "bool", - "string", +} + +PYTHON_DTYPES_MAP = { + bool: "bool", + int: "int", # TBD by backend + float: "float32", } def standardize_dtype(dtype): + if dtype is None: + return config.floatx() + if dtype in PYTHON_DTYPES_MAP: + dtype = PYTHON_DTYPES_MAP.get(dtype) if dtype == "int": if config.backend() == "tensorflow": dtype = "int64" else: dtype = "int32" - if dtype is None: - return config.floatx() if hasattr(dtype, "name"): dtype = dtype.name + if dtype not in ALLOWED_DTYPES: raise ValueError(f"Invalid dtype: {dtype}") return dtype diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 0e75bcc59..452ca34ea 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -1,93 +1,109 @@ import torch +from keras_core.backend.torch.core import convert_to_tensor from keras_core.backend.torch.core import to_torch_dtype def add(x1, x2): - return x1 + x2 + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.add(x1, x2) def einsum(subscripts, *operands, **kwargs): - pass - # return tfnp.einsum(subscripts, *operands, **kwargs) + operands = [convert_to_tensor(operand) for operand in operands] + return torch.einsum(subscripts, *operands) def subtract(x1, x2): - return x1 - x2 + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.subtract(x1, x2) def matmul(x1, x2): - pass - # return tfnp.matmul(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.matmul(x1, x2) def multiply(x1, x2): - return x1 * x2 + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.multiply(x1, x2) def mean(x, axis=None, keepdims=False): - return torch.mean(x, dim=axis, keepdim=keepdims) + x = convert_to_tensor(x) + return torch.mean(x, axis=axis, keepdims=keepdims) def max(x, axis=None, keepdims=False, initial=None): - # TODO: handle initial - return torch.max(x, dim=axis, keepdim=keepdims) - # The TensorFlow numpy API implementation doesn't support `initial` so we - # handle it manually here. - # if initial is not None: - # return tf.math.maximum( - # tfnp.max(x, axis=axis, keepdims=keepdims), initial - # ) + x = convert_to_tensor(x) + if axis is None: + result = torch.max(x) + else: + if isinstance(axis, list): + axis = axis[-1] + result = torch.max(x, dim=axis, keepdim=keepdims) - # TensorFlow returns -inf by default for an empty list, but for consistency - # with other backends and the numpy API we want to throw in this case. - # tf.assert_greater( - # size(x), - # tf.constant(0, dtype=tf.int64), - # message="Cannot compute the max of an empty tensor.", - # ) + if isinstance(getattr(result, "values", None), torch.Tensor): + result = result.values - # return tfnp.max(x, axis=axis, keepdims=keepdims) + if initial is not None: + return torch.maximum(result, initial) + return result def ones(shape, dtype="float32"): dtype = to_torch_dtype(dtype) - torch.ones(*shape, dtype=dtype) + return torch.ones(*shape, dtype=dtype) def zeros(shape, dtype="float32"): dtype = to_torch_dtype(dtype) - torch.zeros(*shape, dtype=dtype) + return torch.zeros(*shape, dtype=dtype) def absolute(x): - pass - # return tfnp.absolute(x) + return abs(x) def abs(x): - pass - # return absolute(x) + x = convert_to_tensor(x) + return torch.abs(x) def all(x, axis=None, keepdims=False): - pass - # return tfnp.all(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + if axis is not None: + if isinstance(axis, list): + axis = axis[-1] + return torch.all(x, dim=axis, keepdim=keepdims) + else: + return torch.all(x) def any(x, axis=None, keepdims=False): - pass - # return tfnp.any(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + if axis is not None: + if isinstance(axis, list): + axis = axis[-1] + return torch.any(x, dim=axis, keepdim=keepdims) + else: + return torch.any(x) def amax(x, axis=None, keepdims=False): - pass - # return tfnp.amax(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + if axis is not None: + return torch.amax(x, dim=axis, keepdim=keepdims) + else: + return torch.amax(x) def amin(x, axis=None, keepdims=False): - pass - # return tfnp.amin(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + if axis is not None: + return torch.amin(x, dim=axis, keepdim=keepdims) + else: + return torch.amin(x) def append( @@ -95,594 +111,694 @@ def append( x2, axis=None, ): - pass - # return tfnp.append(x1, x2, axis=axis) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + if axis is None: + return torch.cat((x1.flatten(), x2.flatten())) + return torch.cat((x1, x2), dim=axis) def arange(start, stop=None, step=None, dtype=None): - pass - # return tfnp.arange(start, stop, step=step, dtype=dtype) + dtype = to_torch_dtype(dtype) + if stop is None: + return torch.arange(start, step=step, dtype=dtype) + step = step or 1 + return torch.arange(start, stop, step=step, dtype=dtype) def arccos(x): - pass - # return tfnp.arccos(x) + x = convert_to_tensor(x) + return torch.arccos(x) def arcsin(x): - pass - # return tfnp.arcsin(x) + x = convert_to_tensor(x) + return torch.arcsin(x) def arctan(x): - pass - # return tfnp.arctan(x) + x = convert_to_tensor(x) + return torch.arctan(x) def arctan2(x1, x2): - pass - # return tfnp.arctan2(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.arctan2(x1, x2) def argmax(x, axis=None): - pass - # return tfnp.argmax(x, axis=axis) + x = convert_to_tensor(x) + return torch.argmax(x, dim=axis) def argmin(x, axis=None): - pass - # return tfnp.argmin(x, axis=axis) + x = convert_to_tensor(x) + return torch.argmin(x, dim=axis) def argsort(x, axis=-1): - pass - # return tfnp.argsort(x, axis=axis) + x = convert_to_tensor(x) + if axis is None: + axis = -1 + x = x.reshape(-1) + return torch.argsort(x, dim=axis) def array(x, dtype=None): - pass - # return tfnp.array(x, dtype=dtype) + dtype = to_torch_dtype(dtype) + if not isinstance(x, torch.Tensor): + return x + return x.numpy() def average(x, axis=None, weights=None): - pass - # return tfnp.average(x, weights=weights, axis=axis) + x = convert_to_tensor(x) + if weights is not None: + weights = convert_to_tensor(weights) + return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum( + weights, dim=-1 + ) + return torch.mean(x, axis) + + +def bincount(x, weights=None, minlength=0): + x = convert_to_tensor(x, dtype=int) + weights = convert_to_tensor(weights) + return torch.bincount(x, weights, minlength) def broadcast_to(x, shape): - pass - # return tfnp.broadcast_to(x, shape) + x = convert_to_tensor(x) + return torch.broadcast_to(x, shape) def ceil(x): - pass - # return tfnp.ceil(x) + x = convert_to_tensor(x) + return torch.ceil(x) def clip(x, x_min, x_max): - pass - # return tfnp.clip(x, x_min, x_max) + x = convert_to_tensor(x) + x_min, x_max = convert_to_tensor(x_min), convert_to_tensor(x_max) + return torch.clip(x, min=x_min, max=x_max) def concatenate(xs, axis=0): - pass - # return tfnp.concatenate(xs, axis=axis) + xs = [convert_to_tensor(x) for x in xs] + return torch.cat(xs, dim=axis) def conjugate(x): - pass - # return tfnp.conjugate(x) + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.conj(x).resolve_conj() def conj(x): - pass - # return conjugate(x) + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.conj(x).resolve_conj() def copy(x): - pass - # return tfnp.copy(x) + x = convert_to_tensor(x) + return torch.clone(x) def cos(x): - pass - # return tfnp.cos(x) + x = convert_to_tensor(x) + return torch.cos(x) def count_nonzero(x, axis=None): - pass - # return tfnp.count_nonzero(x, axis=axis) + x = convert_to_tensor(x) + return torch.count_nonzero(x, dim=axis).T def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): - pass - # return tfnp.cross( - # x1, - # x2, - # axisa=axisa, - # axisb=axisb, - # axisc=axisc, - # axis=axis, - # ) + # TODO: There is API divergence between np.cross and torch.cross, + # preventing `axisa`, `axisb`, and `axisc` parameters from + # being used. + # https://github.com/pytorch/pytorch/issues/50273 + if axisa != -1 or axisb != -1 or axisc != -1: + raise NotImplementedError( + "Due to API divergence between `torch.cross()` and " + "`np.cross`, the following arguments are not supported: " + f"axisa={axisa}, axisb={axisb}, axisc={axisc}" + ) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.cross(x1, x2, dim=axis) def cumprod(x, axis=None): - pass - # return tfnp.cumprod(x, axis=axis) + x = convert_to_tensor(x) + if axis is None: + x = x.flatten() + axis = 0 + return torch.cumprod(x, dim=axis) def cumsum(x, axis=None): - pass - # return tfnp.cumsum(x, axis=axis) + x = convert_to_tensor(x) + if axis is None: + x = x.flatten() + axis = 0 + return torch.cumsum(x, dim=axis) def diag(x, k=0): - pass - # return tfnp.diag(x, k=k) + x = convert_to_tensor(x) + return torch.diag(x, diagonal=k) def diagonal(x, offset=0, axis1=0, axis2=1): - pass - # return tfnp.diagonal( - # x, - # offset=offset, - # axis1=axis1, - # axis2=axis2, - # ) + x = convert_to_tensor(x) + return torch.diagonal( + x, + offset=offset, + dim1=axis1, + dim2=axis2, + ) def dot(x, y): - pass - # return tfnp.dot(x, y) + x, y = convert_to_tensor(x), convert_to_tensor(y) + if len(x.shape) == 0 or len(y.shape) == 0: + return torch.multiply(x, y) + return torch.matmul(x, y) def empty(shape, dtype="float32"): - pass - # return tfnp.empty(shape, dtype=dtype) + dtype = to_torch_dtype(dtype) + return torch.empty(size=shape, dtype=dtype) def equal(x1, x2): - pass - # return tfnp.equal(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.equal(x1, x2) def exp(x): - pass - # return tfnp.exp(x) + x = convert_to_tensor(x) + return torch.exp(x) def expand_dims(x, axis): - pass - # return tfnp.expand_dims(x, axis) + x = convert_to_tensor(x) + return torch.unsqueeze(x, dim=axis) def expm1(x): - pass - # return tfnp.expm1(x) + x = convert_to_tensor(x) + return torch.expm1(x) def flip(x, axis=None): - pass - # return tfnp.flip(x, axis=axis) + x = convert_to_tensor(x) + if axis is None: + axis = tuple(range(len(x.shape))) + if isinstance(axis, int): + axis = (axis,) + return torch.flip(x, dims=axis) def floor(x): - pass - # return tfnp.floor(x) + x = convert_to_tensor(x) + return torch.floor(x) def full(shape, fill_value, dtype=None): - pass - # return tfnp.full(shape, fill_value, dtype=dtype) + dtype = to_torch_dtype(dtype) + if hasattr(fill_value, "__len__"): + raise NotImplementedError( + "`torch.full()` only accepts scalars for `fill_value`. " + f"Received: fill_value={fill_value}" + ) + # TODO: implement conversion of shape into repetitions for `torch.tile`` + # return torch.tile(fill_value, reps) + return torch.full(size=shape, fill_value=fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): - pass - # return tfnp.full_like(x, fill_value, dtype=dtype) + dtype = to_torch_dtype(dtype) + if hasattr(fill_value, "__len__"): + raise NotImplementedError( + "`torch.full()` only accepts scalars for `fill_value`." + ) + # TODO: implement conversion of shape into repetitions for `torch.tile`` + # return torch.tile(fill_value, reps) + x = convert_to_tensor(x) + return torch.full_like(input=x, fill_value=fill_value, dtype=dtype) def greater(x1, x2): - pass - # return tfnp.greater(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.greater(x1, x2) def greater_equal(x1, x2): - pass - # return tfnp.greater_equal(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.greater_equal(x1, x2) def hstack(xs): - pass - # return tfnp.hstack(xs) + xs = [convert_to_tensor(x) for x in xs] + return torch.hstack(xs) def identity(n, dtype="float32"): - pass - # return tfnp.identity(n, dtype=dtype) + dtype = to_torch_dtype(dtype) + return torch.eye(n, dtype=dtype) def imag(x): - pass - # return tfnp.imag(x) + if not isinstance(x, torch.Tensor): + x = torch.from_numpy(x) # needed for complex type conversion + return torch.imag(x) def isclose(x1, x2): - pass - # return tfnp.isclose(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.isclose(x1, x2) def isfinite(x): - pass - # return tfnp.isfinite(x) + x = convert_to_tensor(x) + return torch.isfinite(x) def isinf(x): - pass - # return tfnp.isinf(x) + x = convert_to_tensor(x) + return torch.isinf(x) def isnan(x): - pass - # return tfnp.isnan(x) + x = convert_to_tensor(x) + return torch.isnan(x) def less(x1, x2): - pass - # return tfnp.less(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.less(x1, x2) def less_equal(x1, x2): - pass - # return tfnp.less_equal(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.less_equal(x1, x2) def linspace( start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): - pass - # return tfnp.linspace( - # start, - # stop, - # num=num, - # endpoint=endpoint, - # retstep=retstep, - # dtype=dtype, - # axis=axis, - # ) + if axis != 0: + raise ValueError( + "torch.linspace does not support an `axis` argument. " + f"Received axis={axis}" + ) + dtype = to_torch_dtype(dtype) + if endpoint is False: + stop = stop - ((stop - start) / num) + linspace = torch.linspace( + start=start, + end=stop, + steps=num, + dtype=dtype, + ) + if retstep is True: + return (linspace, num) + return linspace def log(x): - pass - # return tfnp.log(x) + x = convert_to_tensor(x) + return torch.log(x) def log10(x): - pass - # return tfnp.log10(x) + x = convert_to_tensor(x) + return torch.log10(x) def log1p(x): - pass - # return tfnp.log1p(x) + x = convert_to_tensor(x) + return torch.log1p(x) def log2(x): - pass - # return tfnp.log2(x) + x = convert_to_tensor(x) + return torch.log2(x) def logaddexp(x1, x2): - pass - # return tfnp.logaddexp(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logaddexp(x1, x2) def logical_and(x1, x2): - pass - # return tfnp.logical_and(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logical_and(x1, x2) def logical_not(x): - pass - # return tfnp.logical_not(x) + x = convert_to_tensor(x) + return torch.logical_not(x) def logical_or(x1, x2): - pass - # return tfnp.logical_or(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.logical_or(x1, x2) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): - pass - # return tfnp.logspace( - # start, - # stop, - # num=num, - # endpoint=endpoint, - # base=base, - # dtype=dtype, - # axis=axis, - # ) + if axis != 0: + raise ValueError( + "torch.logspace does not support an `axis` argument. " + f"Received axis={axis}" + ) + dtype = to_torch_dtype(dtype) + if endpoint is False: + stop = stop - ((stop - start) / num) + logspace = torch.logspace( + start=start, + end=stop, + steps=num, + base=base, + dtype=dtype, + ) + return logspace def maximum(x1, x2): - pass - # return tfnp.maximum(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.maximum(x1, x2) def meshgrid(*x, indexing="xy"): - pass - # return tfnp.meshgrid(*x, indexing=indexing) + x = [convert_to_tensor(sc_tensor) for sc_tensor in x] + result = torch.meshgrid(x, indexing=indexing) + return [arr.numpy() for arr in result] def min(x, axis=None, keepdims=False, initial=None): - pass - ## The TensorFlow numpy API implementation doesn't support `initial` so we - ## handle it manually here. - # if initial is not None: - # return tf.math.minimum( - # tfnp.min(x, axis=axis, keepdims=keepdims), initial - # ) + x = convert_to_tensor(x) + if axis is None: + result = torch.min(x) + else: + if isinstance(axis, list): + axis = axis[-1] + result = torch.min(x, dim=axis, keepdim=keepdims) - ## TensorFlow returns inf by default for an empty list, but for consistency - ## with other backends and the numpy API we want to throw in this case. - # tf.assert_greater( - # size(x), - # tf.constant(0, dtype=tf.int64), - # message="Cannot compute the min of an empty tensor.", - # ) + if isinstance(getattr(result, "values", None), torch.Tensor): + result = result.values - # return tfnp.min(x, axis=axis, keepdims=keepdims) + if initial is not None: + return torch.minimum(result, initial) + return result def minimum(x1, x2): - pass - # return tfnp.minimum(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.minimum(x1, x2) def mod(x1, x2): - pass - # return tfnp.mod(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.remainder(x1, x2) def moveaxis(x, source, destination): - pass - # return tfnp.moveaxis(x, source=source, destination=destination) + x = convert_to_tensor(x) + return torch.moveaxis(x, source=source, destination=destination) def nan_to_num(x): - pass - ## Replace NaN with 0 - # x = tf.where(tf.math.is_nan(x), 0, x) - - ## Replace positive infinitiy with dtype.max - # x = tf.where(tf.math.is_inf(x) & (x > 0), x.dtype.max, x) - - ## Replace negative infinity with dtype.min - # x = tf.where(tf.math.is_inf(x) & (x < 0), x.dtype.min, x) - - # return x + x = convert_to_tensor(x) + return torch.nan_to_num(x) def ndim(x): - pass - # return tfnp.ndim(x) + x = convert_to_tensor(x) + return x.ndim def nonzero(x): - pass - # return tfnp.nonzero(x) + x = convert_to_tensor(x) + return torch.nonzero(x).T def not_equal(x1, x2): - pass - # return tfnp.not_equal(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.not_equal(x1, x2) def ones_like(x, dtype=None): - pass - # return tfnp.ones_like(x, dtype=dtype) + x = convert_to_tensor(x) + dtype = to_torch_dtype(dtype) + return torch.ones_like(x, dtype=dtype) def outer(x1, x2): - pass - # return tfnp.outer(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.outer(x1.flatten(), x2.flatten()) def pad(x, pad_width, mode="constant"): - pass - # return tfnp.pad(x, pad_width, mode=mode) + x = convert_to_tensor(x) + pad_sum = () + for pad in pad_width: + pad_sum += pad + return torch.nn.functional.pad(x, pad=pad_sum, mode=mode) def prod(x, axis=None, keepdims=False, dtype=None): - pass - # return tfnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) + x = convert_to_tensor(x) + dtype = to_torch_dtype(dtype) + if axis is None: + return torch.prod(x, dtype=dtype) + elif isinstance(axis, list): + axis = axis[-1] + return torch.prod(x, dim=axis, keepdim=keepdims, dtype=dtype) def ravel(x): - pass - # return tfnp.ravel(x) + x = convert_to_tensor(x) + return torch.ravel(x) def real(x): - pass - # return tfnp.real(x) + x = convert_to_tensor(x) + return torch.real(x) def reciprocal(x): - pass - # return tfnp.reciprocal(x) + x = convert_to_tensor(x) + return torch.reciprocal(x) def repeat(x, repeats, axis=None): - pass - # return tfnp.repeat(x, repeats, axis=axis) + x = convert_to_tensor(x) + repeats = convert_to_tensor(repeats, dtype=int) + return torch.repeat_interleave(x, repeats, dim=axis) def reshape(x, new_shape): - pass - # return tfnp.reshape(x, new_shape) + x = convert_to_tensor(x) + return torch.reshape(x, new_shape) def roll(x, shift, axis=None): - pass - # return tfnp.roll(x, shift, axis=axis) + x = convert_to_tensor(x) + return torch.roll(x, shift, dims=axis) def sign(x): - pass - # return tfnp.sign(x) + x = convert_to_tensor(x) + return torch.sign(x) def sin(x): - pass - # return tfnp.sin(x) + x = convert_to_tensor(x) + return torch.sin(x) def size(x): - pass - # return tfnp.size(x) + x_shape = convert_to_tensor(tuple(x.shape)) + return torch.prod(x_shape) def sort(x, axis=-1): - pass - # return tfnp.sort(x, axis=axis) + x = convert_to_tensor(x) + return torch.sort(x, dim=axis).values def split(x, indices_or_sections, axis=0): - pass - # return tfnp.split(x, indices_or_sections, axis=axis) + x = convert_to_tensor(x) + return torch.split( + tensor=x, + split_size_or_sections=indices_or_sections, + dim=axis, + ) def stack(x, axis=0): - pass - # return tfnp.stack(x, axis=axis) + x = [convert_to_tensor(elem) for elem in x] + return torch.stack(x, dim=axis) def std(x, axis=None, keepdims=False): - pass - # return tfnp.std(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + # Remove Bessel correction to align with numpy + return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False) def swapaxes(x, axis1, axis2): - pass - # return tfnp.swapaxes(x, axis1=axis1, axis2=axis2) + x = convert_to_tensor(x) + return torch.swapaxes(x, axis0=axis1, axis1=axis2) def take(x, indices, axis=None): - pass - # return tfnp.take(x, indices, axis=axis) + x = convert_to_tensor(x) + indices = convert_to_tensor(indices).long() + if axis is not None: + return torch.index_select(x, dim=axis, index=indices) + return torch.take(x, index=indices) def take_along_axis(x, indices, axis=None): - pass - # return tfnp.take_along_axis(x, indices, axis=axis) + x = convert_to_tensor(x) + indices = convert_to_tensor(indices).long() + return torch.take_along_dim(x, indices, dim=axis) def tan(x): - pass - # return tfnp.tan(x) + x = convert_to_tensor(x) + return torch.tan(x) def tensordot(x1, x2, axes=2): - pass - # return tfnp.tensordot(x1, x2, axes=axes) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.tensordot(x1, x2, dims=axes) def round(x, decimals=0): - pass - # return tfnp.round(x, decimals=decimals) + x = convert_to_tensor(x) + return torch.round(x, decimals=decimals) def tile(x, repeats): - pass - # return tfnp.tile(x, repeats) + x = convert_to_tensor(x) + return torch.tile(x, dims=repeats) -def trace(x, offset=0, axis1=0, axis2=1): - pass - # return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2) +def trace(x, offset=None, axis1=None, axis2=None): + x = convert_to_tensor(x) + # TODO: implement support for these arguments + # API divergence between `np.trace()` and `torch.trace()` + if offset or axis1 or axis2: + raise NotImplementedError( + "Arguments not supported by `torch.trace: " + f"offset={offset}, axis1={axis1}, axis2={axis2}" + ) + return torch.trace(x) def tri(N, M=None, k=0, dtype="float32"): + dtype = to_torch_dtype(dtype) pass - # return tfnp.tri(N, M=M, k=k, dtype=dtype) def tril(x, k=0): - pass - # return tfnp.tril(x, k=k) + x = convert_to_tensor(x) + return torch.tril(x, diagonal=k) def triu(x, k=0): - pass - # return tfnp.triu(x, k=k) + x = convert_to_tensor(x) + return torch.triu(x, diagonal=k) def vdot(x1, x2): - pass - # return tfnp.vdot(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.vdot(x1, x2) def vstack(xs): - pass - # return tfnp.vstack(xs) + xs = [convert_to_tensor(x) for x in xs] + return torch.vstack(xs) def where(condition, x1, x2): - pass - # return tfnp.where(condition, x1, x2) + condition = convert_to_tensor(condition, dtype=bool) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.where(condition, x1, x2) def divide(x1, x2): - pass - # return tfnp.divide(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.divide(x1, x2) def true_divide(x1, x2): - pass - # return tfnp.true_divide(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.true_divide(x1, x2) def power(x1, x2): - pass - # return tfnp.power(x1, x2) + x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) + return torch.pow(x1, x2) def negative(x): - pass - # return tfnp.negative(x) + x = convert_to_tensor(x) + return torch.negative(x) def square(x): - pass - # return tfnp.square(x) + x = convert_to_tensor(x) + return torch.square(x) def sqrt(x): - pass - # return tfnp.sqrt(x) + x = convert_to_tensor(x) + return torch.sqrt(x) def squeeze(x, axis=None): - pass - # return tfnp.squeeze(x, axis=axis) + x = convert_to_tensor(x) + if axis is not None: + return torch.squeeze(x, dim=axis) + return torch.squeeze(x) def transpose(x, axes=None): - pass - # return tfnp.transpose(x, axes=axes) + x = convert_to_tensor(x) + if axes is not None: + return torch.permute(x, dims=axes) + return x.T def var(x, axis=None, keepdims=False): - pass - # return tfnp.var(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + # Bessel correction removed for numpy compatibility + return torch.var(x, dim=axis, keepdim=keepdims, correction=0) def sum(x, axis=None, keepdims=False): - pass - # return tfnp.sum(x, axis=axis, keepdims=keepdims) + x = convert_to_tensor(x) + if axis is not None: + return torch.sum(x, axis=axis, keepdim=keepdims) + return torch.sum(x) -def eye(N, M=None, k=0, dtype="float32"): - pass - # return tfnp.eye(N, M=M, k=k, dtype=dtype) +def eye(N, M=None, k=None, dtype="float32"): + # TODO: implement support for `k` diagonal arg, + # does not exist in torch.eye() + if k is not None: + raise NotImplementedError( + "Due to API divergence bewtween `torch.eye` " + "and `np.eye`, the argument k is not supported: " + f"Received: k={k}" + ) + dtype = to_torch_dtype(dtype) + if M is not None: + return torch.eye(n=N, m=M, dtype=dtype) + return torch.eye(n=N, dtype=dtype) diff --git a/keras_core/layers/reshaping/up_sampling3d.py b/keras_core/layers/reshaping/up_sampling3d.py index 1549f76c9..5a7cace72 100644 --- a/keras_core/layers/reshaping/up_sampling3d.py +++ b/keras_core/layers/reshaping/up_sampling3d.py @@ -13,8 +13,7 @@ class UpSampling3D(Layer): Repeats the 1st, 2nd and 3rd dimensions of the data by `size[0]`, `size[1]` and `size[2]` respectively. - Example: - + Examples: >>> input_shape = (2, 1, 2, 1, 3) >>> x = np.ones(input_shape) >>> y = keras_core.layers.UpSampling3D(size=(2, 2, 2))(x) @@ -109,7 +108,6 @@ class UpSampling3D(Layer): self, x, depth_factor, height_factor, width_factor, data_format ): """Resizes the volume contained in a 5D tensor. - Args: x: Tensor or variable to resize. depth_factor: Positive integer. @@ -118,7 +116,11 @@ class UpSampling3D(Layer): data_format: One of `"channels_first"`, `"channels_last"`. Returns: - Resized tensor. + A tensor. + + Raises: + ValueError: if `data_format` is neither + `channels_last` or `channels_first`. """ if data_format == "channels_first": output = ops.repeat(x, depth_factor, axis=2) @@ -131,4 +133,4 @@ class UpSampling3D(Layer): output = ops.repeat(output, width_factor, axis=3) return output else: - raise ValueError(f"Invalid data_format: {data_format}") + raise ValueError("Invalid data_format: " + str(data_format)) diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index 80f108951..ca929ba22 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1665,13 +1665,19 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): y3 = np.ones([1, 5, 4, 2]) self.assertAllClose(np.array(knp.cross(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.cross(x1, y2)), np.cross(x1, y2)) - self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3)) - self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3)) + if backend.backend() != "torch": + # API divergence between `torch.cross` and `np.cross` + # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 + self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3)) + self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3)) self.assertAllClose(np.array(knp.Cross()(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.Cross()(x1, y2)), np.cross(x1, y2)) - self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3)) - self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3)) + if backend.backend() != "torch": + # API divergence between `torch.cross` and `np.cross` + # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 + self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3)) + self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3)) def test_einsum(self): x = np.arange(24).reshape([2, 3, 4]) @@ -1711,25 +1717,36 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): def test_full_like(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(np.array(knp.full_like(x, 2)), np.full_like(x, 2)) - self.assertAllClose( - np.array(knp.full_like(x, np.ones([2, 3]))), - np.full_like(x, np.ones([2, 3])), - ) self.assertAllClose( np.array(knp.full_like(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) - self.assertAllClose( - np.array(knp.FullLike()(x, np.ones([2, 3]))), - np.full_like(x, np.ones([2, 3])), - ) self.assertAllClose( np.array(knp.FullLike()(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) + # TODO: implement conversion of shape into repetitions, pass to + # `torch.tile`, since `torch.full()` only accepts scalars + # for `fill_value`." + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.full` only accepts scalars for `fill_value`.", + ) + def test_full_like_without_torch(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + self.assertAllClose( + np.array(knp.full_like(x, np.ones([2, 3]))), + np.full_like(x, np.ones([2, 3])), + ) + + self.assertAllClose( + np.array(knp.FullLike()(x, np.ones([2, 3]))), + np.full_like(x, np.ones([2, 3])), + ) + def test_greater(self): x = np.array([[1, 2, 3], [3, 2, 1]]) y = np.array([[4, 5, 6], [3, 2, 1]]) @@ -1825,6 +1842,14 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.linspace(0, 10, 5, endpoint=False), ) + # TODO: torch.linspace does not support tensor or array + # for start/stop, create manual implementation + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.linspace` has no support for array start/stop.", + ) + def test_linspace_without_torch(self): + start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -1929,6 +1954,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.logspace(0, 10, 5, endpoint=False), ) + # TODO: torch.logspace does not support tensor or array + # for start/stop, create manual implementation + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.logspace` has no support for array start/stop.", + ) + def test_logspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -2004,6 +2036,11 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y)) + # TODO: Fix numpy compatibility (squeeze by one dimension only) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.take` and `np.take` have return shape divergence.", + ) def test_take(self): x = np.arange(24).reshape([1, 2, 3, 4]) indices = np.array([0, 1]) @@ -2767,6 +2804,23 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.array(knp.pad(x, ((1, 1), (1, 1)))), np.pad(x, ((1, 1), (1, 1))), ) + + self.assertAllClose( + np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1))) + ) + self.assertAllClose( + np.array(knp.Pad(((1, 1), (1, 1)))(x)), + np.pad(x, ((1, 1), (1, 1))), + ) + + # TODO: implement padding with non-constant padding, + # bypass NotImplementedError for PyTorch + @pytest.mark.skipif( + backend.backend() == "torch", + reason="padding not implemented for non-constant use case", + ) + def test_pad_without_torch(self): + x = np.array([[1, 2], [3, 4]]) self.assertAllClose( np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")), np.pad(x, ((1, 1), (1, 1)), mode="reflect"), @@ -2776,13 +2830,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.pad(x, ((1, 1), (1, 1)), mode="symmetric"), ) - self.assertAllClose( - np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1))) - ) - self.assertAllClose( - np.array(knp.Pad(((1, 1), (1, 1)))(x)), - np.pad(x, ((1, 1), (1, 1))), - ) self.assertAllClose( np.array(knp.Pad(((1, 1), (1, 1)), mode="reflect")(x)), np.pad(x, ((1, 1), (1, 1)), mode="reflect"), @@ -2905,6 +2952,12 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0)) + # TODO: implement split for `torch` with support for conversion + # of numpy.split args. + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.split` and `np.split` have return arg divergence.", + ) def test_split(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2)) @@ -2972,7 +3025,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.split` does not support args `offset`, `axis1`, `axis2`", + ) def test_trace(self): + # TODO: implement `torch.trace` support for arguments `offset`, + # `axis1`, `axis2` and delete NotImplementedError x = np.arange(24).reshape([1, 2, 3, 4]) self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose( @@ -3012,7 +3071,14 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3])) self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3])) + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.eye` does not support arg `k`.", + ) def test_eye(self): + # TODO: implement support for `k` diagonal arg, + # does not exist in torch.eye() + self.assertAllClose(np.array(knp.eye(3)), np.eye(3)) self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4)) self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1)) @@ -3035,15 +3101,25 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose( np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) ) - self.assertAllClose( - np.array(knp.full([2, 3], np.array([1, 4, 5]))), - np.full([2, 3], np.array([1, 4, 5])), - ) self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) self.assertAllClose( np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) ) + + # TODO: implement conversion of shape into repetitions, pass to + # `torch.tile`, since `torch.full()` only accepts scalars + # for `fill_value`." + @pytest.mark.skipif( + backend.backend() == "torch", + reason="`torch.full` only accepts scalars for `fill_value`.", + ) + def test_full_without_torch(self): + self.assertAllClose( + np.array(knp.full([2, 3], np.array([1, 4, 5]))), + np.full([2, 3], np.array([1, 4, 5])), + ) + self.assertAllClose( np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])), @@ -3053,7 +3129,11 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.identity(3)), np.identity(3)) self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3)) + @pytest.mark.skipif( + backend.backend() == "torch", reason="No torch equivalent for `np.tri`" + ) def test_tri(self): + # TODO: create a manual implementation, as PyTorch has no equivalent self.assertAllClose(np.array(knp.tri(3)), np.tri(3)) self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4)) self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1)) diff --git a/keras_core/utils/feature_space.py b/keras_core/utils/feature_space.py deleted file mode 100644 index bf101aa04..000000000 --- a/keras_core/utils/feature_space.py +++ /dev/null @@ -1,757 +0,0 @@ -from tensorflow import data as tf_data - -from keras_core import layers -from keras_core import operations as ops -from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer -from keras_core.saving import saving_lib -from keras_core.saving import serialization_lib -from keras_core.utils.naming import auto_name - - -class Cross: - def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): - if output_mode not in {"int", "one_hot"}: - raise ValueError( - "Invalid value for argument `output_mode`. " - "Expected one of {'int', 'one_hot'}. " - f"Received: output_mode={output_mode}" - ) - self.feature_names = tuple(feature_names) - self.crossing_dim = crossing_dim - self.output_mode = output_mode - - @property - def name(self): - return "_X_".join(self.feature_names) - - def get_config(self): - return { - "feature_names": self.feature_names, - "crossing_dim": self.crossing_dim, - "output_mode": self.output_mode, - } - - @classmethod - def from_config(cls, config): - return cls(**config) - - -class Feature: - def __init__(self, dtype, preprocessor, output_mode): - if output_mode not in {"int", "one_hot", "float"}: - raise ValueError( - "Invalid value for argument `output_mode`. " - "Expected one of {'int', 'one_hot', 'float'}. " - f"Received: output_mode={output_mode}" - ) - self.dtype = dtype - if isinstance(preprocessor, dict): - preprocessor = serialization_lib.deserialize_keras_object( - preprocessor - ) - self.preprocessor = preprocessor - self.output_mode = output_mode - - def get_config(self): - return { - "dtype": self.dtype, - "preprocessor": serialization_lib.serialize_keras_object( - self.preprocessor - ), - "output_mode": self.output_mode, - } - - @classmethod - def from_config(cls, config): - return cls(**config) - - -@keras_core_export("keras_core.utils.FeatureSpace") -class FeatureSpace(Layer): - """One-stop utility for preprocessing and encoding structured data. - - Arguments: - feature_names: Dict mapping the names of your features to their - type specification, e.g. `{"my_feature": "integer_categorical"}` - or `{"my_feature": FeatureSpace.integer_categorical()}`. - For a complete list of all supported types, see - "Available feature types" paragraph below. - output_mode: One of `"concat"` or `"dict"`. In concat mode, all - features get concatenated together into a single vector. - In dict mode, the FeatureSpace returns a dict of individually - encoded features (with the same keys as the input dict keys). - crosses: List of features to be crossed together, e.g. - `crosses=[("feature_1", "feature_2")]`. The features will be - "crossed" by hashing their combined value into - a fixed-length vector. - crossing_dim: Default vector size for hashing crossed features. - Defaults to `32`. - hashing_dim: Default vector size for hashing features of type - `"integer_hashed"` and `"string_hashed"`. Defaults to `32`. - num_discretization_bins: Default number of bins to be used for - discretizing features of type `"float_discretized"`. - Defaults to `32`. - - **Available feature types:** - - Note that all features can be referred to by their string name, - e.g. `"integer_categorical"`. When using the string name, the default - argument values are used. - - ```python - # Plain float values. - FeatureSpace.float(name=None) - - # Float values to be preprocessed via featurewise standardization - # (i.e. via a `keras.layers.Normalization` layer). - FeatureSpace.float_normalized(name=None) - - # Float values to be preprocessed via linear rescaling - # (i.e. via a `keras.layers.Rescaling` layer). - FeatureSpace.float_rescaled(scale=1., offset=0., name=None) - - # Float values to be discretized. By default, the discrete - # representation will then be one-hot encoded. - FeatureSpace.float_discretized( - num_bins, bin_boundaries=None, output_mode="one_hot", name=None) - - # Integer values to be indexed. By default, the discrete - # representation will then be one-hot encoded. - FeatureSpace.integer_categorical( - max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) - - # String values to be indexed. By default, the discrete - # representation will then be one-hot encoded. - FeatureSpace.string_categorical( - max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) - - # Integer values to be hashed into a fixed number of bins. - # By default, the discrete representation will then be one-hot encoded. - FeatureSpace.integer_hashed(num_bins, output_mode="one_hot", name=None) - - # String values to be hashed into a fixed number of bins. - # By default, the discrete representation will then be one-hot encoded. - FeatureSpace.string_hashed(num_bins, output_mode="one_hot", name=None) - ``` - - Examples: - - **Basic usage with a dict of input data:** - - ```python - raw_data = { - "float_values": [0.0, 0.1, 0.2, 0.3], - "string_values": ["zero", "one", "two", "three"], - "int_values": [0, 1, 2, 3], - } - dataset = tf.data.Dataset.from_tensor_slices(raw_data) - - feature_space = FeatureSpace( - features={ - "float_values": "float_normalized", - "string_values": "string_categorical", - "int_values": "integer_categorical", - }, - crosses=[("string_values", "int_values")], - output_mode="concat", - ) - # Before you start using the FeatureSpace, - # you must `adapt()` it on some data. - feature_space.adapt(dataset) - - # You can call the FeatureSpace on a dict of data (batched or unbatched). - output_vector = feature_space(raw_data) - ``` - - **Basic usage with `tf.data`:** - - ```python - # Unlabeled data - preprocessed_ds = unlabeled_dataset.map(feature_space) - - # Labeled data - preprocessed_ds = labeled_dataset.map(lambda x, y: (feature_space(x), y)) - ``` - - **Basic usage with the Keras Functional API:** - - ```python - # Retrieve a dict Keras Input objects - inputs = feature_space.get_inputs() - # Retrieve the corresponding encoded Keras tensors - encoded_features = feature_space.get_encoded_features() - # Build a Functional model - outputs = keras.layers.Dense(1, activation="sigmoid")(encoded_features) - model = keras.Model(inputs, outputs) - ``` - - **Customizing each feature or feature cross:** - - ```python - feature_space = FeatureSpace( - features={ - "float_values": FeatureSpace.float_normalized(), - "string_values": FeatureSpace.string_categorical(max_tokens=10), - "int_values": FeatureSpace.integer_categorical(max_tokens=10), - }, - crosses=[ - FeatureSpace.cross(("string_values", "int_values"), crossing_dim=32) - ], - output_mode="concat", - ) - ``` - - **Returning a dict of integer-encoded features:** - - ```python - feature_space = FeatureSpace( - features={ - "string_values": FeatureSpace.string_categorical(output_mode="int"), - "int_values": FeatureSpace.integer_categorical(output_mode="int"), - }, - crosses=[ - FeatureSpace.cross( - feature_names=("string_values", "int_values"), - crossing_dim=32, - output_mode="int", - ) - ], - output_mode="dict", - ) - ``` - - **Specifying your own Keras preprocessing layer:** - - ```python - # Let's say that one of the features is a short text paragraph that - # we want to encode as a vector (one vector per paragraph) via TF-IDF. - data = { - "text": ["1st string", "2nd string", "3rd string"], - } - - # There's a Keras layer for this: TextVectorization. - custom_layer = layers.TextVectorization(output_mode="tf_idf") - - # We can use FeatureSpace.feature to create a custom feature - # that will use our preprocessing layer. - feature_space = FeatureSpace( - features={ - "text": FeatureSpace.feature( - preprocessor=custom_layer, dtype="string", output_mode="float" - ), - }, - output_mode="concat", - ) - feature_space.adapt(tf.data.Dataset.from_tensor_slices(data)) - output_vector = feature_space(data) - ``` - - **Retrieving the underlying Keras preprocessing layers:** - - ```python - # The preprocessing layer of each feature is available in `.preprocessors`. - preprocessing_layer = feature_space.preprocessors["feature1"] - - # The crossing layer of each feature cross is available in `.crossers`. - # It's an instance of keras.layers.HashedCrossing. - crossing_layer = feature_space.crossers["feature1_X_feature2"] - ``` - - **Saving and reloading a FeatureSpace:** - - ```python - feature_space.save("myfeaturespace.keras") - reloaded_feature_space = keras.models.load_model("myfeaturespace.keras") - ``` - """ - - @classmethod - def cross(cls, feature_names, crossing_dim, output_mode="one_hot"): - return Cross(feature_names, crossing_dim, output_mode=output_mode) - - @classmethod - def feature(cls, dtype, preprocessor, output_mode): - return Feature(dtype, preprocessor, output_mode) - - @classmethod - def float(cls, name=None): - from keras.layers.core import identity - - name = name or auto_name("float") - preprocessor = identity.Identity( - dtype="float32", name=f"{name}_preprocessor" - ) - return Feature( - dtype="float32", preprocessor=preprocessor, output_mode="float" - ) - - @classmethod - def float_rescaled(cls, scale=1.0, offset=0.0, name=None): - name = name or auto_name("float_rescaled") - preprocessor = layers.Rescaling( - scale=scale, offset=offset, name=f"{name}_preprocessor" - ) - return Feature( - dtype="float32", preprocessor=preprocessor, output_mode="float" - ) - - @classmethod - def float_normalized(cls, name=None): - name = name or auto_name("float_normalized") - preprocessor = layers.Normalization( - axis=-1, name=f"{name}_preprocessor" - ) - return Feature( - dtype="float32", preprocessor=preprocessor, output_mode="float" - ) - - @classmethod - def float_discretized( - cls, num_bins, bin_boundaries=None, output_mode="one_hot", name=None - ): - name = name or auto_name("float_discretized") - preprocessor = layers.Discretization( - num_bins=num_bins, - bin_boundaries=bin_boundaries, - name=f"{name}_preprocessor", - ) - return Feature( - dtype="float32", preprocessor=preprocessor, output_mode=output_mode - ) - - @classmethod - def integer_categorical( - cls, - max_tokens=None, - num_oov_indices=1, - output_mode="one_hot", - name=None, - ): - name = name or auto_name("integer_categorical") - preprocessor = layers.IntegerLookup( - name=f"{name}_preprocessor", - max_tokens=max_tokens, - num_oov_indices=num_oov_indices, - ) - return Feature( - dtype="int64", preprocessor=preprocessor, output_mode=output_mode - ) - - @classmethod - def string_categorical( - cls, - max_tokens=None, - num_oov_indices=1, - output_mode="one_hot", - name=None, - ): - name = name or auto_name("string_categorical") - preprocessor = layers.StringLookup( - name=f"{name}_preprocessor", - max_tokens=max_tokens, - num_oov_indices=num_oov_indices, - ) - return Feature( - dtype="string", preprocessor=preprocessor, output_mode=output_mode - ) - - @classmethod - def string_hashed(cls, num_bins, output_mode="one_hot", name=None): - name = name or auto_name("string_hashed") - preprocessor = layers.Hashing( - name=f"{name}_preprocessor", num_bins=num_bins - ) - return Feature( - dtype="string", preprocessor=preprocessor, output_mode=output_mode - ) - - @classmethod - def integer_hashed(cls, num_bins, output_mode="one_hot", name=None): - name = name or auto_name("integer_hashed") - preprocessor = layers.Hashing( - name=f"{name}_preprocessor", num_bins=num_bins - ) - return Feature( - dtype="int64", preprocessor=preprocessor, output_mode=output_mode - ) - - def __init__( - self, - features, - output_mode="concat", - crosses=None, - crossing_dim=32, - hashing_dim=32, - num_discretization_bins=32, - name=None - ): - super().__init__(name=name) - if not features: - raise ValueError("The `features` argument cannot be None or empty.") - self.crossing_dim = crossing_dim - self.hashing_dim = hashing_dim - self.num_discretization_bins = num_discretization_bins - self.features = { - name: self._standardize_feature(name, value) - for name, value in features.items() - } - self.crosses = [] - if crosses: - feature_set = set(features.keys()) - for cross in crosses: - if isinstance(cross, dict): - cross = serialization_lib.deserialize_keras_object(cross) - if isinstance(cross, Cross): - self.crosses.append(cross) - else: - if not crossing_dim: - raise ValueError( - "When specifying `crosses`, the argument " - "`crossing_dim` " - "(dimensionality of the crossing space) " - "should be specified as well." - ) - for key in cross: - if key not in feature_set: - raise ValueError( - "All features referenced " - "in the `crosses` argument " - "should be present in the `features` dict. " - f"Received unknown features: {cross}" - ) - self.crosses.append(Cross(cross, crossing_dim=crossing_dim)) - self.crosses_by_name = {cross.name: cross for cross in self.crosses} - - if output_mode not in {"dict", "concat"}: - raise ValueError( - "Invalid value for argument `output_mode`. " - "Expected one of {'dict', 'concat'}. " - f"Received: output_mode={output_mode}" - ) - self.output_mode = output_mode - - self.inputs = { - name: self._feature_to_input(name, value) - for name, value in self.features.items() - } - self.preprocessors = { - name: value.preprocessor for name, value in self.features.items() - } - self.encoded_features = None - self.crossers = { - cross.name: self._cross_to_crosser(cross) for cross in self.crosses - } - self.one_hot_encoders = {} - self.built = False - self._is_adapted = False - self.concat = None - self._preprocessed_features_names = None - self._crossed_features_names = None - - def _feature_to_input(self, name, feature): - return layers.Input(shape=(1,), dtype=feature.dtype, name=name) - - def _standardize_feature(self, name, feature): - if isinstance(feature, Feature): - return feature - - if isinstance(feature, dict): - return serialization_lib.deserialize_keras_object(feature) - - if feature == "float": - return self.float(name=name) - elif feature == "float_normalized": - return self.float_normalized(name=name) - elif feature == "float_rescaled": - return self.float_rescaled(name=name) - elif feature == "float_discretized": - return self.float_discretized( - name=name, num_bins=self.num_discretization_bins - ) - elif feature == "integer_categorical": - return self.integer_categorical(name=name) - elif feature == "string_categorical": - return self.string_categorical(name=name) - elif feature == "integer_hashed": - return self.integer_hashed(self.hashing_dim, name=name) - elif feature == "string_hashed": - return self.string_hashed(self.hashing_dim, name=name) - else: - raise ValueError(f"Invalid feature type: {feature}") - - def _cross_to_crosser(self, cross): - return layers.HashedCrossing(cross.crossing_dim, name=cross.name) - - def _list_adaptable_preprocessors(self): - adaptable_preprocessors = [] - for name in self.features.keys(): - preprocessor = self.preprocessors[name] - # Special case: a Normalization layer with preset mean/variance. - # Not adaptable. - if isinstance(preprocessor, layers.Normalization): - if preprocessor.input_mean is not None: - continue - if hasattr(preprocessor, "adapt"): - adaptable_preprocessors.append(name) - return adaptable_preprocessors - - def adapt(self, dataset): - if not isinstance(dataset, tf_data.Dataset): - raise ValueError( - "`adapt()` can only be called on a tf.data.Dataset. " - f"Received instead: {dataset} (of type {type(dataset)})" - ) - - for name in self._list_adaptable_preprocessors(): - # Call adapt() on each individual adaptable layer. - - # TODO: consider rewriting this to instead iterate on the - # dataset once, split each batch into individual features, - # and call the layer's `_adapt_function` on each batch - # to simulate the behavior of adapt() in a more performant fashion. - - feature_dataset = dataset.map(lambda x: x[name]) - preprocessor = self.preprocessors[name] - # TODO: consider adding an adapt progress bar. - # Sample 1 element to check the rank - for x in feature_dataset.take(1): - pass - if x.shape.rank == 0: - # The dataset yields unbatched scalars; batch it. - feature_dataset = feature_dataset.batch(32) - if x.shape.rank in {0, 1}: - # If the rank is 1, add a dimension - # so we can reduce on axis=-1. - # Note: if rank was previously 0, it is now 1. - feature_dataset = feature_dataset.map( - lambda x: ops.expand_dims(x, -1) - ) - preprocessor.adapt(feature_dataset) - self._is_adapted = True - self.get_encoded_features() # Finish building the layer - self.built = True - - def get_inputs(self): - self._check_if_built() - return self.inputs - - def get_encoded_features(self): - self._check_if_adapted() - - if self.encoded_features is None: - preprocessed_features = self._preprocess_features(self.inputs) - crossed_features = self._cross_features(preprocessed_features) - merged_features = self._merge_features( - preprocessed_features, crossed_features - ) - self.encoded_features = merged_features - return self.encoded_features - - def _preprocess_features(self, features): - return { - name: self.preprocessors[name](features[name]) - for name in features.keys() - } - - def _cross_features(self, features): - all_outputs = {} - for cross in self.crosses: - inputs = [features[name] for name in cross.feature_names] - outputs = self.crossers[cross.name](inputs) - all_outputs[cross.name] = outputs - return all_outputs - - def _merge_features(self, preprocessed_features, crossed_features): - if not self._preprocessed_features_names: - self._preprocessed_features_names = sorted( - preprocessed_features.keys() - ) - self._crossed_features_names = sorted(crossed_features.keys()) - - all_names = ( - self._preprocessed_features_names + self._crossed_features_names - ) - all_features = [ - preprocessed_features[name] - for name in self._preprocessed_features_names - ] + [crossed_features[name] for name in self._crossed_features_names] - - if self.output_mode == "dict": - output_dict = {} - else: - features_to_concat = [] - - if self.built: - # Fast mode. - for name, feature in zip(all_names, all_features): - encoder = self.one_hot_encoders.get(name, None) - if encoder: - feature = encoder(feature) - if self.output_mode == "dict": - output_dict[name] = feature - else: - features_to_concat.append(feature) - if self.output_mode == "dict": - return output_dict - else: - return self.concat(features_to_concat) - - # If the object isn't built, - # we create the encoder and concat layers below - all_specs = [ - self.features[name] for name in self._preprocessed_features_names - ] + [ - self.crosses_by_name[name] for name in self._crossed_features_names - ] - for name, feature, spec in zip(all_names, all_features, all_specs): - dtype = feature.dtype - - if spec.output_mode == "one_hot": - preprocessor = self.preprocessors.get( - name - ) or self.crossers.get(name) - cardinality = None - if not feature.dtype.startswith("int"): - raise ValueError( - f"Feature '{name}' has `output_mode='one_hot'`. " - "Thus its preprocessor should return an int64 dtype. " - f"Instead it returns a {dtype} dtype." - ) - - if isinstance( - preprocessor, (layers.IntegerLookup, layers.StringLookup) - ): - cardinality = preprocessor.vocabulary_size() - elif isinstance(preprocessor, layers.CategoryEncoding): - cardinality = preprocessor.num_tokens - elif isinstance(preprocessor, layers.Discretization): - cardinality = preprocessor.num_bins - elif isinstance( - preprocessor, (layers.HashedCrossing, layers.Hashing) - ): - cardinality = preprocessor.num_bins - else: - raise ValueError( - f"Feature '{name}' has `output_mode='one_hot'`. " - "However it isn't a standard feature and the " - "dimensionality of its output space is not known, " - "thus it cannot be one-hot encoded. " - "Try using `output_mode='int'`." - ) - if cardinality is not None: - encoder = layers.CategoryEncoding( - num_tokens=cardinality, output_mode="multi_hot" - ) - self.one_hot_encoders[name] = encoder - feature = encoder(feature) - - if self.output_mode == "concat": - dtype = feature.dtype - if dtype.startswith("int") or dtype == "string": - raise ValueError( - f"Cannot concatenate features because feature '{name}' " - f"has not been encoded (it has dtype {dtype}). " - "Consider using `output_mode='dict'`." - ) - features_to_concat.append(feature) - else: - output_dict[name] = feature - - if self.output_mode == "concat": - self.concat = layers.Concatenate(axis=-1) - return self.concat(features_to_concat) - else: - return output_dict - - def _check_if_adapted(self): - if not self._is_adapted: - if not self._list_adaptable_preprocessors(): - self._is_adapted = True - else: - raise ValueError( - "You need to call `.adapt(dataset)` on the FeatureSpace " - "before you can start using it." - ) - - def _check_if_built(self): - if not self.built: - self._check_if_adapted() - # Finishes building - self.get_encoded_features() - self.built = True - - def __call__(self, data): - self._check_if_built() - if not isinstance(data, dict): - raise ValueError( - "A FeatureSpace can only be called with a dict. " - f"Received: data={data} (of type {type(data)}" - ) - - data = { - key: ops.convert_to_tensor(value) for key, value in data.items() - } - rebatched = False - for name, x in data.items(): - if x.shape.rank == 0: - data[name] = ops.reshape(x, (1, 1)) - rebatched = True - elif x.shape.rank == 1: - data[name] = ops.expand_dims(x, -1) - - preprocessed_data = self._preprocess_features(data) - crossed_data = self._cross_features(preprocessed_data) - merged_data = self._merge_features(preprocessed_data, crossed_data) - if rebatched: - if self.output_mode == "concat": - assert merged_data.shape[0] == 1 - return ops.squeeze(merged_data, axis=0) - else: - for name, x in merged_data.items(): - if x.shape.rank == 2 and x.shape[0] == 1: - merged_data[name] = ops.squeeze(x, axis=0) - return merged_data - - def get_config(self): - return { - "features": serialization_lib.serialize_keras_object(self.features), - "output_mode": self.output_mode, - "crosses": serialization_lib.serialize_keras_object(self.crosses), - "crossing_dim": self.crossing_dim, - "hashing_dim": self.hashing_dim, - "num_discretization_bins": self.num_discretization_bins, - } - - @classmethod - def from_config(cls, config): - return cls(**config) - - def get_build_config(self): - return { - name: feature.preprocessor.get_build_config() - for name, feature in self.features.items() - } - - def build_from_config(self, config): - for name in config.keys(): - self.features[name].preprocessor.build_from_config(config[name]) - self._is_adapted = True - - def save(self, filepath): - """Save the `FeatureSpace` instance to a `.keras` file. - - You can reload it via `keras.models.load_model()`: - - ```python - feature_space.save("myfeaturespace.keras") - reloaded_feature_space = keras.models.load_model("myfeaturespace.keras") - ``` - """ - saving_lib.save_model(self, filepath) - - def save_own_variables(self, store): - return - - def load_own_variables(self, store): - return diff --git a/keras_core/utils/feature_space_test.py b/keras_core/utils/feature_space_test.py deleted file mode 100644 index 3083ff14d..000000000 --- a/keras_core/utils/feature_space_test.py +++ /dev/null @@ -1,378 +0,0 @@ -# from keras_core import testing -# from keras_core.utils import feature_space -# from keras_core import operations as ops -# from tensorflow import nest -# from tensorflow import data as tf_data -# from keras_core import layers -# from keras_core import models -# import os - - -# class FeatureSpaceTest(testing.TestCase): -# def _get_train_data_dict( -# self, as_dataset=False, as_tf_tensors=False, as_labeled_dataset=False -# ): -# data = { -# "float_1": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], -# "float_2": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], -# "float_3": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], -# "string_1": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], -# "string_2": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], -# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -# "int_2": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -# "int_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -# } -# if as_dataset: -# return tf_data.Dataset.from_tensor_slices(data) -# elif as_tf_tensors: -# return nest.map_structure(ops.convert_to_tensor, data) -# elif as_labeled_dataset: -# labels = [0, 1, 0, 1, 0, 0, 1, 0, 1, 1] -# return tf_data.Dataset.from_tensor_slices((data, labels)) -# return data - -# def test_basic_usage(self): -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": "float", -# "float_2": "float_normalized", -# "float_3": "float_discretized", -# "string_1": "string_categorical", -# "string_2": "string_hashed", -# "int_1": "integer_categorical", -# "int_2": "integer_hashed", -# "int_3": "integer_categorical", -# }, -# crosses=[("float_3", "string_1"), ("string_2", "int_2")], -# output_mode="concat", -# ) -# # Test unbatched adapt -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# # Test batched adapt -# fs.adapt(self._get_train_data_dict(as_dataset=True).batch(4)) - -# # Test unbatched call on raw data -# data = { -# key: value[0] for key, value in self._get_train_data_dict().items() -# } -# out = fs(data) -# self.assertEqual(out.shape, [195]) - -# # Test unbatched call on TF tensors -# data = self._get_train_data_dict(as_tf_tensors=True) -# data = {key: value[0] for key, value in data.items()} -# out = fs(data) -# self.assertEqual(out.shape, [195]) - -# # Test batched call on raw data -# out = fs(self._get_train_data_dict()) -# self.assertEqual(out.shape, [10, 195]) - -# # Test batched call on TF tensors -# out = fs(self._get_train_data_dict(as_tf_tensors=True)) -# self.assertEqual(out.shape, [10, 195]) - -# def test_output_mode_dict(self): -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": "float", -# "float_2": "float_normalized", -# "float_3": "float_discretized", -# "string_1": "string_categorical", -# "string_2": "string_hashed", -# "int_1": "integer_categorical", -# "int_2": "integer_hashed", -# "int_3": "integer_categorical", -# }, -# crosses=[("float_3", "string_1"), ("string_2", "int_2")], -# output_mode="dict", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) - -# # Test unbatched call on raw data -# data = { -# key: value[0] for key, value in self._get_train_data_dict().items() -# } -# out = fs(data) -# self.assertIsInstance(out, dict) -# self.assertLen(out, 10) -# self.assertEqual(out["string_1"].shape, [11]) -# self.assertEqual(out["int_2"].shape, [32]) -# self.assertEqual(out["string_2_X_int_2"].shape, [32]) - -# # Test batched call on raw data -# out = fs(self._get_train_data_dict()) -# self.assertIsInstance(out, dict) -# self.assertLen(out, 10) -# self.assertEqual(out["string_1"].shape, [10, 11]) -# self.assertEqual(out["int_2"].shape, [10, 32]) -# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32]) - -# # Test batched call on TF tensors -# out = fs(self._get_train_data_dict(as_tf_tensors=True)) -# self.assertIsInstance(out, dict) -# self.assertLen(out, 10) -# self.assertEqual(out["string_1"].shape, [10, 11]) -# self.assertEqual(out["int_2"].shape, [10, 32]) -# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32]) - -# def test_output_mode_dict_of_ints(self): -# cls = feature_space.FeatureSpace -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": "float", -# "float_2": "float_normalized", -# "float_3": "float_discretized", -# "string_1": cls.string_categorical(output_mode="int"), -# "string_2": cls.string_hashed(num_bins=32, output_mode="int"), -# "int_1": cls.integer_categorical(output_mode="int"), -# "int_2": cls.integer_hashed(num_bins=32, output_mode="int"), -# "int_3": cls.integer_categorical(output_mode="int"), -# }, -# crosses=[ -# cls.cross( -# ("float_3", "string_1"), output_mode="int", crossing_dim=32 -# ), -# cls.cross( -# ("string_2", "int_2"), output_mode="int", crossing_dim=32 -# ), -# ], -# output_mode="dict", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# data = { -# key: value[0] for key, value in self._get_train_data_dict().items() -# } -# out = fs(data) -# self.assertIsInstance(out, dict) -# self.assertLen(out, 10) -# self.assertEqual(out["string_1"].shape, [1]) -# self.assertEqual(out["string_1"].dtype.name, "int64") -# self.assertEqual(out["int_2"].shape, [1]) -# self.assertEqual(out["int_2"].dtype.name, "int64") -# self.assertEqual(out["string_2_X_int_2"].shape, [1]) -# self.assertEqual(out["string_2_X_int_2"].dtype.name, "int64") - -# def test_functional_api_sync_processing(self): -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": "float", -# "float_2": "float_normalized", -# "float_3": "float_discretized", -# "string_1": "string_categorical", -# "string_2": "string_hashed", -# "int_1": "integer_categorical", -# "int_2": "integer_hashed", -# "int_3": "integer_categorical", -# }, -# crosses=[("float_3", "string_1"), ("string_2", "int_2")], -# output_mode="concat", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# inputs = fs.get_inputs() -# features = fs.get_encoded_features() -# outputs = layers.Dense(1)(features) -# model = models.Model(inputs=inputs, outputs=outputs) -# model.compile("adam", "mse") -# ds = self._get_train_data_dict(as_labeled_dataset=True) -# model.fit(ds.batch(4)) -# model.evaluate(ds.batch(4)) -# ds = self._get_train_data_dict(as_dataset=True) -# model.predict(ds.batch(4)) - -# def test_tf_data_async_processing(self): -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": "float", -# "float_2": "float_normalized", -# "float_3": "float_discretized", -# "string_1": "string_categorical", -# "string_2": "string_hashed", -# "int_1": "integer_categorical", -# "int_2": "integer_hashed", -# "int_3": "integer_categorical", -# }, -# crosses=[("float_3", "string_1"), ("string_2", "int_2")], -# output_mode="concat", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# features = fs.get_encoded_features() -# outputs = layers.Dense(1)(features) -# model = models.Model(inputs=features, outputs=outputs) -# model.compile("adam", "mse") -# ds = self._get_train_data_dict(as_labeled_dataset=True) -# # Try map before batch -# ds = ds.map(lambda x, y: (fs(x), y)) -# model.fit(ds.batch(4)) -# # Try map after batch -# ds = self._get_train_data_dict(as_labeled_dataset=True) -# ds = ds.batch(4) -# ds = ds.map(lambda x, y: (fs(x), y)) -# model.evaluate(ds) -# ds = self._get_train_data_dict(as_dataset=True) -# ds = ds.map(fs) -# model.predict(ds.batch(4)) - -# def test_advanced_usage(self): -# cls = feature_space.FeatureSpace -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": cls.float(), -# "float_2": cls.float_normalized(), -# "float_3": cls.float_discretized(num_bins=3), -# "string_1": cls.string_categorical(max_tokens=5), -# "string_2": cls.string_hashed(num_bins=32), -# "int_1": cls.integer_categorical( -# max_tokens=5, num_oov_indices=2 -# ), -# "int_2": cls.integer_hashed(num_bins=32), -# "int_3": cls.integer_categorical(max_tokens=5), -# }, -# crosses=[ -# cls.cross(("float_3", "string_1"), crossing_dim=32), -# cls.cross(("string_2", "int_2"), crossing_dim=32), -# ], -# output_mode="concat", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# data = { -# key: value[0] for key, value in self._get_train_data_dict().items() -# } -# out = fs(data) -# self.assertEqual(out.shape, [148]) - -# def test_manual_kpl(self): -# data = { -# "text": ["1st string", "2nd string", "3rd string"], -# } -# cls = feature_space.FeatureSpace - -# # Test with a tf-idf TextVectorization layer -# tv = layers.TextVectorization(output_mode="tf_idf") -# fs = feature_space.FeatureSpace( -# features={ -# "text": cls.feature( -# preprocessor=tv, dtype="string", output_mode="float" -# ), -# }, -# output_mode="concat", -# ) -# fs.adapt(tf_data.Dataset.from_tensor_slices(data)) -# out = fs(data) -# self.assertEqual(out.shape, [3, 5]) - -# def test_no_adapt(self): -# data = { -# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], -# } -# fs = feature_space.FeatureSpace( -# { -# "int_1": "integer_hashed", -# }, -# output_mode="concat", -# ) -# out = fs(data) -# self.assertEqual(out.shape, [10, 32]) - -# def test_saving(self): -# cls = feature_space.FeatureSpace -# fs = feature_space.FeatureSpace( -# features={ -# "float_1": cls.float(), -# "float_2": cls.float_normalized(), -# "float_3": cls.float_discretized(num_bins=3), -# "string_1": cls.string_categorical(max_tokens=5), -# "string_2": cls.string_hashed(num_bins=32), -# "int_1": cls.integer_categorical( -# max_tokens=5, num_oov_indices=2 -# ), -# "int_2": cls.integer_hashed(num_bins=32), -# "int_3": cls.integer_categorical(max_tokens=5), -# }, -# crosses=[ -# cls.cross(("float_3", "string_1"), crossing_dim=32), -# cls.cross(("string_2", "int_2"), crossing_dim=32), -# ], -# output_mode="concat", -# ) -# fs.adapt(self._get_train_data_dict(as_dataset=True)) -# data = { -# key: value[0] for key, value in self._get_train_data_dict().items() -# } -# ref_out = fs(data) - -# temp_filepath = os.path.join(self.get_temp_dir(), "fs.keras") -# fs.save(temp_filepath) -# fs = models.models.load_model(temp_filepath) - -# # Save again immediately after loading to test idempotency -# temp_filepath = os.path.join(self.get_temp_dir(), "fs2.keras") -# fs.save(temp_filepath) - -# # Test correctness of the first saved FS -# out = fs(data) -# self.assertAllClose(out, ref_out) - -# inputs = fs.get_inputs() -# outputs = fs.get_encoded_features() -# model = models.Model(inputs=inputs, outputs=outputs) -# ds = self._get_train_data_dict(as_dataset=True) -# out = model.predict(ds.batch(4)) -# self.assertAllClose(out[0], ref_out) - -# # Test correctness of the re-saved FS -# fs = models.models.load_model(temp_filepath) -# out = fs(data) -# self.assertAllClose(out, ref_out) - -# def test_errors(self): -# # Test no features -# with self.assertRaisesRegex(ValueError, "cannot be None or empty"): -# feature_space.FeatureSpace(features={}) -# # Test no crossing dim -# with self.assertRaisesRegex(ValueError, "`crossing_dim`"): -# feature_space.FeatureSpace( -# features={ -# "f1": "integer_categorical", -# "f2": "integer_categorical", -# }, -# crosses=[("f1", "f2")], -# crossing_dim=None, -# ) -# # Test wrong cross feature name -# with self.assertRaisesRegex(ValueError, "should be present in "): -# feature_space.FeatureSpace( -# features={ -# "f1": "integer_categorical", -# "f2": "integer_categorical", -# }, -# crosses=[("f1", "unknown")], -# crossing_dim=32, -# ) -# # Test wrong output mode -# with self.assertRaisesRegex(ValueError, "for argument `output_mode`"): -# feature_space.FeatureSpace( -# features={ -# "f1": "integer_categorical", -# "f2": "integer_categorical", -# }, -# output_mode="unknown", -# ) -# # Test call before adapt -# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): -# fs = feature_space.FeatureSpace( -# features={ -# "f1": "integer_categorical", -# "f2": "integer_categorical", -# } -# ) -# fs({"f1": [0], "f2": [0]}) -# # Test get_encoded_features before adapt -# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): -# fs = feature_space.FeatureSpace( -# features={ -# "f1": "integer_categorical", -# "f2": "integer_categorical", -# } -# ) -# fs.get_encoded_features() \ No newline at end of file