Checkpoint

This commit is contained in:
Francois Chollet 2023-05-18 15:07:14 -07:00
parent afa8882b3d
commit 1baec319b0
18 changed files with 1616 additions and 704 deletions

@ -26,7 +26,7 @@ class MyDense(layers.Layer):
# You can also use add_weight # You can also use add_weight
self.b = self.add_weight( self.b = self.add_weight(
shape=(self.units,), shape=(self.units,),
initializer="zeros", initializer=initializers.Zeros(),
name="bias", name="bias",
trainable=True, trainable=True,
) )

@ -6,7 +6,7 @@ from keras_core import losses
from keras_core import metrics from keras_core import metrics
from keras_core import optimizers from keras_core import optimizers
inputs = layers.Input((100,), batch_size=32) inputs = layers.Input((100,))
x = layers.Dense(256, activation="relu")(inputs) x = layers.Dense(256, activation="relu")(inputs)
residual = x residual = x
x = layers.Dense(256, activation="relu")(x) x = layers.Dense(256, activation="relu")(x)
@ -27,9 +27,18 @@ model.compile(
loss=losses.MeanSquaredError(), loss=losses.MeanSquaredError(),
metrics=[metrics.CategoricalAccuracy(name="acc"), metrics.MeanSquaredError(name="mse")], metrics=[metrics.CategoricalAccuracy(name="acc"), metrics.MeanSquaredError(name="mse")],
) )
print("\nTrain model")
history = model.fit( history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2 x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
) )
print("\nHistory:")
print("History:")
print(history.history) print(history.history)
print("\nEvaluate model")
scores = model.evaluate(x, y, return_dict=True)
print(scores)
print("\nRun inference")
pred = model.predict(x)
print(f"Inferred output shape {pred.shape}")

@ -9,10 +9,11 @@ class KerasTensor:
def __init__(self, shape, dtype="float32", record_history=True, name=None): def __init__(self, shape, dtype="float32", record_history=True, name=None):
from keras_core import backend from keras_core import backend
if backend.DYNAMIC_SHAPES_OK: shape = backend.standardize_shape(
shape = backend.standardize_shape(shape, fully_defined=False) shape,
else: allow_dynamic_batch_size=backend.DYNAMIC_BATCH_SIZE_OK,
shape = backend.standardize_shape(shape, fully_defined=True) allow_all_dynamic=backend.DYNAMIC_SHAPES_OK,
)
self.shape = shape self.shape = shape
self.dtype = backend.standardize_dtype(dtype) self.dtype = backend.standardize_dtype(dtype)
self.name = name or auto_name(self.__class__.__name__) self.name = name or auto_name(self.__class__.__name__)

@ -1,13 +1,12 @@
from keras_core.api_export import keras_core_export from keras_core.api_export import keras_core_export
from keras_core.backend.common import global_state from keras_core.backend.common import global_state
from keras_core.backend.common.variables import KerasVariable
from keras_core.backend.common.variables import initialize_all_variables
@keras_core_export("keras_core.StatelessScope") @keras_core_export("keras_core.StatelessScope")
class StatelessScope: class StatelessScope:
def __init__(self, state_mapping=None, collect_losses=False): def __init__(self, state_mapping=None, collect_losses=False):
from keras_core import backend from keras_core import backend
from keras_core.backend.common.variables import KerasVariable
self.collect_losses = collect_losses self.collect_losses = collect_losses
self.losses = [] self.losses = []
@ -54,6 +53,10 @@ class StatelessScope:
# We're back in eager scope; # We're back in eager scope;
# if any variables were created within the stateless # if any variables were created within the stateless
# scope, we initialize them here. # scope, we initialize them here.
from keras_core.backend.common.variables import (
initialize_all_variables,
)
initialize_all_variables() initialize_all_variables()

@ -1,5 +1,9 @@
import numpy as np
from keras_core.backend import config from keras_core.backend import config
from keras_core.backend.common import global_state from keras_core.backend.common import global_state
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
from keras_core.utils.naming import auto_name from keras_core.utils.naming import auto_name
@ -20,7 +24,6 @@ class KerasVariable:
f"Received: initializer={initializer} " f"Received: initializer={initializer} "
f"and shape={shape}" f"and shape={shape}"
) )
from keras_core.backend.common.stateless_scope import in_stateless_scope
if in_stateless_scope(): if in_stateless_scope():
if callable(initializer): if callable(initializer):
@ -76,6 +79,42 @@ class KerasVariable:
return autocast_scope.maybe_cast(value) return autocast_scope.maybe_cast(value)
return value return value
def numpy(self):
return np.array(self.value)
@property
def value(self):
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
if value is not None:
return self._maybe_autocast(value)
if self._value is None:
# Unitialized variable. Return a placeholder.
# This is fine because it's only ever used
# in during shape inference / graph tracing
# (anything else would be a bug, to be fixed.)
return self._maybe_autocast(
self._initializer(self._shape, dtype=self._dtype)
)
return self._maybe_autocast(self._value)
def assign(self, value):
value = self._convert_to_tensor(value, dtype=self.dtype)
if value.shape != self.value.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
"`variable.assign(value)` must match. "
f"Received: value.shape={value.shape}; "
f"variable.shape={self.value.shape}"
)
if in_stateless_scope():
scope = get_stateless_scope()
scope.add_update((self, value))
else:
self._direct_assign(value)
@property @property
def dtype(self): def dtype(self):
autocast_scope = get_autocast_scope() autocast_scope = get_autocast_scope()
@ -100,16 +139,6 @@ class KerasVariable:
def _initialize(self, value): def _initialize(self, value):
raise NotImplementedError raise NotImplementedError
@property
def value(self):
raise NotImplementedError
def numpy(self):
raise NotImplementedError
def assign(self, value):
raise NotImplementedError
def _convert_to_tensor(self, value, dtype=None): def _convert_to_tensor(self, value, dtype=None):
raise NotImplementedError raise NotImplementedError
@ -336,50 +365,59 @@ ALLOWED_DTYPES = {
"int64", "int64",
"bfloat16", "bfloat16",
"bool", "bool",
} "string",
PYTHON_DTYPES_MAP = {
bool: "bool",
int: "int", # TBD by backend
float: "float32",
} }
def standardize_dtype(dtype): def standardize_dtype(dtype):
if dtype is None:
return config.floatx()
if dtype in PYTHON_DTYPES_MAP:
dtype = PYTHON_DTYPES_MAP.get(dtype)
if dtype == "int": if dtype == "int":
if config.backend() == "tensorflow": if config.backend() == "tensorflow":
dtype = "int64" dtype = "int64"
else: else:
dtype = "int32" dtype = "int32"
if dtype is None:
return config.floatx()
if hasattr(dtype, "name"): if hasattr(dtype, "name"):
dtype = dtype.name dtype = dtype.name
if dtype not in ALLOWED_DTYPES: if dtype not in ALLOWED_DTYPES:
raise ValueError(f"Invalid dtype: {dtype}") raise ValueError(f"Invalid dtype: {dtype}")
return dtype return dtype
def standardize_shape(shape, fully_defined=False): def standardize_shape(
shape, allow_dynamic_batch_size=True, allow_all_dynamic=True
):
if not isinstance(shape, tuple): if not isinstance(shape, tuple):
if shape is None: if shape is None:
raise ValueError("Undefined shapes are not supported.") raise ValueError("Undefined shapes are not supported.")
if not hasattr(shape, "__iter__"): if not hasattr(shape, "__iter__"):
raise ValueError(f"Cannot convert '{shape}' to a shape.") raise ValueError(f"Cannot convert '{shape}' to a shape.")
shape = tuple(shape) shape = tuple(shape)
for e in shape:
if not fully_defined and e is None: for i, e in enumerate(shape):
if i == 0 and allow_dynamic_batch_size and e is None:
continue
if allow_all_dynamic and e is None:
continue continue
if not isinstance(e, int): if not isinstance(e, int):
raise ValueError( msg = (
f"Cannot convert '{shape}' to a shape. " f"Cannot convert '{shape}' to a shape. "
f"Found invalid entry '{e}'. Only " f"Found invalid entry '{e}'. "
"fully-defined shapes are allowed with the "
f"{config.backend()} backend."
) )
if allow_dynamic_batch_size:
msg += (
"Dynamic shapes (shapes with `None` entries) "
f"are not allowed with the {config.backend()}, "
"except for the batch size (axis 0)."
)
else:
msg += (
"Dynamic shapes (shapes with `None` entries) "
f"are not allowed with the {config.backend()}. "
"All dimensions should be positive integers, "
"including the batch size (axis 0)."
)
raise ValueError(msg)
if e < 0: if e < 0:
raise ValueError( raise ValueError(
f"Cannot convert '{shape}' to a shape. " f"Cannot convert '{shape}' to a shape. "

@ -4,6 +4,7 @@ from keras_core.backend.jax import math
from keras_core.backend.jax import nn from keras_core.backend.jax import nn
from keras_core.backend.jax import numpy from keras_core.backend.jax import numpy
from keras_core.backend.jax import random from keras_core.backend.jax import random
from keras_core.backend.jax.core import DYNAMIC_BATCH_SIZE_OK
from keras_core.backend.jax.core import DYNAMIC_SHAPES_OK from keras_core.backend.jax.core import DYNAMIC_SHAPES_OK
from keras_core.backend.jax.core import Variable from keras_core.backend.jax.core import Variable
from keras_core.backend.jax.core import cast from keras_core.backend.jax.core import cast

@ -1,69 +1,30 @@
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from tensorflow import nest from tensorflow import nest
from keras_core.backend.common import KerasVariable from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
DYNAMIC_SHAPES_OK = False # Dynamic shapes NG DYNAMIC_SHAPES_OK = False # Dynamic shapes NG
DYNAMIC_BATCH_SIZE_OK = True
class Variable(KerasVariable): class Variable(KerasVariable):
def _initialize(self, value): def _initialize(self, value):
self._value = jnp.array(value, dtype=self._dtype) self._value = jnp.array(value, dtype=self._dtype)
def assign(self, value): def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self.dtype)
if value.shape != self.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
"`variable.assign(value)` must match. "
f"Received: value.shape={value.shape}; "
f"variable.shape={self.value.shape}"
)
if in_stateless_scope():
scope = get_stateless_scope()
scope.add_update((self, value))
else:
if isinstance(value, jnp.ndarray) and value.dtype == self.dtype:
# Avoid a memory copy
self._value = value self._value = value
else:
self._value = jnp.array(value, dtype=self.dtype)
@property def _convert_to_tensor(self, value, dtype=None):
def value(self): return convert_to_tensor(value, dtype=dtype)
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
if value is not None:
return self._maybe_autocast(value)
if self._value is None:
# Unitialized variable. Return a placeholder.
# This is fine because it's only ever used
# in during shape inference with JAX tracer objects
# (anything else would be a bug, to be fixed.)
return self._maybe_autocast(
self._initializer(self._shape, dtype=self._dtype)
)
return self._maybe_autocast(self._value)
def numpy(self):
return np.array(self.value)
# Overload native accessor. # Overload native accessor.
def __jax_array__(self): def __jax_array__(self):
return self.value return self.value
def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
def convert_to_tensor(x, dtype=None): def convert_to_tensor(x, dtype=None):
if dtype is not None: if dtype is not None:
@ -98,10 +59,22 @@ def name_scope(name):
# Shape / dtype inference util # Shape / dtype inference util
def compute_output_spec(fn, *args, **kwargs): def compute_output_spec(fn, *args, **kwargs):
with StatelessScope(): with StatelessScope():
dynamic_batch_map = {}
magic_number = 3
def convert_keras_tensor_to_jax(x): def convert_keras_tensor_to_jax(x):
if isinstance(x, KerasTensor): if isinstance(x, KerasTensor):
return jax.ShapeDtypeStruct(x.shape, dtype=x.dtype) shape = x.shape
if shape and x.shape[0] is None:
shape = list(shape)
shape[0] = magic_number
dynamic_batch = True
else:
dynamic_batch = False
jax_tensor = jax.ShapeDtypeStruct(shape, dtype=x.dtype)
dynamic_batch_map[jax_tensor] = dynamic_batch
return jax_tensor
return x return x
built_in_types = (type(None), int, float, str, bool, complex, bytes) built_in_types = (type(None), int, float, str, bool, complex, bytes)
@ -136,6 +109,17 @@ def compute_output_spec(fn, *args, **kwargs):
def convert_jax_spec_to_keras_tensor(x): def convert_jax_spec_to_keras_tensor(x):
if isinstance(x, jax.ShapeDtypeStruct): if isinstance(x, jax.ShapeDtypeStruct):
if dynamic_batch_map.get(x, False):
shape = list(x.shape)
if shape[0] != magic_number:
raise ValueError(
f"Function {fn} appears to change the "
"batch size of its input. This is not "
"allowed when used in conjunction with "
"dynamic batch sizes. Consider using "
"a static batch size here."
)
shape[0] = None
return KerasTensor(x.shape, x.dtype) return KerasTensor(x.shape, x.dtype)
return x return x

