diff --git a/examples/demo_custom_layer_backend_agnostic.py b/examples/demo_custom_layer_backend_agnostic.py index 8083dbcac..5a13cd749 100644 --- a/examples/demo_custom_layer_backend_agnostic.py +++ b/examples/demo_custom_layer_backend_agnostic.py @@ -26,7 +26,7 @@ class MyDense(layers.Layer): # You can also use add_weight self.b = self.add_weight( shape=(self.units,), - initializer="zeros", + initializer=initializers.Zeros(), name="bias", trainable=True, ) diff --git a/examples/demo_functional.py b/examples/demo_functional.py index ba607fd0c..e93e5b0fb 100644 --- a/examples/demo_functional.py +++ b/examples/demo_functional.py @@ -6,7 +6,7 @@ from keras_core import losses from keras_core import metrics from keras_core import optimizers -inputs = layers.Input((100,), batch_size=32) +inputs = layers.Input((100,)) x = layers.Dense(256, activation="relu")(inputs) residual = x x = layers.Dense(256, activation="relu")(x) @@ -27,9 +27,18 @@ model.compile( loss=losses.MeanSquaredError(), metrics=[metrics.CategoricalAccuracy(name="acc"), metrics.MeanSquaredError(name="mse")], ) + +print("\nTrain model") history = model.fit( x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 ) - -print("History:") +print("\nHistory:") print(history.history) + +print("\nEvaluate model") +scores = model.evaluate(x, y, return_dict=True) +print(scores) + +print("\nRun inference") +pred = model.predict(x) +print(f"Inferred output shape {pred.shape}") diff --git a/keras_core/backend/common/keras_tensor.py b/keras_core/backend/common/keras_tensor.py index 7d97c6c37..f1a8c68a0 100644 --- a/keras_core/backend/common/keras_tensor.py +++ b/keras_core/backend/common/keras_tensor.py @@ -9,10 +9,11 @@ class KerasTensor: def __init__(self, shape, dtype="float32", record_history=True, name=None): from keras_core import backend - if backend.DYNAMIC_SHAPES_OK: - shape = backend.standardize_shape(shape, fully_defined=False) - else: - shape = backend.standardize_shape(shape, fully_defined=True) + shape = backend.standardize_shape( + shape, + allow_dynamic_batch_size=backend.DYNAMIC_BATCH_SIZE_OK, + allow_all_dynamic=backend.DYNAMIC_SHAPES_OK, + ) self.shape = shape self.dtype = backend.standardize_dtype(dtype) self.name = name or auto_name(self.__class__.__name__) diff --git a/keras_core/backend/common/stateless_scope.py b/keras_core/backend/common/stateless_scope.py index 4279ba70a..ec1c13747 100644 --- a/keras_core/backend/common/stateless_scope.py +++ b/keras_core/backend/common/stateless_scope.py @@ -1,13 +1,12 @@ from keras_core.api_export import keras_core_export from keras_core.backend.common import global_state -from keras_core.backend.common.variables import KerasVariable -from keras_core.backend.common.variables import initialize_all_variables @keras_core_export("keras_core.StatelessScope") class StatelessScope: def __init__(self, state_mapping=None, collect_losses=False): from keras_core import backend + from keras_core.backend.common.variables import KerasVariable self.collect_losses = collect_losses self.losses = [] @@ -54,6 +53,10 @@ class StatelessScope: # We're back in eager scope; # if any variables were created within the stateless # scope, we initialize them here. + from keras_core.backend.common.variables import ( + initialize_all_variables, + ) + initialize_all_variables() diff --git a/keras_core/backend/common/variables.py b/keras_core/backend/common/variables.py index e3de325f3..0b54ad2fd 100644 --- a/keras_core/backend/common/variables.py +++ b/keras_core/backend/common/variables.py @@ -1,5 +1,9 @@ +import numpy as np + from keras_core.backend import config from keras_core.backend.common import global_state +from keras_core.backend.common.stateless_scope import get_stateless_scope +from keras_core.backend.common.stateless_scope import in_stateless_scope from keras_core.utils.naming import auto_name @@ -20,7 +24,6 @@ class KerasVariable: f"Received: initializer={initializer} " f"and shape={shape}" ) - from keras_core.backend.common.stateless_scope import in_stateless_scope if in_stateless_scope(): if callable(initializer): @@ -76,6 +79,42 @@ class KerasVariable: return autocast_scope.maybe_cast(value) return value + def numpy(self): + return np.array(self.value) + + @property + def value(self): + if in_stateless_scope(): + scope = get_stateless_scope() + value = scope.get_current_value(self) + if value is not None: + return self._maybe_autocast(value) + if self._value is None: + # Unitialized variable. Return a placeholder. + # This is fine because it's only ever used + # in during shape inference / graph tracing + # (anything else would be a bug, to be fixed.) + return self._maybe_autocast( + self._initializer(self._shape, dtype=self._dtype) + ) + return self._maybe_autocast(self._value) + + def assign(self, value): + value = self._convert_to_tensor(value, dtype=self.dtype) + if value.shape != self.value.shape: + raise ValueError( + "The shape of the target variable and " + "the shape of the target value in " + "`variable.assign(value)` must match. " + f"Received: value.shape={value.shape}; " + f"variable.shape={self.value.shape}" + ) + if in_stateless_scope(): + scope = get_stateless_scope() + scope.add_update((self, value)) + else: + self._direct_assign(value) + @property def dtype(self): autocast_scope = get_autocast_scope() @@ -100,16 +139,6 @@ class KerasVariable: def _initialize(self, value): raise NotImplementedError - @property - def value(self): - raise NotImplementedError - - def numpy(self): - raise NotImplementedError - - def assign(self, value): - raise NotImplementedError - def _convert_to_tensor(self, value, dtype=None): raise NotImplementedError @@ -336,50 +365,59 @@ ALLOWED_DTYPES = { "int64", "bfloat16", "bool", -} - -PYTHON_DTYPES_MAP = { - bool: "bool", - int: "int", # TBD by backend - float: "float32", + "string", } def standardize_dtype(dtype): - if dtype is None: - return config.floatx() - if dtype in PYTHON_DTYPES_MAP: - dtype = PYTHON_DTYPES_MAP.get(dtype) if dtype == "int": if config.backend() == "tensorflow": dtype = "int64" else: dtype = "int32" + if dtype is None: + return config.floatx() if hasattr(dtype, "name"): dtype = dtype.name - if dtype not in ALLOWED_DTYPES: raise ValueError(f"Invalid dtype: {dtype}") return dtype -def standardize_shape(shape, fully_defined=False): +def standardize_shape( + shape, allow_dynamic_batch_size=True, allow_all_dynamic=True +): if not isinstance(shape, tuple): if shape is None: raise ValueError("Undefined shapes are not supported.") if not hasattr(shape, "__iter__"): raise ValueError(f"Cannot convert '{shape}' to a shape.") shape = tuple(shape) - for e in shape: - if not fully_defined and e is None: + + for i, e in enumerate(shape): + if i == 0 and allow_dynamic_batch_size and e is None: + continue + if allow_all_dynamic and e is None: continue if not isinstance(e, int): - raise ValueError( + msg = ( f"Cannot convert '{shape}' to a shape. " - f"Found invalid entry '{e}'. Only " - "fully-defined shapes are allowed with the " - f"{config.backend()} backend." + f"Found invalid entry '{e}'. " ) + if allow_dynamic_batch_size: + msg += ( + "Dynamic shapes (shapes with `None` entries) " + f"are not allowed with the {config.backend()}, " + "except for the batch size (axis 0)." + ) + else: + msg += ( + "Dynamic shapes (shapes with `None` entries) " + f"are not allowed with the {config.backend()}. " + "All dimensions should be positive integers, " + "including the batch size (axis 0)." + ) + raise ValueError(msg) if e < 0: raise ValueError( f"Cannot convert '{shape}' to a shape. " diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index b958f23ed..e6e6c4c47 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -4,6 +4,7 @@ from keras_core.backend.jax import math from keras_core.backend.jax import nn from keras_core.backend.jax import numpy from keras_core.backend.jax import random +from keras_core.backend.jax.core import DYNAMIC_BATCH_SIZE_OK from keras_core.backend.jax.core import DYNAMIC_SHAPES_OK from keras_core.backend.jax.core import Variable from keras_core.backend.jax.core import cast diff --git a/keras_core/backend/jax/core.py b/keras_core/backend/jax/core.py index 65be5f978..748eff92c 100644 --- a/keras_core/backend/jax/core.py +++ b/keras_core/backend/jax/core.py @@ -1,69 +1,30 @@ import jax import jax.numpy as jnp -import numpy as np from tensorflow import nest from keras_core.backend.common import KerasVariable from keras_core.backend.common import standardize_dtype from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.stateless_scope import StatelessScope -from keras_core.backend.common.stateless_scope import get_stateless_scope -from keras_core.backend.common.stateless_scope import in_stateless_scope DYNAMIC_SHAPES_OK = False # Dynamic shapes NG +DYNAMIC_BATCH_SIZE_OK = True class Variable(KerasVariable): def _initialize(self, value): self._value = jnp.array(value, dtype=self._dtype) - def assign(self, value): - value = convert_to_tensor(value, dtype=self.dtype) - if value.shape != self.shape: - raise ValueError( - "The shape of the target variable and " - "the shape of the target value in " - "`variable.assign(value)` must match. " - f"Received: value.shape={value.shape}; " - f"variable.shape={self.value.shape}" - ) - if in_stateless_scope(): - scope = get_stateless_scope() - scope.add_update((self, value)) - else: - if isinstance(value, jnp.ndarray) and value.dtype == self.dtype: - # Avoid a memory copy - self._value = value - else: - self._value = jnp.array(value, dtype=self.dtype) + def _direct_assign(self, value): + self._value = value - @property - def value(self): - if in_stateless_scope(): - scope = get_stateless_scope() - value = scope.get_current_value(self) - if value is not None: - return self._maybe_autocast(value) - if self._value is None: - # Unitialized variable. Return a placeholder. - # This is fine because it's only ever used - # in during shape inference with JAX tracer objects - # (anything else would be a bug, to be fixed.) - return self._maybe_autocast( - self._initializer(self._shape, dtype=self._dtype) - ) - return self._maybe_autocast(self._value) - - def numpy(self): - return np.array(self.value) + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) # Overload native accessor. def __jax_array__(self): return self.value - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype) - def convert_to_tensor(x, dtype=None): if dtype is not None: @@ -98,10 +59,22 @@ def name_scope(name): # Shape / dtype inference util def compute_output_spec(fn, *args, **kwargs): with StatelessScope(): + dynamic_batch_map = {} + magic_number = 3 def convert_keras_tensor_to_jax(x): if isinstance(x, KerasTensor): - return jax.ShapeDtypeStruct(x.shape, dtype=x.dtype) + shape = x.shape + if shape and x.shape[0] is None: + shape = list(shape) + shape[0] = magic_number + dynamic_batch = True + else: + dynamic_batch = False + + jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype) + dynamic_batch_map[jax_tensor] = dynamic_batch + return jax_tensor return x built_in_types = (type(None), int, float, str, bool, complex, bytes) @@ -136,6 +109,17 @@ def compute_output_spec(fn, *args, **kwargs): def convert_jax_spec_to_keras_tensor(x): if isinstance(x, jax.ShapeDtypeStruct): + if dynamic_batch_map.get(x, False): + shape = list(x.shape) + if shape[0] != magic_number: + raise ValueError( + f"Function {fn} appears to change the " + "batch size of its input. This is not " + "allowed when used in conjunction with " + "dynamic batch sizes. Consider using " + "a static batch size here." + ) + shape[0] = None return KerasTensor(x.shape, x.dtype) return x diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index d94123675..bb01a48b5 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -5,6 +5,7 @@ from keras_core.backend.tensorflow import nn from keras_core.backend.tensorflow import numpy from keras_core.backend.tensorflow import random from keras_core.backend.tensorflow.core import DYNAMIC_SHAPES_OK +from keras_core.backend.tensorflow.core import DYNAMIC_BATCH_SIZE_OK from keras_core.backend.tensorflow.core import Variable from keras_core.backend.tensorflow.core import cast from keras_core.backend.tensorflow.core import compute_output_spec diff --git a/keras_core/backend/tensorflow/core.py b/keras_core/backend/tensorflow/core.py index b19c0506a..acd41bb56 100644 --- a/keras_core/backend/tensorflow/core.py +++ b/keras_core/backend/tensorflow/core.py @@ -4,11 +4,10 @@ from keras_core.backend.common import KerasVariable from keras_core.backend.common import standardize_dtype from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.stateless_scope import StatelessScope -from keras_core.backend.common.stateless_scope import get_stateless_scope -from keras_core.backend.common.stateless_scope import in_stateless_scope from keras_core.utils.naming import auto_name DYNAMIC_SHAPES_OK = True +DYNAMIC_BATCH_SIZE_OK = True class Variable(KerasVariable, tf.__internal__.types.Tensor): @@ -23,37 +22,11 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor): value, dtype=self._dtype, trainable=self.trainable, name=self.name ) - def assign(self, value): - value = convert_to_tensor(value, dtype=self.dtype) - if value.shape != self.value.shape: - raise ValueError( - "The shape of the target variable and " - "the shape of the target value in " - "`variable.assign(value)` must match. " - f"Received: value.shape={value.shape}; " - f"variable.shape={self.value.shape}" - ) - if in_stateless_scope(): - scope = get_stateless_scope() - scope.add_update((self, value)) - else: - self.value.assign(value) + def _direct_assign(self, value): + self.value.assign(value) - @property - def value(self): - if in_stateless_scope(): - scope = get_stateless_scope() - value = scope.get_current_value(self) - if value is not None: - return self._maybe_autocast(value) - if self._value is None: - # Unitialized variable. Return a placeholder. - # This is fine because it's only ever used - # during shape inference in a scratch graph - # (anything else would be a bug, to be fixed.) - init_val = self._initializer(self._shape, dtype=self._dtype) - return self._maybe_autocast(init_val) - return self._maybe_autocast(self._value) + def _convert_to_tensor(self, value, dtype=None): + return convert_to_tensor(value, dtype=dtype) def numpy(self): # noqa: F811 return self.value.numpy() @@ -66,9 +39,6 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor): def __tf_tensor__(self, dtype=None, name=None): return tf.convert_to_tensor(self.value, dtype=dtype, name=name) - def _convert_to_tensor(self, value, dtype=None): - return convert_to_tensor(value, dtype=dtype) - def convert_to_tensor(x, dtype=None): if dtype is not None: diff --git a/keras_core/backend/torch/core.py b/keras_core/backend/torch/core.py index 20b448eaa..2f05d03b6 100644 --- a/keras_core/backend/torch/core.py +++ b/keras_core/backend/torch/core.py @@ -1,12 +1,11 @@ -import numpy as np import torch from keras_core.backend.common import KerasVariable from keras_core.backend.common import standardize_dtype -from keras_core.backend.common.stateless_scope import get_stateless_scope -from keras_core.backend.common.stateless_scope import in_stateless_scope DYNAMIC_SHAPES_OK = True +DYNAMIC_BATCH_SIZE_OK = True + TORCH_DTYPES = { "float16": torch.float16, @@ -37,53 +36,16 @@ class Variable(KerasVariable): def _initialize(self, value): self._value = convert_to_tensor(value, dtype=self._dtype) - def assign(self, value): - value = convert_to_tensor(value, dtype=self.dtype) - if value.shape != self.shape: - raise ValueError( - "The shape of the target variable and " - "the shape of the target value in " - "`variable.assign(value)` must match. " - f"Received: value.shape={value.shape}; " - f"variable.shape={self.value.shape}" - ) - if in_stateless_scope(): - scope = get_stateless_scope() - scope.add_update((self, value)) - else: - # torch `as_tensor` by default, doesn't copy if tensor is same type - self._value = convert_to_tensor(value, dtype=self.dtype) - - @property - def value(self): - if in_stateless_scope(): - scope = get_stateless_scope() - value = scope.get_current_value(self) - if value is not None: - return self._maybe_autocast(value) - if self._value is None: - # Unitialized variable. Return a placeholder. - # This is fine because it's only ever used - # during shape inference in a scratch graph - # (anything else would be a bug, to be fixed.) - return self._maybe_autocast( - convert_to_tensor( - self._initializer(self._shape, dtype=self._dtype), - dtype=self._dtype, - ) - ) - return self._maybe_autocast(self._value) - - def numpy(self): - return np.array(self.value) - - # Overload native accessor. - def __torch_function__(self, func, types, args=(), kwargs=None): - raise NotImplementedError + def _direct_assign(self, value): + self._value = value def _convert_to_tensor(self, value, dtype=None): return convert_to_tensor(value, dtype=dtype) + # Overload native accessor. + def __torch_function__(self, func, types, args=(), kwargs=None): + return func(self.value, *args, **kwargs) + def convert_to_tensor(x, dtype=None): # TODO: Need to address device placement arg of `as_tensor` diff --git a/keras_core/backend/torch/numpy.py b/keras_core/backend/torch/numpy.py index 452ca34ea..0e75bcc59 100644 --- a/keras_core/backend/torch/numpy.py +++ b/keras_core/backend/torch/numpy.py @@ -1,109 +1,93 @@ import torch -from keras_core.backend.torch.core import convert_to_tensor from keras_core.backend.torch.core import to_torch_dtype def add(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.add(x1, x2) + return x1 + x2 def einsum(subscripts, *operands, **kwargs): - operands = [convert_to_tensor(operand) for operand in operands] - return torch.einsum(subscripts, *operands) + pass + # return tfnp.einsum(subscripts, *operands, **kwargs) def subtract(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.subtract(x1, x2) + return x1 - x2 def matmul(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.matmul(x1, x2) + pass + # return tfnp.matmul(x1, x2) def multiply(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.multiply(x1, x2) + return x1 * x2 def mean(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - return torch.mean(x, axis=axis, keepdims=keepdims) + return torch.mean(x, dim=axis, keepdim=keepdims) def max(x, axis=None, keepdims=False, initial=None): - x = convert_to_tensor(x) - if axis is None: - result = torch.max(x) - else: - if isinstance(axis, list): - axis = axis[-1] - result = torch.max(x, dim=axis, keepdim=keepdims) + # TODO: handle initial + return torch.max(x, dim=axis, keepdim=keepdims) + # The TensorFlow numpy API implementation doesn't support `initial` so we + # handle it manually here. + # if initial is not None: + # return tf.math.maximum( + # tfnp.max(x, axis=axis, keepdims=keepdims), initial + # ) - if isinstance(getattr(result, "values", None), torch.Tensor): - result = result.values + # TensorFlow returns -inf by default for an empty list, but for consistency + # with other backends and the numpy API we want to throw in this case. + # tf.assert_greater( + # size(x), + # tf.constant(0, dtype=tf.int64), + # message="Cannot compute the max of an empty tensor.", + # ) - if initial is not None: - return torch.maximum(result, initial) - return result + # return tfnp.max(x, axis=axis, keepdims=keepdims) def ones(shape, dtype="float32"): dtype = to_torch_dtype(dtype) - return torch.ones(*shape, dtype=dtype) + torch.ones(*shape, dtype=dtype) def zeros(shape, dtype="float32"): dtype = to_torch_dtype(dtype) - return torch.zeros(*shape, dtype=dtype) + torch.zeros(*shape, dtype=dtype) def absolute(x): - return abs(x) + pass + # return tfnp.absolute(x) def abs(x): - x = convert_to_tensor(x) - return torch.abs(x) + pass + # return absolute(x) def all(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - if axis is not None: - if isinstance(axis, list): - axis = axis[-1] - return torch.all(x, dim=axis, keepdim=keepdims) - else: - return torch.all(x) + pass + # return tfnp.all(x, axis=axis, keepdims=keepdims) def any(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - if axis is not None: - if isinstance(axis, list): - axis = axis[-1] - return torch.any(x, dim=axis, keepdim=keepdims) - else: - return torch.any(x) + pass + # return tfnp.any(x, axis=axis, keepdims=keepdims) def amax(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - if axis is not None: - return torch.amax(x, dim=axis, keepdim=keepdims) - else: - return torch.amax(x) + pass + # return tfnp.amax(x, axis=axis, keepdims=keepdims) def amin(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - if axis is not None: - return torch.amin(x, dim=axis, keepdim=keepdims) - else: - return torch.amin(x) + pass + # return tfnp.amin(x, axis=axis, keepdims=keepdims) def append( @@ -111,694 +95,594 @@ def append( x2, axis=None, ): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - if axis is None: - return torch.cat((x1.flatten(), x2.flatten())) - return torch.cat((x1, x2), dim=axis) + pass + # return tfnp.append(x1, x2, axis=axis) def arange(start, stop=None, step=None, dtype=None): - dtype = to_torch_dtype(dtype) - if stop is None: - return torch.arange(start, step=step, dtype=dtype) - step = step or 1 - return torch.arange(start, stop, step=step, dtype=dtype) + pass + # return tfnp.arange(start, stop, step=step, dtype=dtype) def arccos(x): - x = convert_to_tensor(x) - return torch.arccos(x) + pass + # return tfnp.arccos(x) def arcsin(x): - x = convert_to_tensor(x) - return torch.arcsin(x) + pass + # return tfnp.arcsin(x) def arctan(x): - x = convert_to_tensor(x) - return torch.arctan(x) + pass + # return tfnp.arctan(x) def arctan2(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.arctan2(x1, x2) + pass + # return tfnp.arctan2(x1, x2) def argmax(x, axis=None): - x = convert_to_tensor(x) - return torch.argmax(x, dim=axis) + pass + # return tfnp.argmax(x, axis=axis) def argmin(x, axis=None): - x = convert_to_tensor(x) - return torch.argmin(x, dim=axis) + pass + # return tfnp.argmin(x, axis=axis) def argsort(x, axis=-1): - x = convert_to_tensor(x) - if axis is None: - axis = -1 - x = x.reshape(-1) - return torch.argsort(x, dim=axis) + pass + # return tfnp.argsort(x, axis=axis) def array(x, dtype=None): - dtype = to_torch_dtype(dtype) - if not isinstance(x, torch.Tensor): - return x - return x.numpy() + pass + # return tfnp.array(x, dtype=dtype) def average(x, axis=None, weights=None): - x = convert_to_tensor(x) - if weights is not None: - weights = convert_to_tensor(weights) - return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum( - weights, dim=-1 - ) - return torch.mean(x, axis) - - -def bincount(x, weights=None, minlength=0): - x = convert_to_tensor(x, dtype=int) - weights = convert_to_tensor(weights) - return torch.bincount(x, weights, minlength) + pass + # return tfnp.average(x, weights=weights, axis=axis) def broadcast_to(x, shape): - x = convert_to_tensor(x) - return torch.broadcast_to(x, shape) + pass + # return tfnp.broadcast_to(x, shape) def ceil(x): - x = convert_to_tensor(x) - return torch.ceil(x) + pass + # return tfnp.ceil(x) def clip(x, x_min, x_max): - x = convert_to_tensor(x) - x_min, x_max = convert_to_tensor(x_min), convert_to_tensor(x_max) - return torch.clip(x, min=x_min, max=x_max) + pass + # return tfnp.clip(x, x_min, x_max) def concatenate(xs, axis=0): - xs = [convert_to_tensor(x) for x in xs] - return torch.cat(xs, dim=axis) + pass + # return tfnp.concatenate(xs, axis=axis) def conjugate(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) # needed for complex type conversion - return torch.conj(x).resolve_conj() + pass + # return tfnp.conjugate(x) def conj(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) # needed for complex type conversion - return torch.conj(x).resolve_conj() + pass + # return conjugate(x) def copy(x): - x = convert_to_tensor(x) - return torch.clone(x) + pass + # return tfnp.copy(x) def cos(x): - x = convert_to_tensor(x) - return torch.cos(x) + pass + # return tfnp.cos(x) def count_nonzero(x, axis=None): - x = convert_to_tensor(x) - return torch.count_nonzero(x, dim=axis).T + pass + # return tfnp.count_nonzero(x, axis=axis) def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): - # TODO: There is API divergence between np.cross and torch.cross, - # preventing `axisa`, `axisb`, and `axisc` parameters from - # being used. - # https://github.com/pytorch/pytorch/issues/50273 - if axisa != -1 or axisb != -1 or axisc != -1: - raise NotImplementedError( - "Due to API divergence between `torch.cross()` and " - "`np.cross`, the following arguments are not supported: " - f"axisa={axisa}, axisb={axisb}, axisc={axisc}" - ) - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.cross(x1, x2, dim=axis) + pass + # return tfnp.cross( + # x1, + # x2, + # axisa=axisa, + # axisb=axisb, + # axisc=axisc, + # axis=axis, + # ) def cumprod(x, axis=None): - x = convert_to_tensor(x) - if axis is None: - x = x.flatten() - axis = 0 - return torch.cumprod(x, dim=axis) + pass + # return tfnp.cumprod(x, axis=axis) def cumsum(x, axis=None): - x = convert_to_tensor(x) - if axis is None: - x = x.flatten() - axis = 0 - return torch.cumsum(x, dim=axis) + pass + # return tfnp.cumsum(x, axis=axis) def diag(x, k=0): - x = convert_to_tensor(x) - return torch.diag(x, diagonal=k) + pass + # return tfnp.diag(x, k=k) def diagonal(x, offset=0, axis1=0, axis2=1): - x = convert_to_tensor(x) - return torch.diagonal( - x, - offset=offset, - dim1=axis1, - dim2=axis2, - ) + pass + # return tfnp.diagonal( + # x, + # offset=offset, + # axis1=axis1, + # axis2=axis2, + # ) def dot(x, y): - x, y = convert_to_tensor(x), convert_to_tensor(y) - if len(x.shape) == 0 or len(y.shape) == 0: - return torch.multiply(x, y) - return torch.matmul(x, y) + pass + # return tfnp.dot(x, y) def empty(shape, dtype="float32"): - dtype = to_torch_dtype(dtype) - return torch.empty(size=shape, dtype=dtype) + pass + # return tfnp.empty(shape, dtype=dtype) def equal(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.equal(x1, x2) + pass + # return tfnp.equal(x1, x2) def exp(x): - x = convert_to_tensor(x) - return torch.exp(x) + pass + # return tfnp.exp(x) def expand_dims(x, axis): - x = convert_to_tensor(x) - return torch.unsqueeze(x, dim=axis) + pass + # return tfnp.expand_dims(x, axis) def expm1(x): - x = convert_to_tensor(x) - return torch.expm1(x) + pass + # return tfnp.expm1(x) def flip(x, axis=None): - x = convert_to_tensor(x) - if axis is None: - axis = tuple(range(len(x.shape))) - if isinstance(axis, int): - axis = (axis,) - return torch.flip(x, dims=axis) + pass + # return tfnp.flip(x, axis=axis) def floor(x): - x = convert_to_tensor(x) - return torch.floor(x) + pass + # return tfnp.floor(x) def full(shape, fill_value, dtype=None): - dtype = to_torch_dtype(dtype) - if hasattr(fill_value, "__len__"): - raise NotImplementedError( - "`torch.full()` only accepts scalars for `fill_value`. " - f"Received: fill_value={fill_value}" - ) - # TODO: implement conversion of shape into repetitions for `torch.tile`` - # return torch.tile(fill_value, reps) - return torch.full(size=shape, fill_value=fill_value, dtype=dtype) + pass + # return tfnp.full(shape, fill_value, dtype=dtype) def full_like(x, fill_value, dtype=None): - dtype = to_torch_dtype(dtype) - if hasattr(fill_value, "__len__"): - raise NotImplementedError( - "`torch.full()` only accepts scalars for `fill_value`." - ) - # TODO: implement conversion of shape into repetitions for `torch.tile`` - # return torch.tile(fill_value, reps) - x = convert_to_tensor(x) - return torch.full_like(input=x, fill_value=fill_value, dtype=dtype) + pass + # return tfnp.full_like(x, fill_value, dtype=dtype) def greater(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.greater(x1, x2) + pass + # return tfnp.greater(x1, x2) def greater_equal(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.greater_equal(x1, x2) + pass + # return tfnp.greater_equal(x1, x2) def hstack(xs): - xs = [convert_to_tensor(x) for x in xs] - return torch.hstack(xs) + pass + # return tfnp.hstack(xs) def identity(n, dtype="float32"): - dtype = to_torch_dtype(dtype) - return torch.eye(n, dtype=dtype) + pass + # return tfnp.identity(n, dtype=dtype) def imag(x): - if not isinstance(x, torch.Tensor): - x = torch.from_numpy(x) # needed for complex type conversion - return torch.imag(x) + pass + # return tfnp.imag(x) def isclose(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.isclose(x1, x2) + pass + # return tfnp.isclose(x1, x2) def isfinite(x): - x = convert_to_tensor(x) - return torch.isfinite(x) + pass + # return tfnp.isfinite(x) def isinf(x): - x = convert_to_tensor(x) - return torch.isinf(x) + pass + # return tfnp.isinf(x) def isnan(x): - x = convert_to_tensor(x) - return torch.isnan(x) + pass + # return tfnp.isnan(x) def less(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.less(x1, x2) + pass + # return tfnp.less(x1, x2) def less_equal(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.less_equal(x1, x2) + pass + # return tfnp.less_equal(x1, x2) def linspace( start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0 ): - if axis != 0: - raise ValueError( - "torch.linspace does not support an `axis` argument. " - f"Received axis={axis}" - ) - dtype = to_torch_dtype(dtype) - if endpoint is False: - stop = stop - ((stop - start) / num) - linspace = torch.linspace( - start=start, - end=stop, - steps=num, - dtype=dtype, - ) - if retstep is True: - return (linspace, num) - return linspace + pass + # return tfnp.linspace( + # start, + # stop, + # num=num, + # endpoint=endpoint, + # retstep=retstep, + # dtype=dtype, + # axis=axis, + # ) def log(x): - x = convert_to_tensor(x) - return torch.log(x) + pass + # return tfnp.log(x) def log10(x): - x = convert_to_tensor(x) - return torch.log10(x) + pass + # return tfnp.log10(x) def log1p(x): - x = convert_to_tensor(x) - return torch.log1p(x) + pass + # return tfnp.log1p(x) def log2(x): - x = convert_to_tensor(x) - return torch.log2(x) + pass + # return tfnp.log2(x) def logaddexp(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.logaddexp(x1, x2) + pass + # return tfnp.logaddexp(x1, x2) def logical_and(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.logical_and(x1, x2) + pass + # return tfnp.logical_and(x1, x2) def logical_not(x): - x = convert_to_tensor(x) - return torch.logical_not(x) + pass + # return tfnp.logical_not(x) def logical_or(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.logical_or(x1, x2) + pass + # return tfnp.logical_or(x1, x2) def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0): - if axis != 0: - raise ValueError( - "torch.logspace does not support an `axis` argument. " - f"Received axis={axis}" - ) - dtype = to_torch_dtype(dtype) - if endpoint is False: - stop = stop - ((stop - start) / num) - logspace = torch.logspace( - start=start, - end=stop, - steps=num, - base=base, - dtype=dtype, - ) - return logspace + pass + # return tfnp.logspace( + # start, + # stop, + # num=num, + # endpoint=endpoint, + # base=base, + # dtype=dtype, + # axis=axis, + # ) def maximum(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.maximum(x1, x2) + pass + # return tfnp.maximum(x1, x2) def meshgrid(*x, indexing="xy"): - x = [convert_to_tensor(sc_tensor) for sc_tensor in x] - result = torch.meshgrid(x, indexing=indexing) - return [arr.numpy() for arr in result] + pass + # return tfnp.meshgrid(*x, indexing=indexing) def min(x, axis=None, keepdims=False, initial=None): - x = convert_to_tensor(x) - if axis is None: - result = torch.min(x) - else: - if isinstance(axis, list): - axis = axis[-1] - result = torch.min(x, dim=axis, keepdim=keepdims) + pass + ## The TensorFlow numpy API implementation doesn't support `initial` so we + ## handle it manually here. + # if initial is not None: + # return tf.math.minimum( + # tfnp.min(x, axis=axis, keepdims=keepdims), initial + # ) - if isinstance(getattr(result, "values", None), torch.Tensor): - result = result.values + ## TensorFlow returns inf by default for an empty list, but for consistency + ## with other backends and the numpy API we want to throw in this case. + # tf.assert_greater( + # size(x), + # tf.constant(0, dtype=tf.int64), + # message="Cannot compute the min of an empty tensor.", + # ) - if initial is not None: - return torch.minimum(result, initial) - return result + # return tfnp.min(x, axis=axis, keepdims=keepdims) def minimum(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.minimum(x1, x2) + pass + # return tfnp.minimum(x1, x2) def mod(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.remainder(x1, x2) + pass + # return tfnp.mod(x1, x2) def moveaxis(x, source, destination): - x = convert_to_tensor(x) - return torch.moveaxis(x, source=source, destination=destination) + pass + # return tfnp.moveaxis(x, source=source, destination=destination) def nan_to_num(x): - x = convert_to_tensor(x) - return torch.nan_to_num(x) + pass + ## Replace NaN with 0 + # x = tf.where(tf.math.is_nan(x), 0, x) + + ## Replace positive infinitiy with dtype.max + # x = tf.where(tf.math.is_inf(x) & (x > 0), x.dtype.max, x) + + ## Replace negative infinity with dtype.min + # x = tf.where(tf.math.is_inf(x) & (x < 0), x.dtype.min, x) + + # return x def ndim(x): - x = convert_to_tensor(x) - return x.ndim + pass + # return tfnp.ndim(x) def nonzero(x): - x = convert_to_tensor(x) - return torch.nonzero(x).T + pass + # return tfnp.nonzero(x) def not_equal(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.not_equal(x1, x2) + pass + # return tfnp.not_equal(x1, x2) def ones_like(x, dtype=None): - x = convert_to_tensor(x) - dtype = to_torch_dtype(dtype) - return torch.ones_like(x, dtype=dtype) + pass + # return tfnp.ones_like(x, dtype=dtype) def outer(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.outer(x1.flatten(), x2.flatten()) + pass + # return tfnp.outer(x1, x2) def pad(x, pad_width, mode="constant"): - x = convert_to_tensor(x) - pad_sum = () - for pad in pad_width: - pad_sum += pad - return torch.nn.functional.pad(x, pad=pad_sum, mode=mode) + pass + # return tfnp.pad(x, pad_width, mode=mode) def prod(x, axis=None, keepdims=False, dtype=None): - x = convert_to_tensor(x) - dtype = to_torch_dtype(dtype) - if axis is None: - return torch.prod(x, dtype=dtype) - elif isinstance(axis, list): - axis = axis[-1] - return torch.prod(x, dim=axis, keepdim=keepdims, dtype=dtype) + pass + # return tfnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype) def ravel(x): - x = convert_to_tensor(x) - return torch.ravel(x) + pass + # return tfnp.ravel(x) def real(x): - x = convert_to_tensor(x) - return torch.real(x) + pass + # return tfnp.real(x) def reciprocal(x): - x = convert_to_tensor(x) - return torch.reciprocal(x) + pass + # return tfnp.reciprocal(x) def repeat(x, repeats, axis=None): - x = convert_to_tensor(x) - repeats = convert_to_tensor(repeats, dtype=int) - return torch.repeat_interleave(x, repeats, dim=axis) + pass + # return tfnp.repeat(x, repeats, axis=axis) def reshape(x, new_shape): - x = convert_to_tensor(x) - return torch.reshape(x, new_shape) + pass + # return tfnp.reshape(x, new_shape) def roll(x, shift, axis=None): - x = convert_to_tensor(x) - return torch.roll(x, shift, dims=axis) + pass + # return tfnp.roll(x, shift, axis=axis) def sign(x): - x = convert_to_tensor(x) - return torch.sign(x) + pass + # return tfnp.sign(x) def sin(x): - x = convert_to_tensor(x) - return torch.sin(x) + pass + # return tfnp.sin(x) def size(x): - x_shape = convert_to_tensor(tuple(x.shape)) - return torch.prod(x_shape) + pass + # return tfnp.size(x) def sort(x, axis=-1): - x = convert_to_tensor(x) - return torch.sort(x, dim=axis).values + pass + # return tfnp.sort(x, axis=axis) def split(x, indices_or_sections, axis=0): - x = convert_to_tensor(x) - return torch.split( - tensor=x, - split_size_or_sections=indices_or_sections, - dim=axis, - ) + pass + # return tfnp.split(x, indices_or_sections, axis=axis) def stack(x, axis=0): - x = [convert_to_tensor(elem) for elem in x] - return torch.stack(x, dim=axis) + pass + # return tfnp.stack(x, axis=axis) def std(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - # Remove Bessel correction to align with numpy - return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False) + pass + # return tfnp.std(x, axis=axis, keepdims=keepdims) def swapaxes(x, axis1, axis2): - x = convert_to_tensor(x) - return torch.swapaxes(x, axis0=axis1, axis1=axis2) + pass + # return tfnp.swapaxes(x, axis1=axis1, axis2=axis2) def take(x, indices, axis=None): - x = convert_to_tensor(x) - indices = convert_to_tensor(indices).long() - if axis is not None: - return torch.index_select(x, dim=axis, index=indices) - return torch.take(x, index=indices) + pass + # return tfnp.take(x, indices, axis=axis) def take_along_axis(x, indices, axis=None): - x = convert_to_tensor(x) - indices = convert_to_tensor(indices).long() - return torch.take_along_dim(x, indices, dim=axis) + pass + # return tfnp.take_along_axis(x, indices, axis=axis) def tan(x): - x = convert_to_tensor(x) - return torch.tan(x) + pass + # return tfnp.tan(x) def tensordot(x1, x2, axes=2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.tensordot(x1, x2, dims=axes) + pass + # return tfnp.tensordot(x1, x2, axes=axes) def round(x, decimals=0): - x = convert_to_tensor(x) - return torch.round(x, decimals=decimals) + pass + # return tfnp.round(x, decimals=decimals) def tile(x, repeats): - x = convert_to_tensor(x) - return torch.tile(x, dims=repeats) + pass + # return tfnp.tile(x, repeats) -def trace(x, offset=None, axis1=None, axis2=None): - x = convert_to_tensor(x) - # TODO: implement support for these arguments - # API divergence between `np.trace()` and `torch.trace()` - if offset or axis1 or axis2: - raise NotImplementedError( - "Arguments not supported by `torch.trace: " - f"offset={offset}, axis1={axis1}, axis2={axis2}" - ) - return torch.trace(x) +def trace(x, offset=0, axis1=0, axis2=1): + pass + # return tfnp.trace(x, offset=offset, axis1=axis1, axis2=axis2) def tri(N, M=None, k=0, dtype="float32"): - dtype = to_torch_dtype(dtype) pass + # return tfnp.tri(N, M=M, k=k, dtype=dtype) def tril(x, k=0): - x = convert_to_tensor(x) - return torch.tril(x, diagonal=k) + pass + # return tfnp.tril(x, k=k) def triu(x, k=0): - x = convert_to_tensor(x) - return torch.triu(x, diagonal=k) + pass + # return tfnp.triu(x, k=k) def vdot(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.vdot(x1, x2) + pass + # return tfnp.vdot(x1, x2) def vstack(xs): - xs = [convert_to_tensor(x) for x in xs] - return torch.vstack(xs) + pass + # return tfnp.vstack(xs) def where(condition, x1, x2): - condition = convert_to_tensor(condition, dtype=bool) - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.where(condition, x1, x2) + pass + # return tfnp.where(condition, x1, x2) def divide(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.divide(x1, x2) + pass + # return tfnp.divide(x1, x2) def true_divide(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.true_divide(x1, x2) + pass + # return tfnp.true_divide(x1, x2) def power(x1, x2): - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.pow(x1, x2) + pass + # return tfnp.power(x1, x2) def negative(x): - x = convert_to_tensor(x) - return torch.negative(x) + pass + # return tfnp.negative(x) def square(x): - x = convert_to_tensor(x) - return torch.square(x) + pass + # return tfnp.square(x) def sqrt(x): - x = convert_to_tensor(x) - return torch.sqrt(x) + pass + # return tfnp.sqrt(x) def squeeze(x, axis=None): - x = convert_to_tensor(x) - if axis is not None: - return torch.squeeze(x, dim=axis) - return torch.squeeze(x) + pass + # return tfnp.squeeze(x, axis=axis) def transpose(x, axes=None): - x = convert_to_tensor(x) - if axes is not None: - return torch.permute(x, dims=axes) - return x.T + pass + # return tfnp.transpose(x, axes=axes) def var(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - # Bessel correction removed for numpy compatibility - return torch.var(x, dim=axis, keepdim=keepdims, correction=0) + pass + # return tfnp.var(x, axis=axis, keepdims=keepdims) def sum(x, axis=None, keepdims=False): - x = convert_to_tensor(x) - if axis is not None: - return torch.sum(x, axis=axis, keepdim=keepdims) - return torch.sum(x) + pass + # return tfnp.sum(x, axis=axis, keepdims=keepdims) -def eye(N, M=None, k=None, dtype="float32"): - # TODO: implement support for `k` diagonal arg, - # does not exist in torch.eye() - if k is not None: - raise NotImplementedError( - "Due to API divergence bewtween `torch.eye` " - "and `np.eye`, the argument k is not supported: " - f"Received: k={k}" - ) - dtype = to_torch_dtype(dtype) - if M is not None: - return torch.eye(n=N, m=M, dtype=dtype) - return torch.eye(n=N, dtype=dtype) +def eye(N, M=None, k=0, dtype="float32"): + pass + # return tfnp.eye(N, M=M, k=k, dtype=dtype) diff --git a/keras_core/layers/reshaping/up_sampling1d_test.py b/keras_core/layers/reshaping/up_sampling1d_test.py index 941755936..606dd73f5 100644 --- a/keras_core/layers/reshaping/up_sampling1d_test.py +++ b/keras_core/layers/reshaping/up_sampling1d_test.py @@ -49,14 +49,19 @@ class UpSamplingTest(testing.TestCase): ) @pytest.mark.skipif( - not backend.DYNAMIC_SHAPES_OK, - reason="Backend does not support dynamic shapes", + not backend.DYNAMIC_BATCH_SIZE_OK, + reason="Backend does not support dynamic batch sizes", ) - def test_upsampling_1d_with_dynamic_shape(self): + def test_upsampling_1d_with_dynamic_batch_size(self): x = KerasTensor([None, 2, 3]) self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3)) self.assertEqual(layers.UpSampling1D(size=4)(x).shape, (None, 8, 3)) + @pytest.mark.skipif( + not backend.DYNAMIC_SHAPES_OK, + reason="Backend does not support dynamic shapes", + ) + def test_upsampling_1d_with_dynamic_shape(self): y = KerasTensor([2, None, 3]) self.assertEqual(layers.UpSampling1D(size=2)(y).shape, (2, None, 3)) self.assertEqual(layers.UpSampling1D(size=4)(y).shape, (2, None, 3)) diff --git a/keras_core/layers/reshaping/up_sampling3d.py b/keras_core/layers/reshaping/up_sampling3d.py index 5a7cace72..1549f76c9 100644 --- a/keras_core/layers/reshaping/up_sampling3d.py +++ b/keras_core/layers/reshaping/up_sampling3d.py @@ -13,7 +13,8 @@ class UpSampling3D(Layer): Repeats the 1st, 2nd and 3rd dimensions of the data by `size[0]`, `size[1]` and `size[2]` respectively. - Examples: + Example: + >>> input_shape = (2, 1, 2, 1, 3) >>> x = np.ones(input_shape) >>> y = keras_core.layers.UpSampling3D(size=(2, 2, 2))(x) @@ -108,6 +109,7 @@ class UpSampling3D(Layer): self, x, depth_factor, height_factor, width_factor, data_format ): """Resizes the volume contained in a 5D tensor. + Args: x: Tensor or variable to resize. depth_factor: Positive integer. @@ -116,11 +118,7 @@ class UpSampling3D(Layer): data_format: One of `"channels_first"`, `"channels_last"`. Returns: - A tensor. - - Raises: - ValueError: if `data_format` is neither - `channels_last` or `channels_first`. + Resized tensor. """ if data_format == "channels_first": output = ops.repeat(x, depth_factor, axis=2) @@ -133,4 +131,4 @@ class UpSampling3D(Layer): output = ops.repeat(output, width_factor, axis=3) return output else: - raise ValueError("Invalid data_format: " + str(data_format)) + raise ValueError(f"Invalid data_format: {data_format}") diff --git a/keras_core/layers/reshaping/zero_padding3d_test.py b/keras_core/layers/reshaping/zero_padding3d_test.py index 682e84c40..1c10d44a8 100644 --- a/keras_core/layers/reshaping/zero_padding3d_test.py +++ b/keras_core/layers/reshaping/zero_padding3d_test.py @@ -65,8 +65,8 @@ class ZeroPaddingTest(testing.TestCase, parameterized.TestCase): self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs) @pytest.mark.skipif( - not backend.DYNAMIC_SHAPES_OK, - reason="Backend does not support dynamic shapes", + not backend.DYNAMIC_BATCH_SIZE_OK, + reason="Backend does not support dynamic batch sizes", ) def test_zero_padding_3d_with_dynamic_batch_size(self): input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5)) diff --git a/keras_core/operations/function_test.py b/keras_core/operations/function_test.py index 42e32e39b..8003aa43e 100644 --- a/keras_core/operations/function_test.py +++ b/keras_core/operations/function_test.py @@ -42,10 +42,11 @@ class FunctionTest(testing.TestCase): self.assertAllClose(y_val[0], np.ones((2, 3)) * 6) self.assertAllClose(y_val[1], np.ones((2, 3)) * 4) + @pytest.mark.skipif( + not backend.DYNAMIC_BATCH_SIZE_OK, + reason="Test only valid if dynamic batch sizes are supported", + ) def test_dynamic_shape_inference(self): - if not backend.DYNAMIC_SHAPES_OK: - pytest.skip("Test only valid for dynamic shape backends") - x = keras_tensor.KerasTensor((None, 3)) y = x**2 fn = function.Function(x, y) diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py index ca929ba22..80f108951 100644 --- a/keras_core/operations/numpy_test.py +++ b/keras_core/operations/numpy_test.py @@ -1665,19 +1665,13 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): y3 = np.ones([1, 5, 4, 2]) self.assertAllClose(np.array(knp.cross(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.cross(x1, y2)), np.cross(x1, y2)) - if backend.backend() != "torch": - # API divergence between `torch.cross` and `np.cross` - # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 - self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3)) - self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3)) + self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3)) + self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3)) self.assertAllClose(np.array(knp.Cross()(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.Cross()(x1, y2)), np.cross(x1, y2)) - if backend.backend() != "torch": - # API divergence between `torch.cross` and `np.cross` - # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3 - self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3)) - self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3)) + self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3)) + self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3)) def test_einsum(self): x = np.arange(24).reshape([2, 3, 4]) @@ -1717,35 +1711,24 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): def test_full_like(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(np.array(knp.full_like(x, 2)), np.full_like(x, 2)) + self.assertAllClose( + np.array(knp.full_like(x, np.ones([2, 3]))), + np.full_like(x, np.ones([2, 3])), + ) self.assertAllClose( np.array(knp.full_like(x, 2, dtype="float32")), np.full_like(x, 2, dtype="float32"), ) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) - self.assertAllClose( - np.array(knp.FullLike()(x, 2, dtype="float32")), - np.full_like(x, 2, dtype="float32"), - ) - - # TODO: implement conversion of shape into repetitions, pass to - # `torch.tile`, since `torch.full()` only accepts scalars - # for `fill_value`." - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.full` only accepts scalars for `fill_value`.", - ) - def test_full_like_without_torch(self): - x = np.array([[1, 2, 3], [3, 2, 1]]) - self.assertAllClose( - np.array(knp.full_like(x, np.ones([2, 3]))), - np.full_like(x, np.ones([2, 3])), - ) - self.assertAllClose( np.array(knp.FullLike()(x, np.ones([2, 3]))), np.full_like(x, np.ones([2, 3])), ) + self.assertAllClose( + np.array(knp.FullLike()(x, 2, dtype="float32")), + np.full_like(x, 2, dtype="float32"), + ) def test_greater(self): x = np.array([[1, 2, 3], [3, 2, 1]]) @@ -1842,14 +1825,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.linspace(0, 10, 5, endpoint=False), ) - # TODO: torch.linspace does not support tensor or array - # for start/stop, create manual implementation - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.linspace` has no support for array start/stop.", - ) - def test_linspace_without_torch(self): - start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -1954,13 +1929,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): np.logspace(0, 10, 5, endpoint=False), ) - # TODO: torch.logspace does not support tensor or array - # for start/stop, create manual implementation - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.logspace` has no support for array start/stop.", - ) - def test_logspace_without_torch(self): start = np.zeros([2, 3, 4]) stop = np.ones([2, 3, 4]) self.assertAllClose( @@ -2036,11 +2004,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase): self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y)) - # TODO: Fix numpy compatibility (squeeze by one dimension only) - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.take` and `np.take` have return shape divergence.", - ) def test_take(self): x = np.arange(24).reshape([1, 2, 3, 4]) indices = np.array([0, 1]) @@ -2804,23 +2767,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.array(knp.pad(x, ((1, 1), (1, 1)))), np.pad(x, ((1, 1), (1, 1))), ) - - self.assertAllClose( - np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1))) - ) - self.assertAllClose( - np.array(knp.Pad(((1, 1), (1, 1)))(x)), - np.pad(x, ((1, 1), (1, 1))), - ) - - # TODO: implement padding with non-constant padding, - # bypass NotImplementedError for PyTorch - @pytest.mark.skipif( - backend.backend() == "torch", - reason="padding not implemented for non-constant use case", - ) - def test_pad_without_torch(self): - x = np.array([[1, 2], [3, 4]]) self.assertAllClose( np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")), np.pad(x, ((1, 1), (1, 1)), mode="reflect"), @@ -2830,6 +2776,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): np.pad(x, ((1, 1), (1, 1)), mode="symmetric"), ) + self.assertAllClose( + np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1))) + ) + self.assertAllClose( + np.array(knp.Pad(((1, 1), (1, 1)))(x)), + np.pad(x, ((1, 1), (1, 1))), + ) self.assertAllClose( np.array(knp.Pad(((1, 1), (1, 1)), mode="reflect")(x)), np.pad(x, ((1, 1), (1, 1)), mode="reflect"), @@ -2952,12 +2905,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0)) - # TODO: implement split for `torch` with support for conversion - # of numpy.split args. - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.split` and `np.split` have return arg divergence.", - ) def test_split(self): x = np.array([[1, 2, 3], [3, 2, 1]]) self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2)) @@ -3025,13 +2972,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.split` does not support args `offset`, `axis1`, `axis2`", - ) def test_trace(self): - # TODO: implement `torch.trace` support for arguments `offset`, - # `axis1`, `axis2` and delete NotImplementedError x = np.arange(24).reshape([1, 2, 3, 4]) self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose( @@ -3071,14 +3012,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3])) self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3])) - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.eye` does not support arg `k`.", - ) def test_eye(self): - # TODO: implement support for `k` diagonal arg, - # does not exist in torch.eye() - self.assertAllClose(np.array(knp.eye(3)), np.eye(3)) self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4)) self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1)) @@ -3101,25 +3035,15 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose( np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) ) - - self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) - self.assertAllClose( - np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) - ) - - # TODO: implement conversion of shape into repetitions, pass to - # `torch.tile`, since `torch.full()` only accepts scalars - # for `fill_value`." - @pytest.mark.skipif( - backend.backend() == "torch", - reason="`torch.full` only accepts scalars for `fill_value`.", - ) - def test_full_without_torch(self): self.assertAllClose( np.array(knp.full([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])), ) + self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0)) + self.assertAllClose( + np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1) + ) self.assertAllClose( np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.full([2, 3], np.array([1, 4, 5])), @@ -3129,11 +3053,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): self.assertAllClose(np.array(knp.identity(3)), np.identity(3)) self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3)) - @pytest.mark.skipif( - backend.backend() == "torch", reason="No torch equivalent for `np.tri`" - ) def test_tri(self): - # TODO: create a manual implementation, as PyTorch has no equivalent self.assertAllClose(np.array(knp.tri(3)), np.tri(3)) self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4)) self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1)) diff --git a/keras_core/utils/feature_space.py b/keras_core/utils/feature_space.py new file mode 100644 index 000000000..8d66f4fb5 --- /dev/null +++ b/keras_core/utils/feature_space.py @@ -0,0 +1,757 @@ +from tensorflow import data as tf_data + +from keras_core import layers +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.layers.layer import Layer +from keras_core.saving import saving_lib +from keras_core.saving import serialization_lib +from keras_core.utils.naming import auto_name + + +class Cross: + def __init__(self, feature_names, crossing_dim, output_mode="one_hot"): + if output_mode not in {"int", "one_hot"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'int', 'one_hot'}. " + f"Received: output_mode={output_mode}" + ) + self.feature_names = tuple(feature_names) + self.crossing_dim = crossing_dim + self.output_mode = output_mode + + @property + def name(self): + return "_X_".join(self.feature_names) + + def get_config(self): + return { + "feature_names": self.feature_names, + "crossing_dim": self.crossing_dim, + "output_mode": self.output_mode, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +class Feature: + def __init__(self, dtype, preprocessor, output_mode): + if output_mode not in {"int", "one_hot", "float"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'int', 'one_hot', 'float'}. " + f"Received: output_mode={output_mode}" + ) + self.dtype = dtype + if isinstance(preprocessor, dict): + preprocessor = serialization_lib.deserialize_keras_object( + preprocessor + ) + self.preprocessor = preprocessor + self.output_mode = output_mode + + def get_config(self): + return { + "dtype": self.dtype, + "preprocessor": serialization_lib.serialize_keras_object( + self.preprocessor + ), + "output_mode": self.output_mode, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + +@keras_core_export("keras_core.utils.FeatureSpace") +class FeatureSpace(Layer): + """One-stop utility for preprocessing and encoding structured data. + + Arguments: + feature_names: Dict mapping the names of your features to their + type specification, e.g. `{"my_feature": "integer_categorical"}` + or `{"my_feature": FeatureSpace.integer_categorical()}`. + For a complete list of all supported types, see + "Available feature types" paragraph below. + output_mode: One of `"concat"` or `"dict"`. In concat mode, all + features get concatenated together into a single vector. + In dict mode, the FeatureSpace returns a dict of individually + encoded features (with the same keys as the input dict keys). + crosses: List of features to be crossed together, e.g. + `crosses=[("feature_1", "feature_2")]`. The features will be + "crossed" by hashing their combined value into + a fixed-length vector. + crossing_dim: Default vector size for hashing crossed features. + Defaults to `32`. + hashing_dim: Default vector size for hashing features of type + `"integer_hashed"` and `"string_hashed"`. Defaults to `32`. + num_discretization_bins: Default number of bins to be used for + discretizing features of type `"float_discretized"`. + Defaults to `32`. + + **Available feature types:** + + Note that all features can be referred to by their string name, + e.g. `"integer_categorical"`. When using the string name, the default + argument values are used. + + ```python + # Plain float values. + FeatureSpace.float(name=None) + + # Float values to be preprocessed via featurewise standardization + # (i.e. via a `keras.layers.Normalization` layer). + FeatureSpace.float_normalized(name=None) + + # Float values to be preprocessed via linear rescaling + # (i.e. via a `keras.layers.Rescaling` layer). + FeatureSpace.float_rescaled(scale=1., offset=0., name=None) + + # Float values to be discretized. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.float_discretized( + num_bins, bin_boundaries=None, output_mode="one_hot", name=None) + + # Integer values to be indexed. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.integer_categorical( + max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) + + # String values to be indexed. By default, the discrete + # representation will then be one-hot encoded. + FeatureSpace.string_categorical( + max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None) + + # Integer values to be hashed into a fixed number of bins. + # By default, the discrete representation will then be one-hot encoded. + FeatureSpace.integer_hashed(num_bins, output_mode="one_hot", name=None) + + # String values to be hashed into a fixed number of bins. + # By default, the discrete representation will then be one-hot encoded. + FeatureSpace.string_hashed(num_bins, output_mode="one_hot", name=None) + ``` + + Examples: + + **Basic usage with a dict of input data:** + + ```python + raw_data = { + "float_values": [0.0, 0.1, 0.2, 0.3], + "string_values": ["zero", "one", "two", "three"], + "int_values": [0, 1, 2, 3], + } + dataset = tf.data.Dataset.from_tensor_slices(raw_data) + + feature_space = FeatureSpace( + features={ + "float_values": "float_normalized", + "string_values": "string_categorical", + "int_values": "integer_categorical", + }, + crosses=[("string_values", "int_values")], + output_mode="concat", + ) + # Before you start using the FeatureSpace, + # you must `adapt()` it on some data. + feature_space.adapt(dataset) + + # You can call the FeatureSpace on a dict of data (batched or unbatched). + output_vector = feature_space(raw_data) + ``` + + **Basic usage with `tf.data`:** + + ```python + # Unlabeled data + preprocessed_ds = unlabeled_dataset.map(feature_space) + + # Labeled data + preprocessed_ds = labeled_dataset.map(lambda x, y: (feature_space(x), y)) + ``` + + **Basic usage with the Keras Functional API:** + + ```python + # Retrieve a dict Keras Input objects + inputs = feature_space.get_inputs() + # Retrieve the corresponding encoded Keras tensors + encoded_features = feature_space.get_encoded_features() + # Build a Functional model + outputs = keras.layers.Dense(1, activation="sigmoid")(encoded_features) + model = keras.Model(inputs, outputs) + ``` + + **Customizing each feature or feature cross:** + + ```python + feature_space = FeatureSpace( + features={ + "float_values": FeatureSpace.float_normalized(), + "string_values": FeatureSpace.string_categorical(max_tokens=10), + "int_values": FeatureSpace.integer_categorical(max_tokens=10), + }, + crosses=[ + FeatureSpace.cross(("string_values", "int_values"), crossing_dim=32) + ], + output_mode="concat", + ) + ``` + + **Returning a dict of integer-encoded features:** + + ```python + feature_space = FeatureSpace( + features={ + "string_values": FeatureSpace.string_categorical(output_mode="int"), + "int_values": FeatureSpace.integer_categorical(output_mode="int"), + }, + crosses=[ + FeatureSpace.cross( + feature_names=("string_values", "int_values"), + crossing_dim=32, + output_mode="int", + ) + ], + output_mode="dict", + ) + ``` + + **Specifying your own Keras preprocessing layer:** + + ```python + # Let's say that one of the features is a short text paragraph that + # we want to encode as a vector (one vector per paragraph) via TF-IDF. + data = { + "text": ["1st string", "2nd string", "3rd string"], + } + + # There's a Keras layer for this: TextVectorization. + custom_layer = layers.TextVectorization(output_mode="tf_idf") + + # We can use FeatureSpace.feature to create a custom feature + # that will use our preprocessing layer. + feature_space = FeatureSpace( + features={ + "text": FeatureSpace.feature( + preprocessor=custom_layer, dtype="string", output_mode="float" + ), + }, + output_mode="concat", + ) + feature_space.adapt(tf.data.Dataset.from_tensor_slices(data)) + output_vector = feature_space(data) + ``` + + **Retrieving the underlying Keras preprocessing layers:** + + ```python + # The preprocessing layer of each feature is available in `.preprocessors`. + preprocessing_layer = feature_space.preprocessors["feature1"] + + # The crossing layer of each feature cross is available in `.crossers`. + # It's an instance of keras.layers.HashedCrossing. + crossing_layer = feature_space.crossers["feature1_X_feature2"] + ``` + + **Saving and reloading a FeatureSpace:** + + ```python + feature_space.save("myfeaturespace.keras") + reloaded_feature_space = keras.models.load_model("myfeaturespace.keras") + ``` + """ + + @classmethod + def cross(cls, feature_names, crossing_dim, output_mode="one_hot"): + return Cross(feature_names, crossing_dim, output_mode=output_mode) + + @classmethod + def feature(cls, dtype, preprocessor, output_mode): + return Feature(dtype, preprocessor, output_mode) + + @classmethod + def float(cls, name=None): + from keras.layers.core import identity + + name = name or auto_name("float") + preprocessor = identity.Identity( + dtype="float32", name=f"{name}_preprocessor" + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_rescaled(cls, scale=1.0, offset=0.0, name=None): + name = name or auto_name("float_rescaled") + preprocessor = layers.Rescaling( + scale=scale, offset=offset, name=f"{name}_preprocessor" + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_normalized(cls, name=None): + name = name or auto_name("float_normalized") + preprocessor = layers.Normalization( + axis=-1, name=f"{name}_preprocessor" + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode="float" + ) + + @classmethod + def float_discretized( + cls, num_bins, bin_boundaries=None, output_mode="one_hot", name=None + ): + name = name or auto_name("float_discretized") + preprocessor = layers.Discretization( + num_bins=num_bins, + bin_boundaries=bin_boundaries, + name=f"{name}_preprocessor", + ) + return Feature( + dtype="float32", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def integer_categorical( + cls, + max_tokens=None, + num_oov_indices=1, + output_mode="one_hot", + name=None, + ): + name = name or auto_name("integer_categorical") + preprocessor = layers.IntegerLookup( + name=f"{name}_preprocessor", + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + ) + return Feature( + dtype="int64", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def string_categorical( + cls, + max_tokens=None, + num_oov_indices=1, + output_mode="one_hot", + name=None, + ): + name = name or auto_name("string_categorical") + preprocessor = layers.StringLookup( + name=f"{name}_preprocessor", + max_tokens=max_tokens, + num_oov_indices=num_oov_indices, + ) + return Feature( + dtype="string", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def string_hashed(cls, num_bins, output_mode="one_hot", name=None): + name = name or auto_name("string_hashed") + preprocessor = layers.Hashing( + name=f"{name}_preprocessor", num_bins=num_bins + ) + return Feature( + dtype="string", preprocessor=preprocessor, output_mode=output_mode + ) + + @classmethod + def integer_hashed(cls, num_bins, output_mode="one_hot", name=None): + name = name or auto_name("integer_hashed") + preprocessor = layers.Hashing( + name=f"{name}_preprocessor", num_bins=num_bins + ) + return Feature( + dtype="int64", preprocessor=preprocessor, output_mode=output_mode + ) + + def __init__( + self, + features, + output_mode="concat", + crosses=None, + crossing_dim=32, + hashing_dim=32, + num_discretization_bins=32, + name=None, + ): + super().__init__(name=name) + if not features: + raise ValueError("The `features` argument cannot be None or empty.") + self.crossing_dim = crossing_dim + self.hashing_dim = hashing_dim + self.num_discretization_bins = num_discretization_bins + self.features = { + name: self._standardize_feature(name, value) + for name, value in features.items() + } + self.crosses = [] + if crosses: + feature_set = set(features.keys()) + for cross in crosses: + if isinstance(cross, dict): + cross = serialization_lib.deserialize_keras_object(cross) + if isinstance(cross, Cross): + self.crosses.append(cross) + else: + if not crossing_dim: + raise ValueError( + "When specifying `crosses`, the argument " + "`crossing_dim` " + "(dimensionality of the crossing space) " + "should be specified as well." + ) + for key in cross: + if key not in feature_set: + raise ValueError( + "All features referenced " + "in the `crosses` argument " + "should be present in the `features` dict. " + f"Received unknown features: {cross}" + ) + self.crosses.append(Cross(cross, crossing_dim=crossing_dim)) + self.crosses_by_name = {cross.name: cross for cross in self.crosses} + + if output_mode not in {"dict", "concat"}: + raise ValueError( + "Invalid value for argument `output_mode`. " + "Expected one of {'dict', 'concat'}. " + f"Received: output_mode={output_mode}" + ) + self.output_mode = output_mode + + self.inputs = { + name: self._feature_to_input(name, value) + for name, value in self.features.items() + } + self.preprocessors = { + name: value.preprocessor for name, value in self.features.items() + } + self.encoded_features = None + self.crossers = { + cross.name: self._cross_to_crosser(cross) for cross in self.crosses + } + self.one_hot_encoders = {} + self.built = False + self._is_adapted = False + self.concat = None + self._preprocessed_features_names = None + self._crossed_features_names = None + + def _feature_to_input(self, name, feature): + return layers.Input(shape=(1,), dtype=feature.dtype, name=name) + + def _standardize_feature(self, name, feature): + if isinstance(feature, Feature): + return feature + + if isinstance(feature, dict): + return serialization_lib.deserialize_keras_object(feature) + + if feature == "float": + return self.float(name=name) + elif feature == "float_normalized": + return self.float_normalized(name=name) + elif feature == "float_rescaled": + return self.float_rescaled(name=name) + elif feature == "float_discretized": + return self.float_discretized( + name=name, num_bins=self.num_discretization_bins + ) + elif feature == "integer_categorical": + return self.integer_categorical(name=name) + elif feature == "string_categorical": + return self.string_categorical(name=name) + elif feature == "integer_hashed": + return self.integer_hashed(self.hashing_dim, name=name) + elif feature == "string_hashed": + return self.string_hashed(self.hashing_dim, name=name) + else: + raise ValueError(f"Invalid feature type: {feature}") + + def _cross_to_crosser(self, cross): + return layers.HashedCrossing(cross.crossing_dim, name=cross.name) + + def _list_adaptable_preprocessors(self): + adaptable_preprocessors = [] + for name in self.features.keys(): + preprocessor = self.preprocessors[name] + # Special case: a Normalization layer with preset mean/variance. + # Not adaptable. + if isinstance(preprocessor, layers.Normalization): + if preprocessor.input_mean is not None: + continue + if hasattr(preprocessor, "adapt"): + adaptable_preprocessors.append(name) + return adaptable_preprocessors + + def adapt(self, dataset): + if not isinstance(dataset, tf_data.Dataset): + raise ValueError( + "`adapt()` can only be called on a tf.data.Dataset. " + f"Received instead: {dataset} (of type {type(dataset)})" + ) + + for name in self._list_adaptable_preprocessors(): + # Call adapt() on each individual adaptable layer. + + # TODO: consider rewriting this to instead iterate on the + # dataset once, split each batch into individual features, + # and call the layer's `_adapt_function` on each batch + # to simulate the behavior of adapt() in a more performant fashion. + + feature_dataset = dataset.map(lambda x: x[name]) + preprocessor = self.preprocessors[name] + # TODO: consider adding an adapt progress bar. + # Sample 1 element to check the rank + for x in feature_dataset.take(1): + pass + if x.shape.rank == 0: + # The dataset yields unbatched scalars; batch it. + feature_dataset = feature_dataset.batch(32) + if x.shape.rank in {0, 1}: + # If the rank is 1, add a dimension + # so we can reduce on axis=-1. + # Note: if rank was previously 0, it is now 1. + feature_dataset = feature_dataset.map( + lambda x: ops.expand_dims(x, -1) + ) + preprocessor.adapt(feature_dataset) + self._is_adapted = True + self.get_encoded_features() # Finish building the layer + self.built = True + + def get_inputs(self): + self._check_if_built() + return self.inputs + + def get_encoded_features(self): + self._check_if_adapted() + + if self.encoded_features is None: + preprocessed_features = self._preprocess_features(self.inputs) + crossed_features = self._cross_features(preprocessed_features) + merged_features = self._merge_features( + preprocessed_features, crossed_features + ) + self.encoded_features = merged_features + return self.encoded_features + + def _preprocess_features(self, features): + return { + name: self.preprocessors[name](features[name]) + for name in features.keys() + } + + def _cross_features(self, features): + all_outputs = {} + for cross in self.crosses: + inputs = [features[name] for name in cross.feature_names] + outputs = self.crossers[cross.name](inputs) + all_outputs[cross.name] = outputs + return all_outputs + + def _merge_features(self, preprocessed_features, crossed_features): + if not self._preprocessed_features_names: + self._preprocessed_features_names = sorted( + preprocessed_features.keys() + ) + self._crossed_features_names = sorted(crossed_features.keys()) + + all_names = ( + self._preprocessed_features_names + self._crossed_features_names + ) + all_features = [ + preprocessed_features[name] + for name in self._preprocessed_features_names + ] + [crossed_features[name] for name in self._crossed_features_names] + + if self.output_mode == "dict": + output_dict = {} + else: + features_to_concat = [] + + if self.built: + # Fast mode. + for name, feature in zip(all_names, all_features): + encoder = self.one_hot_encoders.get(name, None) + if encoder: + feature = encoder(feature) + if self.output_mode == "dict": + output_dict[name] = feature + else: + features_to_concat.append(feature) + if self.output_mode == "dict": + return output_dict + else: + return self.concat(features_to_concat) + + # If the object isn't built, + # we create the encoder and concat layers below + all_specs = [ + self.features[name] for name in self._preprocessed_features_names + ] + [ + self.crosses_by_name[name] for name in self._crossed_features_names + ] + for name, feature, spec in zip(all_names, all_features, all_specs): + dtype = feature.dtype + + if spec.output_mode == "one_hot": + preprocessor = self.preprocessors.get( + name + ) or self.crossers.get(name) + cardinality = None + if not feature.dtype.startswith("int"): + raise ValueError( + f"Feature '{name}' has `output_mode='one_hot'`. " + "Thus its preprocessor should return an int64 dtype. " + f"Instead it returns a {dtype} dtype." + ) + + if isinstance( + preprocessor, (layers.IntegerLookup, layers.StringLookup) + ): + cardinality = preprocessor.vocabulary_size() + elif isinstance(preprocessor, layers.CategoryEncoding): + cardinality = preprocessor.num_tokens + elif isinstance(preprocessor, layers.Discretization): + cardinality = preprocessor.num_bins + elif isinstance( + preprocessor, (layers.HashedCrossing, layers.Hashing) + ): + cardinality = preprocessor.num_bins + else: + raise ValueError( + f"Feature '{name}' has `output_mode='one_hot'`. " + "However it isn't a standard feature and the " + "dimensionality of its output space is not known, " + "thus it cannot be one-hot encoded. " + "Try using `output_mode='int'`." + ) + if cardinality is not None: + encoder = layers.CategoryEncoding( + num_tokens=cardinality, output_mode="multi_hot" + ) + self.one_hot_encoders[name] = encoder + feature = encoder(feature) + + if self.output_mode == "concat": + dtype = feature.dtype + if dtype.startswith("int") or dtype == "string": + raise ValueError( + f"Cannot concatenate features because feature '{name}' " + f"has not been encoded (it has dtype {dtype}). " + "Consider using `output_mode='dict'`." + ) + features_to_concat.append(feature) + else: + output_dict[name] = feature + + if self.output_mode == "concat": + self.concat = layers.Concatenate(axis=-1) + return self.concat(features_to_concat) + else: + return output_dict + + def _check_if_adapted(self): + if not self._is_adapted: + if not self._list_adaptable_preprocessors(): + self._is_adapted = True + else: + raise ValueError( + "You need to call `.adapt(dataset)` on the FeatureSpace " + "before you can start using it." + ) + + def _check_if_built(self): + if not self.built: + self._check_if_adapted() + # Finishes building + self.get_encoded_features() + self.built = True + + def __call__(self, data): + self._check_if_built() + if not isinstance(data, dict): + raise ValueError( + "A FeatureSpace can only be called with a dict. " + f"Received: data={data} (of type {type(data)}" + ) + + data = { + key: ops.convert_to_tensor(value) for key, value in data.items() + } + rebatched = False + for name, x in data.items(): + if x.shape.rank == 0: + data[name] = ops.reshape(x, (1, 1)) + rebatched = True + elif x.shape.rank == 1: + data[name] = ops.expand_dims(x, -1) + + preprocessed_data = self._preprocess_features(data) + crossed_data = self._cross_features(preprocessed_data) + merged_data = self._merge_features(preprocessed_data, crossed_data) + if rebatched: + if self.output_mode == "concat": + assert merged_data.shape[0] == 1 + return ops.squeeze(merged_data, axis=0) + else: + for name, x in merged_data.items(): + if x.shape.rank == 2 and x.shape[0] == 1: + merged_data[name] = ops.squeeze(x, axis=0) + return merged_data + + def get_config(self): + return { + "features": serialization_lib.serialize_keras_object(self.features), + "output_mode": self.output_mode, + "crosses": serialization_lib.serialize_keras_object(self.crosses), + "crossing_dim": self.crossing_dim, + "hashing_dim": self.hashing_dim, + "num_discretization_bins": self.num_discretization_bins, + } + + @classmethod + def from_config(cls, config): + return cls(**config) + + def get_build_config(self): + return { + name: feature.preprocessor.get_build_config() + for name, feature in self.features.items() + } + + def build_from_config(self, config): + for name in config.keys(): + self.features[name].preprocessor.build_from_config(config[name]) + self._is_adapted = True + + def save(self, filepath): + """Save the `FeatureSpace` instance to a `.keras` file. + + You can reload it via `keras.models.load_model()`: + + ```python + feature_space.save("myfeaturespace.keras") + reloaded_feature_space = keras.models.load_model("myfeaturespace.keras") + ``` + """ + saving_lib.save_model(self, filepath) + + def save_own_variables(self, store): + return + + def load_own_variables(self, store): + return diff --git a/keras_core/utils/feature_space_test.py b/keras_core/utils/feature_space_test.py new file mode 100644 index 000000000..27a5483c3 --- /dev/null +++ b/keras_core/utils/feature_space_test.py @@ -0,0 +1,378 @@ +# from keras_core import testing +# from keras_core.utils import feature_space +# from keras_core import operations as ops +# from tensorflow import nest +# from tensorflow import data as tf_data +# from keras_core import layers +# from keras_core import models +# import os + + +# class FeatureSpaceTest(testing.TestCase): +# def _get_train_data_dict( +# self, as_dataset=False, as_tf_tensors=False, as_labeled_dataset=False +# ): +# data = { +# "float_1": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], +# "float_2": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], +# "float_3": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], +# "string_1": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], +# "string_2": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"], +# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +# "int_2": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +# "int_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +# } +# if as_dataset: +# return tf_data.Dataset.from_tensor_slices(data) +# elif as_tf_tensors: +# return nest.map_structure(ops.convert_to_tensor, data) +# elif as_labeled_dataset: +# labels = [0, 1, 0, 1, 0, 0, 1, 0, 1, 1] +# return tf_data.Dataset.from_tensor_slices((data, labels)) +# return data + +# def test_basic_usage(self): +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": "float", +# "float_2": "float_normalized", +# "float_3": "float_discretized", +# "string_1": "string_categorical", +# "string_2": "string_hashed", +# "int_1": "integer_categorical", +# "int_2": "integer_hashed", +# "int_3": "integer_categorical", +# }, +# crosses=[("float_3", "string_1"), ("string_2", "int_2")], +# output_mode="concat", +# ) +# # Test unbatched adapt +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# # Test batched adapt +# fs.adapt(self._get_train_data_dict(as_dataset=True).batch(4)) + +# # Test unbatched call on raw data +# data = { +# key: value[0] for key, value in self._get_train_data_dict().items() +# } +# out = fs(data) +# self.assertEqual(out.shape, [195]) + +# # Test unbatched call on TF tensors +# data = self._get_train_data_dict(as_tf_tensors=True) +# data = {key: value[0] for key, value in data.items()} +# out = fs(data) +# self.assertEqual(out.shape, [195]) + +# # Test batched call on raw data +# out = fs(self._get_train_data_dict()) +# self.assertEqual(out.shape, [10, 195]) + +# # Test batched call on TF tensors +# out = fs(self._get_train_data_dict(as_tf_tensors=True)) +# self.assertEqual(out.shape, [10, 195]) + +# def test_output_mode_dict(self): +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": "float", +# "float_2": "float_normalized", +# "float_3": "float_discretized", +# "string_1": "string_categorical", +# "string_2": "string_hashed", +# "int_1": "integer_categorical", +# "int_2": "integer_hashed", +# "int_3": "integer_categorical", +# }, +# crosses=[("float_3", "string_1"), ("string_2", "int_2")], +# output_mode="dict", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) + +# # Test unbatched call on raw data +# data = { +# key: value[0] for key, value in self._get_train_data_dict().items() +# } +# out = fs(data) +# self.assertIsInstance(out, dict) +# self.assertLen(out, 10) +# self.assertEqual(out["string_1"].shape, [11]) +# self.assertEqual(out["int_2"].shape, [32]) +# self.assertEqual(out["string_2_X_int_2"].shape, [32]) + +# # Test batched call on raw data +# out = fs(self._get_train_data_dict()) +# self.assertIsInstance(out, dict) +# self.assertLen(out, 10) +# self.assertEqual(out["string_1"].shape, [10, 11]) +# self.assertEqual(out["int_2"].shape, [10, 32]) +# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32]) + +# # Test batched call on TF tensors +# out = fs(self._get_train_data_dict(as_tf_tensors=True)) +# self.assertIsInstance(out, dict) +# self.assertLen(out, 10) +# self.assertEqual(out["string_1"].shape, [10, 11]) +# self.assertEqual(out["int_2"].shape, [10, 32]) +# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32]) + +# def test_output_mode_dict_of_ints(self): +# cls = feature_space.FeatureSpace +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": "float", +# "float_2": "float_normalized", +# "float_3": "float_discretized", +# "string_1": cls.string_categorical(output_mode="int"), +# "string_2": cls.string_hashed(num_bins=32, output_mode="int"), +# "int_1": cls.integer_categorical(output_mode="int"), +# "int_2": cls.integer_hashed(num_bins=32, output_mode="int"), +# "int_3": cls.integer_categorical(output_mode="int"), +# }, +# crosses=[ +# cls.cross( +# ("float_3", "string_1"), output_mode="int", crossing_dim=32 +# ), +# cls.cross( +# ("string_2", "int_2"), output_mode="int", crossing_dim=32 +# ), +# ], +# output_mode="dict", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# data = { +# key: value[0] for key, value in self._get_train_data_dict().items() +# } +# out = fs(data) +# self.assertIsInstance(out, dict) +# self.assertLen(out, 10) +# self.assertEqual(out["string_1"].shape, [1]) +# self.assertEqual(out["string_1"].dtype.name, "int64") +# self.assertEqual(out["int_2"].shape, [1]) +# self.assertEqual(out["int_2"].dtype.name, "int64") +# self.assertEqual(out["string_2_X_int_2"].shape, [1]) +# self.assertEqual(out["string_2_X_int_2"].dtype.name, "int64") + +# def test_functional_api_sync_processing(self): +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": "float", +# "float_2": "float_normalized", +# "float_3": "float_discretized", +# "string_1": "string_categorical", +# "string_2": "string_hashed", +# "int_1": "integer_categorical", +# "int_2": "integer_hashed", +# "int_3": "integer_categorical", +# }, +# crosses=[("float_3", "string_1"), ("string_2", "int_2")], +# output_mode="concat", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# inputs = fs.get_inputs() +# features = fs.get_encoded_features() +# outputs = layers.Dense(1)(features) +# model = models.Model(inputs=inputs, outputs=outputs) +# model.compile("adam", "mse") +# ds = self._get_train_data_dict(as_labeled_dataset=True) +# model.fit(ds.batch(4)) +# model.evaluate(ds.batch(4)) +# ds = self._get_train_data_dict(as_dataset=True) +# model.predict(ds.batch(4)) + +# def test_tf_data_async_processing(self): +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": "float", +# "float_2": "float_normalized", +# "float_3": "float_discretized", +# "string_1": "string_categorical", +# "string_2": "string_hashed", +# "int_1": "integer_categorical", +# "int_2": "integer_hashed", +# "int_3": "integer_categorical", +# }, +# crosses=[("float_3", "string_1"), ("string_2", "int_2")], +# output_mode="concat", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# features = fs.get_encoded_features() +# outputs = layers.Dense(1)(features) +# model = models.Model(inputs=features, outputs=outputs) +# model.compile("adam", "mse") +# ds = self._get_train_data_dict(as_labeled_dataset=True) +# # Try map before batch +# ds = ds.map(lambda x, y: (fs(x), y)) +# model.fit(ds.batch(4)) +# # Try map after batch +# ds = self._get_train_data_dict(as_labeled_dataset=True) +# ds = ds.batch(4) +# ds = ds.map(lambda x, y: (fs(x), y)) +# model.evaluate(ds) +# ds = self._get_train_data_dict(as_dataset=True) +# ds = ds.map(fs) +# model.predict(ds.batch(4)) + +# def test_advanced_usage(self): +# cls = feature_space.FeatureSpace +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": cls.float(), +# "float_2": cls.float_normalized(), +# "float_3": cls.float_discretized(num_bins=3), +# "string_1": cls.string_categorical(max_tokens=5), +# "string_2": cls.string_hashed(num_bins=32), +# "int_1": cls.integer_categorical( +# max_tokens=5, num_oov_indices=2 +# ), +# "int_2": cls.integer_hashed(num_bins=32), +# "int_3": cls.integer_categorical(max_tokens=5), +# }, +# crosses=[ +# cls.cross(("float_3", "string_1"), crossing_dim=32), +# cls.cross(("string_2", "int_2"), crossing_dim=32), +# ], +# output_mode="concat", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# data = { +# key: value[0] for key, value in self._get_train_data_dict().items() +# } +# out = fs(data) +# self.assertEqual(out.shape, [148]) + +# def test_manual_kpl(self): +# data = { +# "text": ["1st string", "2nd string", "3rd string"], +# } +# cls = feature_space.FeatureSpace + +# # Test with a tf-idf TextVectorization layer +# tv = layers.TextVectorization(output_mode="tf_idf") +# fs = feature_space.FeatureSpace( +# features={ +# "text": cls.feature( +# preprocessor=tv, dtype="string", output_mode="float" +# ), +# }, +# output_mode="concat", +# ) +# fs.adapt(tf_data.Dataset.from_tensor_slices(data)) +# out = fs(data) +# self.assertEqual(out.shape, [3, 5]) + +# def test_no_adapt(self): +# data = { +# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], +# } +# fs = feature_space.FeatureSpace( +# { +# "int_1": "integer_hashed", +# }, +# output_mode="concat", +# ) +# out = fs(data) +# self.assertEqual(out.shape, [10, 32]) + +# def test_saving(self): +# cls = feature_space.FeatureSpace +# fs = feature_space.FeatureSpace( +# features={ +# "float_1": cls.float(), +# "float_2": cls.float_normalized(), +# "float_3": cls.float_discretized(num_bins=3), +# "string_1": cls.string_categorical(max_tokens=5), +# "string_2": cls.string_hashed(num_bins=32), +# "int_1": cls.integer_categorical( +# max_tokens=5, num_oov_indices=2 +# ), +# "int_2": cls.integer_hashed(num_bins=32), +# "int_3": cls.integer_categorical(max_tokens=5), +# }, +# crosses=[ +# cls.cross(("float_3", "string_1"), crossing_dim=32), +# cls.cross(("string_2", "int_2"), crossing_dim=32), +# ], +# output_mode="concat", +# ) +# fs.adapt(self._get_train_data_dict(as_dataset=True)) +# data = { +# key: value[0] for key, value in self._get_train_data_dict().items() +# } +# ref_out = fs(data) + +# temp_filepath = os.path.join(self.get_temp_dir(), "fs.keras") +# fs.save(temp_filepath) +# fs = models.models.load_model(temp_filepath) + +# # Save again immediately after loading to test idempotency +# temp_filepath = os.path.join(self.get_temp_dir(), "fs2.keras") +# fs.save(temp_filepath) + +# # Test correctness of the first saved FS +# out = fs(data) +# self.assertAllClose(out, ref_out) + +# inputs = fs.get_inputs() +# outputs = fs.get_encoded_features() +# model = models.Model(inputs=inputs, outputs=outputs) +# ds = self._get_train_data_dict(as_dataset=True) +# out = model.predict(ds.batch(4)) +# self.assertAllClose(out[0], ref_out) + +# # Test correctness of the re-saved FS +# fs = models.models.load_model(temp_filepath) +# out = fs(data) +# self.assertAllClose(out, ref_out) + +# def test_errors(self): +# # Test no features +# with self.assertRaisesRegex(ValueError, "cannot be None or empty"): +# feature_space.FeatureSpace(features={}) +# # Test no crossing dim +# with self.assertRaisesRegex(ValueError, "`crossing_dim`"): +# feature_space.FeatureSpace( +# features={ +# "f1": "integer_categorical", +# "f2": "integer_categorical", +# }, +# crosses=[("f1", "f2")], +# crossing_dim=None, +# ) +# # Test wrong cross feature name +# with self.assertRaisesRegex(ValueError, "should be present in "): +# feature_space.FeatureSpace( +# features={ +# "f1": "integer_categorical", +# "f2": "integer_categorical", +# }, +# crosses=[("f1", "unknown")], +# crossing_dim=32, +# ) +# # Test wrong output mode +# with self.assertRaisesRegex(ValueError, "for argument `output_mode`"): +# feature_space.FeatureSpace( +# features={ +# "f1": "integer_categorical", +# "f2": "integer_categorical", +# }, +# output_mode="unknown", +# ) +# # Test call before adapt +# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): +# fs = feature_space.FeatureSpace( +# features={ +# "f1": "integer_categorical", +# "f2": "integer_categorical", +# } +# ) +# fs({"f1": [0], "f2": [0]}) +# # Test get_encoded_features before adapt +# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"): +# fs = feature_space.FeatureSpace( +# features={ +# "f1": "integer_categorical", +# "f2": "integer_categorical", +# } +# ) +# fs.get_encoded_features()