@ -5,6 +5,7 @@ from keras_core.backend.tensorflow import nn
from keras_core.backend.tensorflow import numpy from keras_core.backend.tensorflow import numpy
from keras_core.backend.tensorflow import random from keras_core.backend.tensorflow import random
from keras_core.backend.tensorflow.core import DYNAMIC_SHAPES_OK from keras_core.backend.tensorflow.core import DYNAMIC_SHAPES_OK
from keras_core.backend.tensorflow.core import DYNAMIC_BATCH_SIZE_OK
from keras_core.backend.tensorflow.core import Variable from keras_core.backend.tensorflow.core import Variable
from keras_core.backend.tensorflow.core import cast from keras_core.backend.tensorflow.core import cast
from keras_core.backend.tensorflow.core import compute_output_spec from keras_core.backend.tensorflow.core import compute_output_spec

@ -4,11 +4,10 @@ from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.keras_tensor import KerasTensor from keras_core.backend.common.keras_tensor import KerasTensor
from keras_core.backend.common.stateless_scope import StatelessScope from keras_core.backend.common.stateless_scope import StatelessScope
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
from keras_core.utils.naming import auto_name from keras_core.utils.naming import auto_name
DYNAMIC_SHAPES_OK = True DYNAMIC_SHAPES_OK = True
DYNAMIC_BATCH_SIZE_OK = True
class Variable(KerasVariable, tf.__internal__.types.Tensor): class Variable(KerasVariable, tf.__internal__.types.Tensor):
@ -23,37 +22,11 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor):
value, dtype=self._dtype, trainable=self.trainable, name=self.name value, dtype=self._dtype, trainable=self.trainable, name=self.name
) )
def assign(self, value): def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self.dtype)
if value.shape != self.value.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
"`variable.assign(value)` must match. "
f"Received: value.shape={value.shape}; "
f"variable.shape={self.value.shape}"
)
if in_stateless_scope():
scope = get_stateless_scope()
scope.add_update((self, value))
else:
self.value.assign(value) self.value.assign(value)
@property def _convert_to_tensor(self, value, dtype=None):
def value(self): return convert_to_tensor(value, dtype=dtype)
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
if value is not None:
return self._maybe_autocast(value)
if self._value is None:
# Unitialized variable. Return a placeholder.
# This is fine because it's only ever used
# during shape inference in a scratch graph
# (anything else would be a bug, to be fixed.)
init_val = self._initializer(self._shape, dtype=self._dtype)
return self._maybe_autocast(init_val)
return self._maybe_autocast(self._value)
def numpy(self): # noqa: F811 def numpy(self): # noqa: F811
return self.value.numpy() return self.value.numpy()
@ -66,9 +39,6 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor):
def __tf_tensor__(self, dtype=None, name=None): def __tf_tensor__(self, dtype=None, name=None):
return tf.convert_to_tensor(self.value, dtype=dtype, name=name) return tf.convert_to_tensor(self.value, dtype=dtype, name=name)
def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype)
def convert_to_tensor(x, dtype=None): def convert_to_tensor(x, dtype=None):
if dtype is not None: if dtype is not None:

@ -1,12 +1,11 @@
import numpy as np
import torch import torch
from keras_core.backend.common import KerasVariable from keras_core.backend.common import KerasVariable
from keras_core.backend.common import standardize_dtype from keras_core.backend.common import standardize_dtype
from keras_core.backend.common.stateless_scope import get_stateless_scope
from keras_core.backend.common.stateless_scope import in_stateless_scope
DYNAMIC_SHAPES_OK = True DYNAMIC_SHAPES_OK = True
DYNAMIC_BATCH_SIZE_OK = True
TORCH_DTYPES = { TORCH_DTYPES = {
"float16": torch.float16, "float16": torch.float16,
@ -37,53 +36,16 @@ class Variable(KerasVariable):
def _initialize(self, value): def _initialize(self, value):
self._value = convert_to_tensor(value, dtype=self._dtype) self._value = convert_to_tensor(value, dtype=self._dtype)
def assign(self, value): def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self.dtype) self._value = value
if value.shape != self.shape:
raise ValueError(
"The shape of the target variable and "
"the shape of the target value in "
"`variable.assign(value)` must match. "
f"Received: value.shape={value.shape}; "
f"variable.shape={self.value.shape}"
)
if in_stateless_scope():
scope = get_stateless_scope()
scope.add_update((self, value))
else:
# torch `as_tensor` by default, doesn't copy if tensor is same type
self._value = convert_to_tensor(value, dtype=self.dtype)
@property
def value(self):
if in_stateless_scope():
scope = get_stateless_scope()
value = scope.get_current_value(self)
if value is not None:
return self._maybe_autocast(value)
if self._value is None:
# Unitialized variable. Return a placeholder.
# This is fine because it's only ever used
# during shape inference in a scratch graph
# (anything else would be a bug, to be fixed.)
return self._maybe_autocast(
convert_to_tensor(
self._initializer(self._shape, dtype=self._dtype),
dtype=self._dtype,
)
)
return self._maybe_autocast(self._value)
def numpy(self):
return np.array(self.value)
# Overload native accessor.
def __torch_function__(self, func, types, args=(), kwargs=None):
raise NotImplementedError
def _convert_to_tensor(self, value, dtype=None): def _convert_to_tensor(self, value, dtype=None):
return convert_to_tensor(value, dtype=dtype) return convert_to_tensor(value, dtype=dtype)
# Overload native accessor.
def __torch_function__(self, func, types, args=(), kwargs=None):
return func(self.value, *args, **kwargs)
def convert_to_tensor(x, dtype=None): def convert_to_tensor(x, dtype=None):
# TODO: Need to address device placement arg of `as_tensor` # TODO: Need to address device placement arg of `as_tensor`

File diff suppressed because it is too large Load Diff

@ -49,14 +49,19 @@ class UpSamplingTest(testing.TestCase):
) )
@pytest.mark.skipif( @pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK, not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Backend does not support dynamic shapes", reason="Backend does not support dynamic batch sizes",
) )
def test_upsampling_1d_with_dynamic_shape(self): def test_upsampling_1d_with_dynamic_batch_size(self):
x = KerasTensor([None, 2, 3]) x = KerasTensor([None, 2, 3])
self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3)) self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3))
self.assertEqual(layers.UpSampling1D(size=4)(x).shape, (None, 8, 3)) self.assertEqual(layers.UpSampling1D(size=4)(x).shape, (None, 8, 3))
@pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK,
reason="Backend does not support dynamic shapes",
)
def test_upsampling_1d_with_dynamic_shape(self):
y = KerasTensor([2, None, 3]) y = KerasTensor([2, None, 3])
self.assertEqual(layers.UpSampling1D(size=2)(y).shape, (2, None, 3)) self.assertEqual(layers.UpSampling1D(size=2)(y).shape, (2, None, 3))
self.assertEqual(layers.UpSampling1D(size=4)(y).shape, (2, None, 3)) self.assertEqual(layers.UpSampling1D(size=4)(y).shape, (2, None, 3))

@ -13,7 +13,8 @@ class UpSampling3D(Layer):
Repeats the 1st, 2nd and 3rd dimensions Repeats the 1st, 2nd and 3rd dimensions
of the data by `size[0]`, `size[1]` and `size[2]` respectively. of the data by `size[0]`, `size[1]` and `size[2]` respectively.
Examples: Example:
>>> input_shape = (2, 1, 2, 1, 3) >>> input_shape = (2, 1, 2, 1, 3)
>>> x = np.ones(input_shape) >>> x = np.ones(input_shape)
>>> y = keras_core.layers.UpSampling3D(size=(2, 2, 2))(x) >>> y = keras_core.layers.UpSampling3D(size=(2, 2, 2))(x)
@ -108,6 +109,7 @@ class UpSampling3D(Layer):
self, x, depth_factor, height_factor, width_factor, data_format self, x, depth_factor, height_factor, width_factor, data_format
): ):
"""Resizes the volume contained in a 5D tensor. """Resizes the volume contained in a 5D tensor.
Args: Args:
x: Tensor or variable to resize. x: Tensor or variable to resize.
depth_factor: Positive integer. depth_factor: Positive integer.
@ -116,11 +118,7 @@ class UpSampling3D(Layer):
data_format: One of `"channels_first"`, `"channels_last"`. data_format: One of `"channels_first"`, `"channels_last"`.
Returns: Returns:
A tensor. Resized tensor.
Raises:
ValueError: if `data_format` is neither
`channels_last` or `channels_first`.
""" """
if data_format == "channels_first": if data_format == "channels_first":
output = ops.repeat(x, depth_factor, axis=2) output = ops.repeat(x, depth_factor, axis=2)
@ -133,4 +131,4 @@ class UpSampling3D(Layer):
output = ops.repeat(output, width_factor, axis=3) output = ops.repeat(output, width_factor, axis=3)
return output return output
else: else:
raise ValueError("Invalid data_format: " + str(data_format)) raise ValueError(f"Invalid data_format: {data_format}")

@ -65,8 +65,8 @@ class ZeroPaddingTest(testing.TestCase, parameterized.TestCase):
self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs) self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs)
@pytest.mark.skipif( @pytest.mark.skipif(
not backend.DYNAMIC_SHAPES_OK, not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Backend does not support dynamic shapes", reason="Backend does not support dynamic batch sizes",
) )
def test_zero_padding_3d_with_dynamic_batch_size(self): def test_zero_padding_3d_with_dynamic_batch_size(self):
input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5)) input_layer = layers.Input(batch_shape=(None, 2, 3, 4, 5))

@ -42,10 +42,11 @@ class FunctionTest(testing.TestCase):
self.assertAllClose(y_val[0], np.ones((2, 3)) * 6) self.assertAllClose(y_val[0], np.ones((2, 3)) * 6)
self.assertAllClose(y_val[1], np.ones((2, 3)) * 4) self.assertAllClose(y_val[1], np.ones((2, 3)) * 4)
@pytest.mark.skipif(
not backend.DYNAMIC_BATCH_SIZE_OK,
reason="Test only valid if dynamic batch sizes are supported",
)
def test_dynamic_shape_inference(self): def test_dynamic_shape_inference(self):
if not backend.DYNAMIC_SHAPES_OK:
pytest.skip("Test only valid for dynamic shape backends")
x = keras_tensor.KerasTensor((None, 3)) x = keras_tensor.KerasTensor((None, 3))
y = x**2 y = x**2
fn = function.Function(x, y) fn = function.Function(x, y)

@ -1665,17 +1665,11 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
y3 = np.ones([1, 5, 4, 2]) y3 = np.ones([1, 5, 4, 2])
self.assertAllClose(np.array(knp.cross(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.cross(x1, y1)), np.cross(x1, y1))
self.assertAllClose(np.array(knp.cross(x1, y2)), np.cross(x1, y2)) self.assertAllClose(np.array(knp.cross(x1, y2)), np.cross(x1, y2))
if backend.backend() != "torch":
# API divergence between `torch.cross` and `np.cross`
# `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3
self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3)) self.assertAllClose(np.array(knp.cross(x1, y3)), np.cross(x1, y3))
self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3)) self.assertAllClose(np.array(knp.cross(x2, y3)), np.cross(x2, y3))
self.assertAllClose(np.array(knp.Cross()(x1, y1)), np.cross(x1, y1)) self.assertAllClose(np.array(knp.Cross()(x1, y1)), np.cross(x1, y1))
self.assertAllClose(np.array(knp.Cross()(x1, y2)), np.cross(x1, y2)) self.assertAllClose(np.array(knp.Cross()(x1, y2)), np.cross(x1, y2))
if backend.backend() != "torch":
# API divergence between `torch.cross` and `np.cross`
# `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3
self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3)) self.assertAllClose(np.array(knp.Cross()(x1, y3)), np.cross(x1, y3))
self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3)) self.assertAllClose(np.array(knp.Cross()(x2, y3)), np.cross(x2, y3))
@ -1717,35 +1711,24 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
def test_full_like(self): def test_full_like(self):
x = np.array([[1, 2, 3], [3, 2, 1]]) x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.full_like(x, 2)), np.full_like(x, 2)) self.assertAllClose(np.array(knp.full_like(x, 2)), np.full_like(x, 2))
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose( self.assertAllClose(
np.array(knp.full_like(x, 2, dtype="float32")), np.array(knp.full_like(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"), np.full_like(x, 2, dtype="float32"),
) )
self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2)) self.assertAllClose(np.array(knp.FullLike()(x, 2)), np.full_like(x, 2))
self.assertAllClose(
np.array(knp.FullLike()(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_like_without_torch(self):
x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(
np.array(knp.full_like(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])),
)
self.assertAllClose( self.assertAllClose(
np.array(knp.FullLike()(x, np.ones([2, 3]))), np.array(knp.FullLike()(x, np.ones([2, 3]))),
np.full_like(x, np.ones([2, 3])), np.full_like(x, np.ones([2, 3])),
) )
self.assertAllClose(
np.array(knp.FullLike()(x, 2, dtype="float32")),
np.full_like(x, 2, dtype="float32"),
)
def test_greater(self): def test_greater(self):
x = np.array([[1, 2, 3], [3, 2, 1]]) x = np.array([[1, 2, 3], [3, 2, 1]])
@ -1842,14 +1825,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.linspace(0, 10, 5, endpoint=False), np.linspace(0, 10, 5, endpoint=False),
) )
# TODO: torch.linspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.linspace` has no support for array start/stop.",
)
def test_linspace_without_torch(self):
start = np.zeros([2, 3, 4]) start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4]) stop = np.ones([2, 3, 4])
self.assertAllClose( self.assertAllClose(
@ -1954,13 +1929,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
np.logspace(0, 10, 5, endpoint=False), np.logspace(0, 10, 5, endpoint=False),
) )
# TODO: torch.logspace does not support tensor or array
# for start/stop, create manual implementation
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.logspace` has no support for array start/stop.",
)
def test_logspace_without_torch(self):
start = np.zeros([2, 3, 4]) start = np.zeros([2, 3, 4])
stop = np.ones([2, 3, 4]) stop = np.ones([2, 3, 4])
self.assertAllClose( self.assertAllClose(
@ -2036,11 +2004,6 @@ class NumpyTwoInputOpsCorretnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.outer(x, y)), np.outer(x, y))
self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y)) self.assertAllClose(np.array(knp.Outer()(x, y)), np.outer(x, y))
# TODO: Fix numpy compatibility (squeeze by one dimension only)
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.take` and `np.take` have return shape divergence.",
)
def test_take(self): def test_take(self):
x = np.arange(24).reshape([1, 2, 3, 4]) x = np.arange(24).reshape([1, 2, 3, 4])
indices = np.array([0, 1]) indices = np.array([0, 1])
@ -2804,23 +2767,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.array(knp.pad(x, ((1, 1), (1, 1)))), np.array(knp.pad(x, ((1, 1), (1, 1)))),
np.pad(x, ((1, 1), (1, 1))), np.pad(x, ((1, 1), (1, 1))),
) )
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1)))
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)),
np.pad(x, ((1, 1), (1, 1))),
)
# TODO: implement padding with non-constant padding,
# bypass NotImplementedError for PyTorch
@pytest.mark.skipif(
backend.backend() == "torch",
reason="padding not implemented for non-constant use case",
)
def test_pad_without_torch(self):
x = np.array([[1, 2], [3, 4]])
self.assertAllClose( self.assertAllClose(
np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")), np.array(knp.pad(x, ((1, 1), (1, 1)), mode="reflect")),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"), np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
@ -2830,6 +2776,13 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
np.pad(x, ((1, 1), (1, 1)), mode="symmetric"), np.pad(x, ((1, 1), (1, 1)), mode="symmetric"),
) )
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)), np.pad(x, ((1, 1), (1, 1)))
)
self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)))(x)),
np.pad(x, ((1, 1), (1, 1))),
)
self.assertAllClose( self.assertAllClose(
np.array(knp.Pad(((1, 1), (1, 1)), mode="reflect")(x)), np.array(knp.Pad(((1, 1), (1, 1)), mode="reflect")(x)),
np.pad(x, ((1, 1), (1, 1)), mode="reflect"), np.pad(x, ((1, 1), (1, 1)), mode="reflect"),
@ -2952,12 +2905,6 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.sort(x, axis=0)), np.sort(x, axis=0))
self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0)) self.assertAllClose(np.array(knp.Sort(axis=0)(x)), np.sort(x, axis=0))
# TODO: implement split for `torch` with support for conversion
# of numpy.split args.
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` and `np.split` have return arg divergence.",
)
def test_split(self): def test_split(self):
x = np.array([[1, 2, 3], [3, 2, 1]]) x = np.array([[1, 2, 3], [3, 2, 1]])
self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2)) self.assertAllClose(np.array(knp.split(x, 2)), np.split(x, 2))
@ -3025,13 +2972,7 @@ class NumpyOneInputOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.tile(x, [2, 3])), np.tile(x, [2, 3]))
self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3])) self.assertAllClose(np.array(knp.Tile([2, 3])(x)), np.tile(x, [2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.split` does not support args `offset`, `axis1`, `axis2`",
)
def test_trace(self): def test_trace(self):
# TODO: implement `torch.trace` support for arguments `offset`,
# `axis1`, `axis2` and delete NotImplementedError
x = np.arange(24).reshape([1, 2, 3, 4]) x = np.arange(24).reshape([1, 2, 3, 4])
self.assertAllClose(np.array(knp.trace(x)), np.trace(x)) self.assertAllClose(np.array(knp.trace(x)), np.trace(x))
self.assertAllClose( self.assertAllClose(
@ -3071,14 +3012,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3])) self.assertAllClose(np.array(knp.zeros([2, 3])), np.zeros([2, 3]))
self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3])) self.assertAllClose(np.array(knp.Zeros()([2, 3])), np.zeros([2, 3]))
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.eye` does not support arg `k`.",
)
def test_eye(self): def test_eye(self):
# TODO: implement support for `k` diagonal arg,
# does not exist in torch.eye()
self.assertAllClose(np.array(knp.eye(3)), np.eye(3)) self.assertAllClose(np.array(knp.eye(3)), np.eye(3))
self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4)) self.assertAllClose(np.array(knp.eye(3, 4)), np.eye(3, 4))
self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1)) self.assertAllClose(np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1))
@ -3101,25 +3035,15 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose( self.assertAllClose(
np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1) np.array(knp.full([2, 3], 0.1)), np.full([2, 3], 0.1)
) )
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
)
# TODO: implement conversion of shape into repetitions, pass to
# `torch.tile`, since `torch.full()` only accepts scalars
# for `fill_value`."
@pytest.mark.skipif(
backend.backend() == "torch",
reason="`torch.full` only accepts scalars for `fill_value`.",
)
def test_full_without_torch(self):
self.assertAllClose( self.assertAllClose(
np.array(knp.full([2, 3], np.array([1, 4, 5]))), np.array(knp.full([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])), np.full([2, 3], np.array([1, 4, 5])),
) )
self.assertAllClose(np.array(knp.Full()([2, 3], 0)), np.full([2, 3], 0))
self.assertAllClose(
np.array(knp.Full()([2, 3], 0.1)), np.full([2, 3], 0.1)
)
self.assertAllClose( self.assertAllClose(
np.array(knp.Full()([2, 3], np.array([1, 4, 5]))), np.array(knp.Full()([2, 3], np.array([1, 4, 5]))),
np.full([2, 3], np.array([1, 4, 5])), np.full([2, 3], np.array([1, 4, 5])),
@ -3129,11 +3053,7 @@ class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):
self.assertAllClose(np.array(knp.identity(3)), np.identity(3)) self.assertAllClose(np.array(knp.identity(3)), np.identity(3))
self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3)) self.assertAllClose(np.array(knp.Identity()(3)), np.identity(3))
@pytest.mark.skipif(
backend.backend() == "torch", reason="No torch equivalent for `np.tri`"
)
def test_tri(self): def test_tri(self):
# TODO: create a manual implementation, as PyTorch has no equivalent
self.assertAllClose(np.array(knp.tri(3)), np.tri(3)) self.assertAllClose(np.array(knp.tri(3)), np.tri(3))
self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4)) self.assertAllClose(np.array(knp.tri(3, 4)), np.tri(3, 4))
self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1)) self.assertAllClose(np.array(knp.tri(3, 4, 1)), np.tri(3, 4, 1))

@ -0,0 +1,757 @@
from tensorflow import data as tf_data
from keras_core import layers
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.layer import Layer
from keras_core.saving import saving_lib
from keras_core.saving import serialization_lib
from keras_core.utils.naming import auto_name
class Cross:
def __init__(self, feature_names, crossing_dim, output_mode="one_hot"):
if output_mode not in {"int", "one_hot"}:
raise ValueError(
"Invalid value for argument `output_mode`. "
"Expected one of {'int', 'one_hot'}. "
f"Received: output_mode={output_mode}"
)
self.feature_names = tuple(feature_names)
self.crossing_dim = crossing_dim
self.output_mode = output_mode
@property
def name(self):
return "_X_".join(self.feature_names)
def get_config(self):
return {
"feature_names": self.feature_names,
"crossing_dim": self.crossing_dim,
"output_mode": self.output_mode,
}
@classmethod
def from_config(cls, config):
return cls(**config)
class Feature:
def __init__(self, dtype, preprocessor, output_mode):
if output_mode not in {"int", "one_hot", "float"}:
raise ValueError(
"Invalid value for argument `output_mode`. "
"Expected one of {'int', 'one_hot', 'float'}. "
f"Received: output_mode={output_mode}"
)
self.dtype = dtype
if isinstance(preprocessor, dict):
preprocessor = serialization_lib.deserialize_keras_object(
preprocessor
)
self.preprocessor = preprocessor
self.output_mode = output_mode
def get_config(self):
return {
"dtype": self.dtype,
"preprocessor": serialization_lib.serialize_keras_object(
self.preprocessor
),
"output_mode": self.output_mode,
}
@classmethod
def from_config(cls, config):
return cls(**config)
@keras_core_export("keras_core.utils.FeatureSpace")
class FeatureSpace(Layer):
"""One-stop utility for preprocessing and encoding structured data.
Arguments:
feature_names: Dict mapping the names of your features to their
type specification, e.g. `{"my_feature": "integer_categorical"}`
or `{"my_feature": FeatureSpace.integer_categorical()}`.
For a complete list of all supported types, see
"Available feature types" paragraph below.
output_mode: One of `"concat"` or `"dict"`. In concat mode, all
features get concatenated together into a single vector.
In dict mode, the FeatureSpace returns a dict of individually
encoded features (with the same keys as the input dict keys).
crosses: List of features to be crossed together, e.g.
`crosses=[("feature_1", "feature_2")]`. The features will be
"crossed" by hashing their combined value into
a fixed-length vector.
crossing_dim: Default vector size for hashing crossed features.
Defaults to `32`.
hashing_dim: Default vector size for hashing features of type
`"integer_hashed"` and `"string_hashed"`. Defaults to `32`.
num_discretization_bins: Default number of bins to be used for
discretizing features of type `"float_discretized"`.
Defaults to `32`.
**Available feature types:**
Note that all features can be referred to by their string name,
e.g. `"integer_categorical"`. When using the string name, the default
argument values are used.
```python
# Plain float values.
FeatureSpace.float(name=None)
# Float values to be preprocessed via featurewise standardization
# (i.e. via a `keras.layers.Normalization` layer).
FeatureSpace.float_normalized(name=None)
# Float values to be preprocessed via linear rescaling
# (i.e. via a `keras.layers.Rescaling` layer).
FeatureSpace.float_rescaled(scale=1., offset=0., name=None)
# Float values to be discretized. By default, the discrete
# representation will then be one-hot encoded.
FeatureSpace.float_discretized(
num_bins, bin_boundaries=None, output_mode="one_hot", name=None)
# Integer values to be indexed. By default, the discrete
# representation will then be one-hot encoded.
FeatureSpace.integer_categorical(
max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None)
# String values to be indexed. By default, the discrete
# representation will then be one-hot encoded.
FeatureSpace.string_categorical(
max_tokens=None, num_oov_indices=1, output_mode="one_hot", name=None)
# Integer values to be hashed into a fixed number of bins.
# By default, the discrete representation will then be one-hot encoded.
FeatureSpace.integer_hashed(num_bins, output_mode="one_hot", name=None)
# String values to be hashed into a fixed number of bins.
# By default, the discrete representation will then be one-hot encoded.
FeatureSpace.string_hashed(num_bins, output_mode="one_hot", name=None)
```
Examples:
**Basic usage with a dict of input data:**
```python
raw_data = {
"float_values": [0.0, 0.1, 0.2, 0.3],
"string_values": ["zero", "one", "two", "three"],
"int_values": [0, 1, 2, 3],
}
dataset = tf.data.Dataset.from_tensor_slices(raw_data)
feature_space = FeatureSpace(
features={
"float_values": "float_normalized",
"string_values": "string_categorical",
"int_values": "integer_categorical",
},
crosses=[("string_values", "int_values")],
output_mode="concat",
)
# Before you start using the FeatureSpace,
# you must `adapt()` it on some data.
feature_space.adapt(dataset)
# You can call the FeatureSpace on a dict of data (batched or unbatched).
output_vector = feature_space(raw_data)
```
**Basic usage with `tf.data`:**
```python
# Unlabeled data
preprocessed_ds = unlabeled_dataset.map(feature_space)
# Labeled data
preprocessed_ds = labeled_dataset.map(lambda x, y: (feature_space(x), y))
```
**Basic usage with the Keras Functional API:**
```python
# Retrieve a dict Keras Input objects
inputs = feature_space.get_inputs()
# Retrieve the corresponding encoded Keras tensors
encoded_features = feature_space.get_encoded_features()
# Build a Functional model
outputs = keras.layers.Dense(1, activation="sigmoid")(encoded_features)
model = keras.Model(inputs, outputs)
```
**Customizing each feature or feature cross:**
```python
feature_space = FeatureSpace(
features={
"float_values": FeatureSpace.float_normalized(),
"string_values": FeatureSpace.string_categorical(max_tokens=10),
"int_values": FeatureSpace.integer_categorical(max_tokens=10),
},
crosses=[
FeatureSpace.cross(("string_values", "int_values"), crossing_dim=32)
],
output_mode="concat",
)
```
**Returning a dict of integer-encoded features:**
```python
feature_space = FeatureSpace(
features={
"string_values": FeatureSpace.string_categorical(output_mode="int"),
"int_values": FeatureSpace.integer_categorical(output_mode="int"),
},
crosses=[
FeatureSpace.cross(
feature_names=("string_values", "int_values"),
crossing_dim=32,
output_mode="int",
)
],
output_mode="dict",
)
```
**Specifying your own Keras preprocessing layer:**
```python
# Let's say that one of the features is a short text paragraph that
# we want to encode as a vector (one vector per paragraph) via TF-IDF.
data = {
"text": ["1st string", "2nd string", "3rd string"],
}
# There's a Keras layer for this: TextVectorization.
custom_layer = layers.TextVectorization(output_mode="tf_idf")
# We can use FeatureSpace.feature to create a custom feature
# that will use our preprocessing layer.
feature_space = FeatureSpace(
features={
"text": FeatureSpace.feature(
preprocessor=custom_layer, dtype="string", output_mode="float"
),
},
output_mode="concat",
)
feature_space.adapt(tf.data.Dataset.from_tensor_slices(data))
output_vector = feature_space(data)
```
**Retrieving the underlying Keras preprocessing layers:**
```python
# The preprocessing layer of each feature is available in `.preprocessors`.
preprocessing_layer = feature_space.preprocessors["feature1"]
# The crossing layer of each feature cross is available in `.crossers`.
# It's an instance of keras.layers.HashedCrossing.
crossing_layer = feature_space.crossers["feature1_X_feature2"]
```
**Saving and reloading a FeatureSpace:**
```python
feature_space.save("myfeaturespace.keras")
reloaded_feature_space = keras.models.load_model("myfeaturespace.keras")
```
"""
@classmethod
def cross(cls, feature_names, crossing_dim, output_mode="one_hot"):
return Cross(feature_names, crossing_dim, output_mode=output_mode)
@classmethod
def feature(cls, dtype, preprocessor, output_mode):
return Feature(dtype, preprocessor, output_mode)
@classmethod
def float(cls, name=None):
from keras.layers.core import identity
name = name or auto_name("float")
preprocessor = identity.Identity(
dtype="float32", name=f"{name}_preprocessor"
)
return Feature(
dtype="float32", preprocessor=preprocessor, output_mode="float"
)
@classmethod
def float_rescaled(cls, scale=1.0, offset=0.0, name=None):
name = name or auto_name("float_rescaled")
preprocessor = layers.Rescaling(
scale=scale, offset=offset, name=f"{name}_preprocessor"
)
return Feature(
dtype="float32", preprocessor=preprocessor, output_mode="float"
)
@classmethod
def float_normalized(cls, name=None):
name = name or auto_name("float_normalized")
preprocessor = layers.Normalization(
axis=-1, name=f"{name}_preprocessor"
)
return Feature(
dtype="float32", preprocessor=preprocessor, output_mode="float"
)
@classmethod
def float_discretized(
cls, num_bins, bin_boundaries=None, output_mode="one_hot", name=None
):
name = name or auto_name("float_discretized")
preprocessor = layers.Discretization(
num_bins=num_bins,
bin_boundaries=bin_boundaries,
name=f"{name}_preprocessor",
)
return Feature(
dtype="float32", preprocessor=preprocessor, output_mode=output_mode
)
@classmethod
def integer_categorical(
cls,
max_tokens=None,
num_oov_indices=1,
output_mode="one_hot",
name=None,
):
name = name or auto_name("integer_categorical")
preprocessor = layers.IntegerLookup(
name=f"{name}_preprocessor",
max_tokens=max_tokens,
num_oov_indices=num_oov_indices,
)
return Feature(
dtype="int64", preprocessor=preprocessor, output_mode=output_mode
)
@classmethod
def string_categorical(
cls,
max_tokens=None,
num_oov_indices=1,
output_mode="one_hot",
name=None,
):
name = name or auto_name("string_categorical")
preprocessor = layers.StringLookup(
name=f"{name}_preprocessor",
max_tokens=max_tokens,
num_oov_indices=num_oov_indices,
)
return Feature(
dtype="string", preprocessor=preprocessor, output_mode=output_mode
)
@classmethod
def string_hashed(cls, num_bins, output_mode="one_hot", name=None):
name = name or auto_name("string_hashed")
preprocessor = layers.Hashing(
name=f"{name}_preprocessor", num_bins=num_bins
)
return Feature(
dtype="string", preprocessor=preprocessor, output_mode=output_mode
)
@classmethod
def integer_hashed(cls, num_bins, output_mode="one_hot", name=None):
name = name or auto_name("integer_hashed")
preprocessor = layers.Hashing(
name=f"{name}_preprocessor", num_bins=num_bins
)
return Feature(
dtype="int64", preprocessor=preprocessor, output_mode=output_mode
)
def __init__(
self,
features,
output_mode="concat",
crosses=None,
crossing_dim=32,
hashing_dim=32,
num_discretization_bins=32,
name=None,
):
super().__init__(name=name)
if not features:
raise ValueError("The `features` argument cannot be None or empty.")
self.crossing_dim = crossing_dim
self.hashing_dim = hashing_dim
self.num_discretization_bins = num_discretization_bins
self.features = {
name: self._standardize_feature(name, value)
for name, value in features.items()
}
self.crosses = []
if crosses:
feature_set = set(features.keys())
for cross in crosses:
if isinstance(cross, dict):
cross = serialization_lib.deserialize_keras_object(cross)
if isinstance(cross, Cross):
self.crosses.append(cross)
else:
if not crossing_dim:
raise ValueError(
"When specifying `crosses`, the argument "
"`crossing_dim` "
"(dimensionality of the crossing space) "
"should be specified as well."
)
for key in cross:
if key not in feature_set:
raise ValueError(
"All features referenced "
"in the `crosses` argument "
"should be present in the `features` dict. "
f"Received unknown features: {cross}"
)
self.crosses.append(Cross(cross, crossing_dim=crossing_dim))
self.crosses_by_name = {cross.name: cross for cross in self.crosses}
if output_mode not in {"dict", "concat"}:
raise ValueError(
"Invalid value for argument `output_mode`. "
"Expected one of {'dict', 'concat'}. "
f"Received: output_mode={output_mode}"
)
self.output_mode = output_mode
self.inputs = {
name: self._feature_to_input(name, value)
for name, value in self.features.items()
}
self.preprocessors = {
name: value.preprocessor for name, value in self.features.items()
}
self.encoded_features = None
self.crossers = {
cross.name: self._cross_to_crosser(cross) for cross in self.crosses
}
self.one_hot_encoders = {}
self.built = False
self._is_adapted = False
self.concat = None
self._preprocessed_features_names = None
self._crossed_features_names = None
def _feature_to_input(self, name, feature):
return layers.Input(shape=(1,), dtype=feature.dtype, name=name)
def _standardize_feature(self, name, feature):
if isinstance(feature, Feature):
return feature
if isinstance(feature, dict):
return serialization_lib.deserialize_keras_object(feature)
if feature == "float":
return self.float(name=name)
elif feature == "float_normalized":
return self.float_normalized(name=name)
elif feature == "float_rescaled":
return self.float_rescaled(name=name)
elif feature == "float_discretized":
return self.float_discretized(
name=name, num_bins=self.num_discretization_bins
)
elif feature == "integer_categorical":
return self.integer_categorical(name=name)
elif feature == "string_categorical":
return self.string_categorical(name=name)
elif feature == "integer_hashed":
return self.integer_hashed(self.hashing_dim, name=name)
elif feature == "string_hashed":
return self.string_hashed(self.hashing_dim, name=name)
else:
raise ValueError(f"Invalid feature type: {feature}")
def _cross_to_crosser(self, cross):
return layers.HashedCrossing(cross.crossing_dim, name=cross.name)
def _list_adaptable_preprocessors(self):
adaptable_preprocessors = []
for name in self.features.keys():
preprocessor = self.preprocessors[name]
# Special case: a Normalization layer with preset mean/variance.
# Not adaptable.
if isinstance(preprocessor, layers.Normalization):
if preprocessor.input_mean is not None:
continue
if hasattr(preprocessor, "adapt"):
adaptable_preprocessors.append(name)
return adaptable_preprocessors
def adapt(self, dataset):
if not isinstance(dataset, tf_data.Dataset):
raise ValueError(
"`adapt()` can only be called on a tf.data.Dataset. "
f"Received instead: {dataset} (of type {type(dataset)})"
)
for name in self._list_adaptable_preprocessors():
# Call adapt() on each individual adaptable layer.
# TODO: consider rewriting this to instead iterate on the
# dataset once, split each batch into individual features,
# and call the layer's `_adapt_function` on each batch
# to simulate the behavior of adapt() in a more performant fashion.
feature_dataset = dataset.map(lambda x: x[name])
preprocessor = self.preprocessors[name]
# TODO: consider adding an adapt progress bar.
# Sample 1 element to check the rank
for x in feature_dataset.take(1):
pass
if x.shape.rank == 0:
# The dataset yields unbatched scalars; batch it.
feature_dataset = feature_dataset.batch(32)
if x.shape.rank in {0, 1}:
# If the rank is 1, add a dimension
# so we can reduce on axis=-1.
# Note: if rank was previously 0, it is now 1.
feature_dataset = feature_dataset.map(
lambda x: ops.expand_dims(x, -1)
)
preprocessor.adapt(feature_dataset)
self._is_adapted = True
self.get_encoded_features() # Finish building the layer
self.built = True
def get_inputs(self):
self._check_if_built()
return self.inputs
def get_encoded_features(self):
self._check_if_adapted()
if self.encoded_features is None:
preprocessed_features = self._preprocess_features(self.inputs)
crossed_features = self._cross_features(preprocessed_features)
merged_features = self._merge_features(
preprocessed_features, crossed_features
)
self.encoded_features = merged_features
return self.encoded_features
def _preprocess_features(self, features):
return {
name: self.preprocessors[name](features[name])
for name in features.keys()
}
def _cross_features(self, features):
all_outputs = {}
for cross in self.crosses:
inputs = [features[name] for name in cross.feature_names]
outputs = self.crossers[cross.name](inputs)
all_outputs[cross.name] = outputs
return all_outputs
def _merge_features(self, preprocessed_features, crossed_features):
if not self._preprocessed_features_names:
self._preprocessed_features_names = sorted(
preprocessed_features.keys()
)
self._crossed_features_names = sorted(crossed_features.keys())
all_names = (
self._preprocessed_features_names + self._crossed_features_names
)
all_features = [
preprocessed_features[name]
for name in self._preprocessed_features_names
] + [crossed_features[name] for name in self._crossed_features_names]
if self.output_mode == "dict":
output_dict = {}
else:
features_to_concat = []
if self.built:
# Fast mode.
for name, feature in zip(all_names, all_features):
encoder = self.one_hot_encoders.get(name, None)
if encoder:
feature = encoder(feature)
if self.output_mode == "dict":
output_dict[name] = feature
else:
features_to_concat.append(feature)
if self.output_mode == "dict":
return output_dict
else:
return self.concat(features_to_concat)
# If the object isn't built,
# we create the encoder and concat layers below
all_specs = [
self.features[name] for name in self._preprocessed_features_names
] + [
self.crosses_by_name[name] for name in self._crossed_features_names
]
for name, feature, spec in zip(all_names, all_features, all_specs):
dtype = feature.dtype
if spec.output_mode == "one_hot":
preprocessor = self.preprocessors.get(
name
) or self.crossers.get(name)
cardinality = None
if not feature.dtype.startswith("int"):
raise ValueError(
f"Feature '{name}' has `output_mode='one_hot'`. "
"Thus its preprocessor should return an int64 dtype. "
f"Instead it returns a {dtype} dtype."
)
if isinstance(
preprocessor, (layers.IntegerLookup, layers.StringLookup)
):
cardinality = preprocessor.vocabulary_size()
elif isinstance(preprocessor, layers.CategoryEncoding):
cardinality = preprocessor.num_tokens
elif isinstance(preprocessor, layers.Discretization):
cardinality = preprocessor.num_bins
elif isinstance(
preprocessor, (layers.HashedCrossing, layers.Hashing)
):
cardinality = preprocessor.num_bins
else:
raise ValueError(
f"Feature '{name}' has `output_mode='one_hot'`. "
"However it isn't a standard feature and the "
"dimensionality of its output space is not known, "
"thus it cannot be one-hot encoded. "
"Try using `output_mode='int'`."
)
if cardinality is not None:
encoder = layers.CategoryEncoding(
num_tokens=cardinality, output_mode="multi_hot"
)
self.one_hot_encoders[name] = encoder
feature = encoder(feature)
if self.output_mode == "concat":
dtype = feature.dtype
if dtype.startswith("int") or dtype == "string":
raise ValueError(
f"Cannot concatenate features because feature '{name}' "
f"has not been encoded (it has dtype {dtype}). "
"Consider using `output_mode='dict'`."
)
features_to_concat.append(feature)
else:
output_dict[name] = feature
if self.output_mode == "concat":
self.concat = layers.Concatenate(axis=-1)
return self.concat(features_to_concat)
else:
return output_dict
def _check_if_adapted(self):
if not self._is_adapted:
if not self._list_adaptable_preprocessors():
self._is_adapted = True
else:
raise ValueError(
"You need to call `.adapt(dataset)` on the FeatureSpace "
"before you can start using it."
)
def _check_if_built(self):
if not self.built:
self._check_if_adapted()
# Finishes building
self.get_encoded_features()
self.built = True
def __call__(self, data):
self._check_if_built()
if not isinstance(data, dict):
raise ValueError(
"A FeatureSpace can only be called with a dict. "
f"Received: data={data} (of type {type(data)}"
)
data = {
key: ops.convert_to_tensor(value) for key, value in data.items()
}
rebatched = False
for name, x in data.items():
if x.shape.rank == 0:
data[name] = ops.reshape(x, (1, 1))
rebatched = True
elif x.shape.rank == 1:
data[name] = ops.expand_dims(x, -1)
preprocessed_data = self._preprocess_features(data)
crossed_data = self._cross_features(preprocessed_data)
merged_data = self._merge_features(preprocessed_data, crossed_data)
if rebatched:
if self.output_mode == "concat":
assert merged_data.shape[0] == 1
return ops.squeeze(merged_data, axis=0)
else:
for name, x in merged_data.items():
if x.shape.rank == 2 and x.shape[0] == 1:
merged_data[name] = ops.squeeze(x, axis=0)
return merged_data
def get_config(self):
return {
"features": serialization_lib.serialize_keras_object(self.features),
"output_mode": self.output_mode,
"crosses": serialization_lib.serialize_keras_object(self.crosses),
"crossing_dim": self.crossing_dim,
"hashing_dim": self.hashing_dim,
"num_discretization_bins": self.num_discretization_bins,
}
@classmethod
def from_config(cls, config):
return cls(**config)
def get_build_config(self):
return {
name: feature.preprocessor.get_build_config()
for name, feature in self.features.items()
}
def build_from_config(self, config):
for name in config.keys():
self.features[name].preprocessor.build_from_config(config[name])
self._is_adapted = True
def save(self, filepath):
"""Save the `FeatureSpace` instance to a `.keras` file.
You can reload it via `keras.models.load_model()`:
```python
feature_space.save("myfeaturespace.keras")
reloaded_feature_space = keras.models.load_model("myfeaturespace.keras")
```
"""
saving_lib.save_model(self, filepath)
def save_own_variables(self, store):
return
def load_own_variables(self, store):
return

@ -0,0 +1,378 @@
# from keras_core import testing
# from keras_core.utils import feature_space
# from keras_core import operations as ops
# from tensorflow import nest
# from tensorflow import data as tf_data
# from keras_core import layers
# from keras_core import models
# import os
# class FeatureSpaceTest(testing.TestCase):
# def _get_train_data_dict(
# self, as_dataset=False, as_tf_tensors=False, as_labeled_dataset=False
# ):
# data = {
# "float_1": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
# "float_2": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
# "float_3": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
# "string_1": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
# "string_2": ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"],
# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# "int_2": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# "int_3": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# }
# if as_dataset:
# return tf_data.Dataset.from_tensor_slices(data)
# elif as_tf_tensors:
# return nest.map_structure(ops.convert_to_tensor, data)
# elif as_labeled_dataset:
# labels = [0, 1, 0, 1, 0, 0, 1, 0, 1, 1]
# return tf_data.Dataset.from_tensor_slices((data, labels))
# return data
# def test_basic_usage(self):
# fs = feature_space.FeatureSpace(
# features={
# "float_1": "float",
# "float_2": "float_normalized",
# "float_3": "float_discretized",
# "string_1": "string_categorical",
# "string_2": "string_hashed",
# "int_1": "integer_categorical",
# "int_2": "integer_hashed",
# "int_3": "integer_categorical",
# },
# crosses=[("float_3", "string_1"), ("string_2", "int_2")],
# output_mode="concat",
# )
# # Test unbatched adapt
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# # Test batched adapt
# fs.adapt(self._get_train_data_dict(as_dataset=True).batch(4))
# # Test unbatched call on raw data
# data = {
# key: value[0] for key, value in self._get_train_data_dict().items()
# }
# out = fs(data)
# self.assertEqual(out.shape, [195])
# # Test unbatched call on TF tensors
# data = self._get_train_data_dict(as_tf_tensors=True)
# data = {key: value[0] for key, value in data.items()}
# out = fs(data)
# self.assertEqual(out.shape, [195])
# # Test batched call on raw data
# out = fs(self._get_train_data_dict())
# self.assertEqual(out.shape, [10, 195])
# # Test batched call on TF tensors
# out = fs(self._get_train_data_dict(as_tf_tensors=True))
# self.assertEqual(out.shape, [10, 195])
# def test_output_mode_dict(self):
# fs = feature_space.FeatureSpace(
# features={
# "float_1": "float",
# "float_2": "float_normalized",
# "float_3": "float_discretized",
# "string_1": "string_categorical",
# "string_2": "string_hashed",
# "int_1": "integer_categorical",
# "int_2": "integer_hashed",
# "int_3": "integer_categorical",
# },
# crosses=[("float_3", "string_1"), ("string_2", "int_2")],
# output_mode="dict",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# # Test unbatched call on raw data
# data = {
# key: value[0] for key, value in self._get_train_data_dict().items()
# }
# out = fs(data)
# self.assertIsInstance(out, dict)
# self.assertLen(out, 10)
# self.assertEqual(out["string_1"].shape, [11])
# self.assertEqual(out["int_2"].shape, [32])
# self.assertEqual(out["string_2_X_int_2"].shape, [32])
# # Test batched call on raw data
# out = fs(self._get_train_data_dict())
# self.assertIsInstance(out, dict)
# self.assertLen(out, 10)
# self.assertEqual(out["string_1"].shape, [10, 11])
# self.assertEqual(out["int_2"].shape, [10, 32])
# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32])
# # Test batched call on TF tensors
# out = fs(self._get_train_data_dict(as_tf_tensors=True))
# self.assertIsInstance(out, dict)
# self.assertLen(out, 10)
# self.assertEqual(out["string_1"].shape, [10, 11])
# self.assertEqual(out["int_2"].shape, [10, 32])
# self.assertEqual(out["string_2_X_int_2"].shape, [10, 32])
# def test_output_mode_dict_of_ints(self):
# cls = feature_space.FeatureSpace
# fs = feature_space.FeatureSpace(
# features={
# "float_1": "float",
# "float_2": "float_normalized",
# "float_3": "float_discretized",
# "string_1": cls.string_categorical(output_mode="int"),
# "string_2": cls.string_hashed(num_bins=32, output_mode="int"),
# "int_1": cls.integer_categorical(output_mode="int"),
# "int_2": cls.integer_hashed(num_bins=32, output_mode="int"),
# "int_3": cls.integer_categorical(output_mode="int"),
# },
# crosses=[
# cls.cross(
# ("float_3", "string_1"), output_mode="int", crossing_dim=32
# ),
# cls.cross(
# ("string_2", "int_2"), output_mode="int", crossing_dim=32
# ),
# ],
# output_mode="dict",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# data = {
# key: value[0] for key, value in self._get_train_data_dict().items()
# }
# out = fs(data)
# self.assertIsInstance(out, dict)
# self.assertLen(out, 10)
# self.assertEqual(out["string_1"].shape, [1])
# self.assertEqual(out["string_1"].dtype.name, "int64")
# self.assertEqual(out["int_2"].shape, [1])
# self.assertEqual(out["int_2"].dtype.name, "int64")
# self.assertEqual(out["string_2_X_int_2"].shape, [1])
# self.assertEqual(out["string_2_X_int_2"].dtype.name, "int64")
# def test_functional_api_sync_processing(self):
# fs = feature_space.FeatureSpace(
# features={
# "float_1": "float",
# "float_2": "float_normalized",
# "float_3": "float_discretized",
# "string_1": "string_categorical",
# "string_2": "string_hashed",
# "int_1": "integer_categorical",
# "int_2": "integer_hashed",
# "int_3": "integer_categorical",
# },
# crosses=[("float_3", "string_1"), ("string_2", "int_2")],
# output_mode="concat",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# inputs = fs.get_inputs()
# features = fs.get_encoded_features()
# outputs = layers.Dense(1)(features)
# model = models.Model(inputs=inputs, outputs=outputs)
# model.compile("adam", "mse")
# ds = self._get_train_data_dict(as_labeled_dataset=True)
# model.fit(ds.batch(4))
# model.evaluate(ds.batch(4))
# ds = self._get_train_data_dict(as_dataset=True)
# model.predict(ds.batch(4))
# def test_tf_data_async_processing(self):
# fs = feature_space.FeatureSpace(
# features={
# "float_1": "float",
# "float_2": "float_normalized",
# "float_3": "float_discretized",
# "string_1": "string_categorical",
# "string_2": "string_hashed",
# "int_1": "integer_categorical",
# "int_2": "integer_hashed",
# "int_3": "integer_categorical",
# },
# crosses=[("float_3", "string_1"), ("string_2", "int_2")],
# output_mode="concat",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# features = fs.get_encoded_features()
# outputs = layers.Dense(1)(features)
# model = models.Model(inputs=features, outputs=outputs)
# model.compile("adam", "mse")
# ds = self._get_train_data_dict(as_labeled_dataset=True)
# # Try map before batch
# ds = ds.map(lambda x, y: (fs(x), y))
# model.fit(ds.batch(4))
# # Try map after batch
# ds = self._get_train_data_dict(as_labeled_dataset=True)
# ds = ds.batch(4)
# ds = ds.map(lambda x, y: (fs(x), y))
# model.evaluate(ds)
# ds = self._get_train_data_dict(as_dataset=True)
# ds = ds.map(fs)
# model.predict(ds.batch(4))
# def test_advanced_usage(self):
# cls = feature_space.FeatureSpace
# fs = feature_space.FeatureSpace(
# features={
# "float_1": cls.float(),
# "float_2": cls.float_normalized(),
# "float_3": cls.float_discretized(num_bins=3),
# "string_1": cls.string_categorical(max_tokens=5),
# "string_2": cls.string_hashed(num_bins=32),
# "int_1": cls.integer_categorical(
# max_tokens=5, num_oov_indices=2
# ),
# "int_2": cls.integer_hashed(num_bins=32),
# "int_3": cls.integer_categorical(max_tokens=5),
# },
# crosses=[
# cls.cross(("float_3", "string_1"), crossing_dim=32),
# cls.cross(("string_2", "int_2"), crossing_dim=32),
# ],
# output_mode="concat",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# data = {
# key: value[0] for key, value in self._get_train_data_dict().items()
# }
# out = fs(data)
# self.assertEqual(out.shape, [148])
# def test_manual_kpl(self):
# data = {
# "text": ["1st string", "2nd string", "3rd string"],
# }
# cls = feature_space.FeatureSpace
# # Test with a tf-idf TextVectorization layer
# tv = layers.TextVectorization(output_mode="tf_idf")
# fs = feature_space.FeatureSpace(
# features={
# "text": cls.feature(
# preprocessor=tv, dtype="string", output_mode="float"
# ),
# },
# output_mode="concat",
# )
# fs.adapt(tf_data.Dataset.from_tensor_slices(data))
# out = fs(data)
# self.assertEqual(out.shape, [3, 5])
# def test_no_adapt(self):
# data = {
# "int_1": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
# }
# fs = feature_space.FeatureSpace(
# {
# "int_1": "integer_hashed",
# },
# output_mode="concat",
# )
# out = fs(data)
# self.assertEqual(out.shape, [10, 32])
# def test_saving(self):
# cls = feature_space.FeatureSpace
# fs = feature_space.FeatureSpace(
# features={
# "float_1": cls.float(),
# "float_2": cls.float_normalized(),
# "float_3": cls.float_discretized(num_bins=3),
# "string_1": cls.string_categorical(max_tokens=5),
# "string_2": cls.string_hashed(num_bins=32),
# "int_1": cls.integer_categorical(
# max_tokens=5, num_oov_indices=2
# ),
# "int_2": cls.integer_hashed(num_bins=32),
# "int_3": cls.integer_categorical(max_tokens=5),
# },
# crosses=[
# cls.cross(("float_3", "string_1"), crossing_dim=32),
# cls.cross(("string_2", "int_2"), crossing_dim=32),
# ],
# output_mode="concat",
# )
# fs.adapt(self._get_train_data_dict(as_dataset=True))
# data = {
# key: value[0] for key, value in self._get_train_data_dict().items()
# }
# ref_out = fs(data)
# temp_filepath = os.path.join(self.get_temp_dir(), "fs.keras")
# fs.save(temp_filepath)
# fs = models.models.load_model(temp_filepath)
# # Save again immediately after loading to test idempotency
# temp_filepath = os.path.join(self.get_temp_dir(), "fs2.keras")
# fs.save(temp_filepath)
# # Test correctness of the first saved FS
# out = fs(data)
# self.assertAllClose(out, ref_out)
# inputs = fs.get_inputs()
# outputs = fs.get_encoded_features()
# model = models.Model(inputs=inputs, outputs=outputs)
# ds = self._get_train_data_dict(as_dataset=True)
# out = model.predict(ds.batch(4))
# self.assertAllClose(out[0], ref_out)
# # Test correctness of the re-saved FS
# fs = models.models.load_model(temp_filepath)
# out = fs(data)
# self.assertAllClose(out, ref_out)
# def test_errors(self):
# # Test no features
# with self.assertRaisesRegex(ValueError, "cannot be None or empty"):
# feature_space.FeatureSpace(features={})
# # Test no crossing dim
# with self.assertRaisesRegex(ValueError, "`crossing_dim`"):
# feature_space.FeatureSpace(
# features={
# "f1": "integer_categorical",
# "f2": "integer_categorical",
# },
# crosses=[("f1", "f2")],
# crossing_dim=None,
# )
# # Test wrong cross feature name
# with self.assertRaisesRegex(ValueError, "should be present in "):
# feature_space.FeatureSpace(
# features={
# "f1": "integer_categorical",
# "f2": "integer_categorical",
# },
# crosses=[("f1", "unknown")],
# crossing_dim=32,
# )
# # Test wrong output mode
# with self.assertRaisesRegex(ValueError, "for argument `output_mode`"):
# feature_space.FeatureSpace(
# features={
# "f1": "integer_categorical",
# "f2": "integer_categorical",
# },
# output_mode="unknown",
# )
# # Test call before adapt
# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"):
# fs = feature_space.FeatureSpace(
# features={
# "f1": "integer_categorical",
# "f2": "integer_categorical",
# }
# )
# fs({"f1": [0], "f2": [0]})
# # Test get_encoded_features before adapt
# with self.assertRaisesRegex(ValueError, "You need to call `.adapt"):
# fs = feature_space.FeatureSpace(
# features={
# "f1": "integer_categorical",
# "f2": "integer_categorical",
# }
# )
# fs.get_encoded_features()