Add IndexLookup layer (private class)

This commit is contained in:
Francois Chollet 2023-08-15 16:39:57 -07:00
parent 4513bf7ba0
commit 75e6c6440c
6 changed files with 1445 additions and 8 deletions

@ -80,6 +80,7 @@ from keras_core.layers.preprocessing.center_crop import CenterCrop
from keras_core.layers.preprocessing.discretization import Discretization
from keras_core.layers.preprocessing.hashed_crossing import HashedCrossing
from keras_core.layers.preprocessing.hashing import Hashing
from keras_core.layers.preprocessing.index_lookup import IndexLookup
from keras_core.layers.preprocessing.integer_lookup import IntegerLookup
from keras_core.layers.preprocessing.normalization import Normalization
from keras_core.layers.preprocessing.random_brightness import RandomBrightness

@ -603,6 +603,11 @@ class Layer(BackendLayer, Operation):
"""The dtype of the state (weights) of the layer."""
return self.dtype_policy.variable_dtype
@property
def input_dtype(self):
"""The dtype layer inputs should be converted to."""
return self.dtype_policy.compute_dtype
@property
def supports_masking(self):
"""Whether this layer supports computing a mask using `compute_mask`."""
@ -627,20 +632,20 @@ class Layer(BackendLayer, Operation):
if (
self.autocast
and backend.is_float_dtype(x.dtype)
and x.dtype != self.compute_dtype
and x.dtype != self.input_dtype
):
x = backend.cast(x, dtype=self.compute_dtype)
x = backend.cast(x, dtype=self.input_dtype)
return x
elif isinstance(x, backend.KerasTensor):
if (
self.autocast
and backend.is_float_dtype(x.dtype)
and x.dtype != self.compute_dtype
and x.dtype != self.input_dtype
):
x.dtype = self.compute_dtype
x.dtype = self.input_dtype
return x
elif hasattr(x, "__array__"):
return backend.convert_to_tensor(x, dtype=self.compute_dtype)
return backend.convert_to_tensor(x, dtype=self.input_dtype)
return x
# Used to avoid expensive `tree` operations in the most common case.
@ -648,7 +653,7 @@ class Layer(BackendLayer, Operation):
kwargs
or len(args) != 1
or not backend.is_tensor(args[0])
or backend.standardize_dtype(args[0].dtype) != self.compute_dtype
or backend.standardize_dtype(args[0].dtype) != self.input_dtype
) and self._convert_input_args:
args = tree.map_structure(maybe_convert, args)
kwargs = tree.map_structure(maybe_convert, kwargs)
@ -1408,7 +1413,11 @@ def update_shapes_dict_for_target_fn(
# Single arg: don't check names, pass first shape.
if len(expected_names) == 1:
key = expected_names[0]
input_shape = tuple(shapes_dict.values())[0]
values = tuple(shapes_dict.values())
if values:
input_shape = values[0]
else:
input_shape = None
return {key: input_shape}
# Multiple args: check that all names line up.

@ -0,0 +1,992 @@
import collections
import numpy as np
from keras_core import backend
from keras_core.layers.layer import Layer
from keras_core.utils import argument_validation
from keras_core.utils import tf_utils
from keras_core.utils.module_utils import tensorflow as tf
class IndexLookup(Layer):
"""Maps values from a vocabulary to integer indices.
This layer translates a set of arbitrary hashables into an integer output
via a table-based lookup, with optional out-of-vocabulary handling. This is
the basis layer for both IntegerLookup and StringLookup; it holds the common
logic but is not intended to be exported as part of the Keras API.
Args:
max_tokens: The maximum size of the vocabulary for this layer.
If `None`, there is no cap on the size of the vocabulary.
Note that this size includes the OOV and mask tokens.
num_oov_indices: The number of out-of-vocabulary tokens to use.
If this value is more than 1, OOV inputs are hashed to determine
their OOV value. If this value is 0,
OOV inputs will cause an error when calling the layer.
mask_token: A token that represents masked inputs.
When `output_mode` is `"int"`,
the token is included in vocabulary and mapped to index 0.
In other output modes, the token will not appear in the vocabulary
and instances of the mask token in the input will be dropped.
If set to `None`, no mask term will be added.
oov_token: Only used when `invert` is `True`.
The token to return for OOV indices.
vocabulary: Optional. Either an array or a string path to a text file.
If passing an array, can pass a tuple, list, 1D numpy array,
or 1D tensor containing the vocbulary terms.
If passing a file path, the file should contain one line per term
in the vocabulary. If this argument is set,
there is no need to `adapt` the layer.
vocabulary_dtype: The dtype of the vocabulary terms.
For example, `"int64"` or `"string"`.
idf_weights: Only valid when `output_mode` is `"tf_idf"`.
A tuple, list, 1D numpy array, or 1D tensor or the same length
as the vocabulary, containing the floating point
inverse document frequency weights, which will be multiplied
by per sample term counts for the final TF-IDF
weight. If the `vocabulary` argument is set, and `output_mode`
is `"tf_idf"`, this argument must be supplied.
invert: Only valid when `output_mode` is `"int"`.
If `True`, this layer will map indices to vocabulary items
instead of mapping vocabulary items to indices.
Defaults to `False`.
output_mode: Specification for the output of the layer. Values can be
`"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or `"tf_idf"`
configuring the layer as follows:
- `"int"`: Return the raw integer indices of the input tokens.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as the vocabulary, containing a 1
at the element index. If the last dimension is size 1,
will encode on that dimension.
If the last dimension is not size 1,
will append a new dimension for the encoded output.
- `"multi_hot"`: Encodes each sample in the input into
a single array the same size as the vocabulary,
containing a 1 for each vocabulary term present in the sample.
Treats the last dimension as the sample dimension,
if input shape is `(..., sample_length)`, output shape will
be `(..., num_tokens)`.
- `"count"`: As `"multi_hot"`, but the int array contains a count
of the number of times the token at that index appeared
in the sample.
- `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm
is applied to find the value in each token slot.
Defaults to `"int"`.
pad_to_max_tokens: Only valid when `output_mode` is `"multi_hot"`,
`"count"`, or `"tf_idf"`. If `True`, the output will have its
feature axis padded to `max_tokens` even if the number
of unique tokens in the vocabulary is less than max_tokens,
resulting in a tensor of shape `(batch_size, max_tokens)`
regardless of vocabulary size. Defaults to `False`.
sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`,
`"count"` and `"tf-idf"` output modes.
If `True`, returns a `SparseTensor` instead of a dense `Tensor`.
Defaults to `False`.
"""
def __init__(
self,
max_tokens,
num_oov_indices,
mask_token,
oov_token,
vocabulary_dtype,
vocabulary=None,
idf_weights=None,
invert=False,
output_mode="int",
sparse=False,
pad_to_max_tokens=False,
name=None,
**kwargs,
):
# If max_tokens is set, the value must be greater than 1 - otherwise we
# are creating a 0-element vocab, which doesn't make sense.
if max_tokens is not None and max_tokens <= 1:
raise ValueError(
"If set, `max_tokens` must be greater than 1. "
f"Received: max_tokens={max_tokens}"
)
if pad_to_max_tokens and max_tokens is None:
raise ValueError(
"If pad_to_max_tokens is True, must set `max_tokens`. "
f"Received: max_tokens={max_tokens}"
)
if num_oov_indices < 0:
raise ValueError(
"`num_oov_indices` must be greater than or equal to 0. "
f"Received: num_oov_indices={num_oov_indices}"
)
# Support deprecated names for output_modes.
if output_mode == "binary":
output_mode = "multi_hot"
if output_mode == "tf-idf":
output_mode = "tf_idf"
argument_validation.validate_string_arg(
output_mode,
allowable_strings=(
"int",
"one_hot",
"multi_hot",
"count",
"tf_idf",
),
caller_name=self.__class__.__name__,
arg_name="output_mode",
)
if invert and output_mode != "int":
raise ValueError(
"`output_mode` must be `'int'` when `invert` is true. "
f"Received: output_mode={output_mode}"
)
if sparse and output_mode == "int":
raise ValueError(
"`sparse` may only be true if `output_mode` is "
"`'one_hot'`, `'multi_hot'`, `'count'` or `'tf_idf'`. "
f"Received: sparse={sparse} and "
f"output_mode={output_mode}"
)
if idf_weights is not None and output_mode != "tf_idf":
raise ValueError(
"`idf_weights` should only be set if `output_mode` is "
f"`'tf_idf'`. Received: idf_weights={idf_weights} and "
f"output_mode={output_mode}"
)
super().__init__(name=name)
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
self.supports_jit = False
self.invert = invert
self.max_tokens = max_tokens
self.num_oov_indices = num_oov_indices
self.mask_token = mask_token
self.oov_token = oov_token
self.output_mode = output_mode
self.sparse = sparse
self.pad_to_max_tokens = pad_to_max_tokens
self.vocabulary_dtype = tf.as_dtype(vocabulary_dtype).name
self._frozen_vocab_size = kwargs.pop("vocabulary_size", None)
self.input_vocabulary = vocabulary
self.input_idf_weights = idf_weights
# We set this hidden attr to
# persist the fact that we have have a non-adaptable layer with a
# manually set vocab.
self._has_input_vocabulary = kwargs.pop(
"has_input_vocabulary", (vocabulary is not None)
)
kwargs.pop("trainable", None)
kwargs.pop("dtype", None)
if kwargs:
raise ValueError(f"Unrecognized keyword argument(s): {kwargs}")
if invert:
self._key_dtype = "int64"
self._value_dtype = self.vocabulary_dtype
mask_key = 0
mask_value = mask_token
self._default_value = self.oov_token
else:
self._key_dtype = self.vocabulary_dtype
self._value_dtype = "int64"
mask_key = mask_token
# Masks should map to 0 for int output and be dropped otherwise. Max
# ints will be dropped from the bincount op.
mask_value = (
0
if self.output_mode == "int"
else tf.as_dtype(self._value_dtype).max
)
if self.num_oov_indices == 0:
# If there are no OOV indices, we map OOV tokens to -1 and error
# out during call if we find a negative index.
self._default_value = -1
elif self.num_oov_indices == 1:
# If there is only one OOV index, we can set that index as the
# default value of the index_lookup table.
self._default_value = self._oov_start_index()
else:
# If we have multiple OOV values, we need to do a further
# hashing step; to make this easier, we set the OOV value to -1.
# (This lets us do a vectorized add and cast to boolean to
# determine locations where we need to do extra hashing.)
self._default_value = -1
if self.mask_token is not None:
self._mask_key = tf.convert_to_tensor(mask_key, self._key_dtype)
self._mask_value = tf.convert_to_tensor(
mask_value, self._value_dtype
)
if self.output_mode == "tf_idf":
if self._has_input_vocabulary and idf_weights is None:
raise ValueError(
"When specifying the `vocabulary` argument, "
"in TF-IDF output mode, the `idf_weights` argument "
"must also be provided."
)
if idf_weights is not None:
self.idf_weights = tf.Variable(
idf_weights,
dtype=backend.floatx(),
trainable=False,
)
self.idf_weights_const = self.idf_weights.value()
if vocabulary is not None:
self.set_vocabulary(vocabulary, idf_weights)
else:
# When restoring from a keras SavedModel, the loading code will
# expect to find and restore a lookup_table attribute on the layer.
# This table needs to be uninitialized as a StaticHashTable cannot
# be initialized twice.
self.lookup_table = self._uninitialized_lookup_table()
# Only set up adapt state if we did not receive a vocab on construction.
if not self._has_input_vocabulary:
# Set adapt state.
self.token_counts = tf.lookup.experimental.MutableHashTable(
key_dtype=vocabulary_dtype,
value_dtype="int64",
default_value=0,
)
if self.output_mode == "tf_idf":
self.token_document_counts = (
tf.lookup.experimental.MutableHashTable(
key_dtype=vocabulary_dtype,
value_dtype="int64",
default_value=0,
)
)
self.num_documents = tf.Variable(
0, dtype="int64", trainable=False
)
def get_vocabulary(self, include_special_tokens=True):
"""Returns the current vocabulary of the layer.
Args:
include_special_tokens: If `True`, the returned vocabulary
will include mask and OOV tokens,
and a term's index in the vocabulary
will equal the term's index when calling the layer.
If `False`, the returned vocabulary will not include
any mask or OOV tokens.
"""
# The lookup table data will not be sorted, so we will create a inverted
# lookup here, and use that to lookup a range of indices
# [0, vocab_size).
if self.lookup_table.size() == 0:
vocab, indices = [], []
else:
keys, values = self.lookup_table.export()
vocab, indices = (values, keys) if self.invert else (keys, values)
vocab, indices = (
self._tensor_vocab_to_numpy(vocab),
indices.numpy(),
)
lookup = collections.defaultdict(
lambda: self.oov_token, zip(indices, vocab)
)
vocab = [lookup[x] for x in range(self.vocabulary_size())]
if self.mask_token is not None and self.output_mode == "int":
vocab[0] = self.mask_token
if not include_special_tokens:
vocab = vocab[self._token_start_index() :]
if self.vocabulary_dtype == "string":
return [
i.decode("utf-8") if isinstance(i, bytes) else i for i in vocab
]
else:
return vocab
def vocabulary_size(self):
"""Gets the current size of the layer's vocabulary.
Returns:
The integer size of the vocabulary, including optional mask and oov
indices.
"""
if tf.executing_eagerly():
return (
int(self.lookup_table.size().numpy())
+ self._token_start_index()
)
else:
return self.lookup_table.size() + self._token_start_index()
def get_config(self):
config = {
"invert": self.invert,
"max_tokens": self.max_tokens,
"num_oov_indices": self.num_oov_indices,
"oov_token": self.oov_token,
"mask_token": self.mask_token,
"output_mode": self.output_mode,
"sparse": self.sparse,
"pad_to_max_tokens": self.pad_to_max_tokens,
"vocabulary_dtype": self.vocabulary_dtype,
"idf_weights": listify_tensors(self.input_idf_weights),
"vocabulary": listify_tensors(self.input_vocabulary),
"vocabulary_size": self._frozen_vocab_size,
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def _record_vocabulary_size(self):
self._ensure_vocab_size_unchanged()
with tf.init_scope():
self._frozen_vocab_size = self.vocabulary_size()
def set_vocabulary(self, vocabulary, idf_weights=None):
"""Sets vocabulary (and optionally document frequency) for this layer.
This method sets the vocabulary and idf weights for this layer directly,
instead of analyzing a dataset through `adapt`. It should be used
whenever the vocab (and optionally document frequency) information is
already known. If vocabulary data is already present in the layer, this
method will replace it.
Args:
vocabulary: Either an array or a string path to a text file.
If passing an array, can pass a tuple, list,
1D numpy array, or 1D tensor containing the vocbulary terms.
If passing a file path, the file should contain one line
per term in the vocabulary.
idf_weights: A tuple, list, 1D numpy array, or 1D tensor
of inverse document frequency weights with equal
length to vocabulary. Must be set if `output_mode`
is `"tf_idf"`. Should not be set otherwise.
"""
if self.output_mode == "tf_idf":
if idf_weights is None:
raise ValueError(
"`idf_weights` must be set if output_mode is 'tf_idf'."
)
elif idf_weights is not None:
raise ValueError(
"`idf_weights` should only be set if output_mode is "
f"`'tf_idf'`. Received: output_mode={self.output_mode} "
f"and idf_weights={idf_weights}"
)
if isinstance(vocabulary, str):
if not tf.io.gfile.exists(vocabulary):
raise ValueError(
f"Vocabulary file {vocabulary} does not exist."
)
if self.output_mode == "tf_idf":
raise ValueError(
"output_mode `'tf_idf'` does not support loading a "
"vocabulary from file."
)
self.lookup_table = self._lookup_table_from_file(vocabulary)
self._record_vocabulary_size()
return
if not tf.executing_eagerly() and (
tf.is_tensor(vocabulary) or tf.is_tensor(idf_weights)
):
raise RuntimeError(
f"Cannot set a tensor vocabulary on layer {self.name} "
"when not executing eagerly. "
"Create this layer or call `set_vocabulary()` "
"outside of any traced function."
)
# TODO(mattdangerw): for better performance we should rewrite this
# entire function to operate on tensors and convert vocabulary to a
# tensor here.
if tf.is_tensor(vocabulary):
vocabulary = self._tensor_vocab_to_numpy(vocabulary)
elif isinstance(vocabulary, (list, tuple)):
vocabulary = np.array(vocabulary)
if tf.is_tensor(idf_weights):
idf_weights = idf_weights.numpy()
elif isinstance(idf_weights, (list, tuple)):
idf_weights = np.array(idf_weights)
if vocabulary.size == 0:
raise ValueError(
"Cannot set an empty vocabulary. "
f"Received: vocabulary={vocabulary}"
)
oov_start = self._oov_start_index()
token_start = self._token_start_index()
special_tokens = [self.mask_token] * oov_start + [
self.oov_token
] * self.num_oov_indices
found_special_tokens = np.array_equal(
special_tokens, vocabulary[:token_start]
)
if found_special_tokens:
tokens = vocabulary[token_start:]
else:
tokens = vocabulary
repeated_tokens = self._find_repeated_tokens(tokens)
if repeated_tokens:
raise ValueError(
"The passed vocabulary has at least one repeated "
"term. Please uniquify your dataset. The repeated terms "
f"are: {repeated_tokens}"
)
if self.mask_token is not None and self.mask_token in tokens:
mask_index = np.argwhere(vocabulary == self.mask_token)[-1]
raise ValueError(
"Found reserved mask token at unexpected location in "
"`vocabulary`. Note that passed `vocabulary` does not need to "
"include the OOV and mask tokens. Either remove all mask and "
"OOV tokens, or include them only at the start of the "
f"vocabulary in precisely this order: {special_tokens}. "
f"Received: mask_token={self.mask_token} at "
f"vocabulary index {mask_index}"
)
# Only error out for oov_token when invert=True. When invert=False,
# oov_token is unused during lookup.
if (
self.oov_token is not None
and self.invert
and self.oov_token in tokens
):
oov_index = np.argwhere(vocabulary == self.oov_token)[-1]
raise ValueError(
"Found reserved OOV token at unexpected location in "
"`vocabulary`. Note that passed `vocabulary` does not need to "
"include the OOV and mask tokens. Either remove all mask and "
"OOV tokens, or include them only at the start of the "
f"vocabulary in precisely this order: {special_tokens}. "
f"Received: oov_token={self.oov_token} at "
f"vocabulary index {oov_index}"
)
new_vocab_size = token_start + len(tokens)
if self.max_tokens is not None and (new_vocab_size > self.max_tokens):
raise ValueError(
"Attempted to set a vocabulary larger than the maximum vocab "
f"size. Received vocabulary size is {new_vocab_size}; "
f"`max_tokens` is {self.max_tokens}."
)
self.lookup_table = self._lookup_table_from_tokens(tokens)
self._record_vocabulary_size()
if self.output_mode == "tf_idf" and idf_weights is not None:
if len(vocabulary) != len(idf_weights):
raise ValueError(
"`idf_weights` must be the same length as vocabulary. "
f"len(idf_weights) is {len(idf_weights)}; "
f"len(vocabulary) is {len(vocabulary)}"
)
idf_weights = self._convert_to_ndarray(idf_weights)
if idf_weights.ndim != 1:
raise ValueError(
"TF-IDF data must be a 1-index array. "
f"Received: type(idf_weights)={type(idf_weights)}"
)
# If the passed vocabulary has no special tokens, we need to pad the
# front of idf_weights. We don't have real document frequencies for
# these tokens so we will use an average of all idf_weights passed
# in as a reasonable default.
if found_special_tokens:
front_padding = 0
front_padding_value = 0
else:
front_padding = token_start
front_padding_value = np.average(idf_weights)
# If pad_to_max_tokens is true, and max_tokens is greater than our
# total vocab size, we need to pad the back of idf_weights with
# zeros as well.
back_padding_value = 0
if self.pad_to_max_tokens and self.max_tokens is not None:
back_padding = (
self.max_tokens - front_padding - len(idf_weights)
)
else:
back_padding = 0
weights = np.pad(
idf_weights,
(front_padding, back_padding),
"constant",
constant_values=(front_padding_value, back_padding_value),
)
weights = tf.convert_to_tensor(weights, dtype=backend.floatx())
self.idf_weights = tf.Variable(
weights,
trainable=False,
)
self.idf_weights_const = self.idf_weights.value()
def build(self):
self.built = True
def get_build_config(self):
return {}
def build_from_config(self, config):
self.build()
@property
def compute_dtype(self):
return self.vocabulary_dtype
@property
def variable_dtype(self):
return self.vocabulary_dtype
def compute_output_spec(self, inputs):
if self.output_mode == "int":
output_dtype = "int64"
output_shape = inputs.shape
else:
output_dtype = backend.floatx()
depth = (
self.max_tokens
if self.pad_to_max_tokens
else self._frozen_vocab_size
)
output_shape = (inputs.shape[0], depth)
return backend.KerasTensor(output_shape, dtype=output_dtype)
def adapt(self, data):
self.reset_state()
if isinstance(data, tf.data.Dataset):
for batch in data:
self.update_state(batch)
else:
data = ensure_tensor(data, dtype=self.vocabulary_dtype)
if data.shape.rank == 1:
# A plain list of strings
# is treated as as many documents
data = tf.expand_dims(data, -1)
self.update_state(data)
self.finalize_state()
def update_state(self, data):
if self._has_input_vocabulary:
raise ValueError(
f"Cannot adapt layer '{self.name}' after setting a static "
"vocabulary via `vocabulary` argument or "
"`set_vocabulary()` method."
)
data = ensure_tensor(data, dtype=self.vocabulary_dtype)
if data.shape.rank == 0:
data = tf.expand_dims(data, 0)
if data.shape.rank == 1:
# Expand dims on axis 0 for tf-idf. A 1-d tensor
# is a single document.
data = tf.expand_dims(data, 0)
tokens, counts = self._num_tokens(data)
self.token_counts.insert(
tokens, counts + self.token_counts.lookup(tokens)
)
if self.output_mode == "tf_idf":
# Dedupe each row of our dataset.
deduped_doc_data = [tf.unique(x)[0] for x in data]
deduped_doc_data = tf.concat(deduped_doc_data, axis=0)
# Flatten and count tokens.
tokens, counts = self._num_tokens(deduped_doc_data)
self.token_document_counts.insert(
tokens, counts + self.token_document_counts.lookup(tokens)
)
if isinstance(data, tf.RaggedTensor):
self.num_documents.assign_add(data.nrows())
else:
self.num_documents.assign_add(
tf.shape(data, out_type="int64")[0]
)
def finalize_state(self):
if self._has_input_vocabulary or tf.equal(self.token_counts.size(), 0):
# Finalize idf_weights to a const for call even if we don't need to
# compute a new vocabulary.
if self.output_mode == "tf_idf":
self.idf_weights_const = self.idf_weights.value()
self._record_vocabulary_size()
return
# Remove special tokens from our counts.
if self.mask_token is not None:
self.token_counts.remove(
tf.convert_to_tensor([self.mask_token], self.vocabulary_dtype)
)
if self.oov_token is not None:
self.token_counts.remove(
tf.convert_to_tensor([self.oov_token], self.vocabulary_dtype)
)
tokens, counts = self.token_counts.export()
# To keep vocabs deterministic, we sort our tokens by count and break
# ties by sorting the tokens themselves. Tensorflow has no ops for
# sorting strings, so we need to use numpy for the sort.
sorted_indices = np.lexsort((tokens.numpy(), counts.numpy()))[::-1]
token_start = self._token_start_index()
if self.max_tokens:
max_learned_tokens = self.max_tokens - token_start
sorted_indices = sorted_indices[:max_learned_tokens]
tokens = tf.gather(tokens, sorted_indices)
self.lookup_table = self._lookup_table_from_tokens(tokens)
if self.output_mode == "tf_idf":
token_document_counts = self.token_document_counts.lookup(tokens)
idf_weights = self._inverse_document_frequency(
token_document_counts, self.num_documents
)
idf_weights = tf.cast(idf_weights, backend.floatx())
# Pad the front of idf_weights with the average idf weight for OOV
# tokens. We cannot compute the real idf weight of OOV in a single
# pass.
idf_weights = tf.pad(
idf_weights,
[[self._token_start_index(), 0]],
constant_values=tf.reduce_mean(idf_weights),
)
if self.pad_to_max_tokens and self.max_tokens is not None:
# Pad the back of idf_weights with zeros.
idf_weights = tf.pad(
idf_weights,
[[0, self.max_tokens - tf.size(idf_weights)]],
constant_values=0,
)
self.idf_weights = tf.Variable(
idf_weights,
dtype=backend.floatx(),
trainable=False,
)
self.idf_weights_const = self.idf_weights.value()
# We call this here to save memory, now that we've built our vocabulary,
# we don't want to keep every token we've seen in separate lookup
# tables.
self.reset_state()
self._record_vocabulary_size()
def reset_state(self):
if self._has_input_vocabulary:
return
self.token_counts.remove(self.token_counts.export()[0])
if self.output_mode == "tf_idf":
self.token_document_counts.remove(
self.token_document_counts.export()[0]
)
self.num_documents.assign(0)
def call(self, inputs):
self._ensure_known_vocab_size()
inputs = ensure_tensor(inputs, dtype=self._key_dtype)
original_shape = inputs.shape
# Some ops will not handle scalar input, so uprank to rank 1.
if inputs.shape.rank == 0:
inputs = self._expand_dims(inputs, -1)
if isinstance(inputs, tf.SparseTensor):
lookups = tf.SparseTensor(
inputs.indices,
self._lookup_dense(inputs.values),
inputs.dense_shape,
)
elif isinstance(inputs, tf.RaggedTensor):
lookups = tf.ragged.map_flat_values(self._lookup_dense, inputs)
else:
lookups = self._lookup_dense(inputs)
if self.output_mode == "int":
# If we received a scalar input, downrank back to a scalar.
if original_shape.rank == 0:
lookups = tf.squeeze(lookups, -1)
return lookups
depth = (
self.max_tokens
if self.pad_to_max_tokens
else self._frozen_vocab_size
)
idf_weights = (
self.idf_weights_const if self.output_mode == "tf_idf" else None
)
return tf_utils.encode_categorical_inputs(
lookups,
output_mode=self.output_mode,
depth=depth,
dtype=self._value_dtype,
sparse=self.sparse,
idf_weights=idf_weights,
)
def _lookup_dense(self, inputs):
"""Lookup table values for a dense Tensor, handling masking and OOV."""
# When executing eagerly and tracing keras.Input objects,
# do not call lookup.
# This is critical for restoring SavedModel, which will first trace
# layer.call and then attempt to restore the table. We need the table to
# be uninitialized for the restore to work, but calling the table
# uninitialized would error.
if tf.executing_eagerly() and backend.is_keras_tensor(inputs):
lookups = tf.zeros_like(inputs, dtype=self._value_dtype)
else:
lookups = self.lookup_table.lookup(inputs)
if self.mask_token is not None:
mask_locations = tf.equal(inputs, self._mask_key)
lookups = tf.where(mask_locations, self._mask_value, lookups)
if self.invert:
return lookups
lookup_checks = []
if self.num_oov_indices == 0:
# If we have zero oov indices, we need to check for oov inputs.
oov_indices = tf.where(tf.equal(lookups, -1))
oov_inputs = tf.gather_nd(inputs, oov_indices)
msg = tf.strings.format(
"When `num_oov_indices=0` all inputs should be in vocabulary, "
"found OOV values {}, consider setting `num_oov_indices=1`.",
(oov_inputs,),
)
assertion = tf.Assert(tf.equal(tf.size(oov_indices), 0), [msg])
lookup_checks.append(assertion)
elif self.num_oov_indices > 1:
# If we have multiple oov indices, we need a further hashing step.
if self._key_dtype.is_integer:
oov_indices = tf.math.floormod(inputs, self.num_oov_indices)
else:
oov_indices = tf.strings.to_hash_bucket_fast(
inputs, num_buckets=self.num_oov_indices
)
oov_indices = oov_indices + self._oov_start_index()
oov_locations = tf.equal(lookups, self._default_value)
lookups = tf.where(oov_locations, oov_indices, lookups)
with tf.control_dependencies(lookup_checks):
return tf.identity(lookups)
def save_own_variables(self, store):
if self.output_mode == "tf_idf":
store["idf_weights"] = self.idf_weights_const.numpy()
def load_own_variables(self, store):
if self.output_mode == "tf_idf":
self.idf_weights.assign(store["idf_weights"])
self.idf_weights_const = self.idf_weights.value()
def save_assets(self, dir_path):
if self.input_vocabulary:
# Vocab saved in config.
# TODO: consider unifying both paths.
return
vocabulary = self.get_vocabulary(include_special_tokens=True)
vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt")
with open(vocabulary_filepath, "w") as f:
f.write("\n".join([str(w) for w in vocabulary]))
def load_assets(self, dir_path):
if self.input_vocabulary:
# Vocab saved in config.
# TODO: consider unifying both paths.
return
vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt")
# TODO: fix bug with include_special_tokens and set reload from file.
with open(vocabulary_filepath, "r") as f:
lines = f.read().split("\n")
if tf.as_dtype(self.vocabulary_dtype) == tf.string:
values = [str(line) for line in lines]
else:
values = [int(line) for line in lines]
if self.output_mode == "tf_idf":
self.set_vocabulary(values, idf_weights=False)
else:
self.set_vocabulary(values)
def _uninitialized_lookup_table(self):
with tf.init_scope():
initializer = NullInitializer(self._key_dtype, self._value_dtype)
return tf.lookup.StaticHashTable(initializer, self._default_value)
def _lookup_table_from_tokens(self, tokens):
with tf.init_scope():
token_start = self._token_start_index()
token_end = token_start + tf.size(tokens)
indices_dtype = (
self._key_dtype if self.invert else self._value_dtype
)
indices = tf.range(token_start, token_end, dtype=indices_dtype)
keys, values = (
(indices, tokens) if self.invert else (tokens, indices)
)
initializer = tf.lookup.KeyValueTensorInitializer(
keys, values, self._key_dtype, self._value_dtype
)
return tf.lookup.StaticHashTable(initializer, self._default_value)
def _lookup_table_from_file(self, filename):
if self.invert:
key_index = tf.lookup.TextFileIndex.LINE_NUMBER
value_index = tf.lookup.TextFileIndex.WHOLE_LINE
else:
key_index = tf.lookup.TextFileIndex.WHOLE_LINE
value_index = tf.lookup.TextFileIndex.LINE_NUMBER
with tf.init_scope():
initializer = tf.lookup.TextFileInitializer(
filename=filename,
key_dtype=self._key_dtype,
key_index=key_index,
value_dtype=self._value_dtype,
value_index=value_index,
value_index_offset=self._token_start_index(),
)
return tf.lookup.StaticHashTable(initializer, self._default_value)
def _convert_to_ndarray(self, x):
return np.array(x) if isinstance(x, (list, tuple)) else x
def _expand_dims(self, inputs, axis):
if isinstance(inputs, tf.SparseTensor):
return tf.sparse.expand_dims(inputs, axis)
else:
return tf.expand_dims(inputs, axis)
def _oov_start_index(self):
return (
1
if self.mask_token is not None and self.output_mode == "int"
else 0
)
def _token_start_index(self):
return self._oov_start_index() + self.num_oov_indices
def _ensure_known_vocab_size(self):
if self.output_mode == "int" or self.pad_to_max_tokens:
return
if self._frozen_vocab_size is None:
raise RuntimeError(
f"When using `output_mode={self.output_mode}` "
"and `pad_to_max_tokens=False`, "
"you must set the layer's vocabulary before calling it. Either "
"pass a `vocabulary` argument to the layer, or call `adapt` "
"with some sample data."
)
def _ensure_vocab_size_unchanged(self):
if self.output_mode == "int" or self.pad_to_max_tokens:
return
with tf.init_scope():
new_vocab_size = self.vocabulary_size()
if (
self._frozen_vocab_size is not None
and new_vocab_size != self._frozen_vocab_size
):
raise RuntimeError(
f"When using `output_mode={self.output_mode}` "
"and `pad_to_max_tokens=False`, "
"the vocabulary size cannot be changed after the layer is "
f"called. Old vocab size is {self._frozen_vocab_size}, "
f"new vocab size is {new_vocab_size}"
)
def _find_repeated_tokens(self, vocabulary):
"""Return all repeated tokens in a vocabulary."""
vocabulary_set = set(vocabulary)
if len(vocabulary) != len(vocabulary_set):
return [
item
for item, count in collections.Counter(vocabulary).items()
if count > 1
]
else:
return []
def _num_tokens(self, data):
"""Count the number of tokens in a ragged, sparse or dense tensor."""
if isinstance(data, tf.SparseTensor):
flat_values = data.values
elif isinstance(data, tf.RaggedTensor):
flat_values = data.flat_values
else:
flat_values = tf.reshape(data, [-1])
tokens, _, counts = tf.unique_with_counts(flat_values, out_idx="int64")
return tokens, counts
def _inverse_document_frequency(self, token_document_counts, num_documents):
"""Computes the inverse-document-frequency (IDF) component of "tf_idf".
Args:
token_document_counts: An array of the # of documents each token
appears in.
num_documents: An int representing the total number of documents
Returns:
An array of "inverse document frequency" weights.
"""
return tf.math.log(1 + num_documents / (1 + token_document_counts))
# Override points for IntegerLookup and StringLookup.
def _tensor_vocab_to_numpy(self, vocabulary):
"""Converts a tensor vocabulary to a numpy vocabulary."""
return vocabulary.numpy()
class NullInitializer(tf.lookup.KeyValueTensorInitializer):
"""A placeholder initializer for restoring this layer from a SavedModel."""
def __init__(self, key_dtype, value_dtype):
"""Construct a table initializer object.
Args:
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
"""
self._key_dtype = key_dtype
self._value_dtype = value_dtype
@property
def key_dtype(self):
"""The expected table key dtype."""
return self._key_dtype
@property
def value_dtype(self):
"""The expected table value dtype."""
return self._value_dtype
def initialize(self, table):
"""Returns the table initialization op."""
pass
def listify_tensors(x):
"""Convert any tensors or numpy arrays to lists for config serialization."""
if tf.is_tensor(x):
x = x.numpy()
if isinstance(x, np.ndarray):
x = x.tolist()
return x
def ensure_tensor(inputs, dtype=None):
"""Ensures the input is a Tensor, SparseTensor or RaggedTensor."""
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)):
inputs = tf.convert_to_tensor(inputs, dtype)
if dtype is not None and inputs.dtype != dtype:
inputs = tf.cast(inputs, dtype)
return inputs

@ -0,0 +1,426 @@
import os
import numpy as np
import pytest
import tensorflow as tf
from absl.testing import parameterized
from keras_core import backend
from keras_core import layers
from keras_core import models
from keras_core import testing
from keras_core.saving import saving_api
class IndexLookupLayerTest(testing.TestCase, parameterized.TestCase):
def test_basics_string_vocab(self):
# Case: adapt + list inputs
adapt_data = ["one", "one", "one", "two", "two", "three"]
input_data = ["one", "two", "four"]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: numpy array input
output = layer(np.array(input_data))
self.assertEqual(list(output), [2, 3, 1])
# Case: fixed vocab + list inputs
vocabulary = ["one", "two", "three"]
layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: fixed vocab with special tokens + list inputs
vocabulary_with_special_tokens = ["", "[OOV]", "one", "two", "three"]
layer = layers.IndexLookup(
vocabulary=vocabulary_with_special_tokens, **kwargs
)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: set vocabulary
layer = layers.IndexLookup(**kwargs)
layer.set_vocabulary(vocabulary)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: set vocabulary (with special tokens)
layer = layers.IndexLookup(**kwargs)
layer.set_vocabulary(vocabulary_with_special_tokens)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
def test_basics_integer_vocab(self):
# Case: adapt + list inputs
adapt_data = [1, 1, 1, 2, 2, 3]
input_data = [1, 2, 4]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": 0,
"oov_token": -1,
"vocabulary_dtype": "int64",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2, 3],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: numpy array input
output = layer(np.array(input_data))
self.assertEqual(list(output), [2, 3, 1])
# Case: fixed vocab + list inputs
vocabulary = [1, 2, 3]
layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2, 3],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: fixed vocab with special tokens + list inputs
vocabulary_with_special_tokens = [0, -1, 1, 2, 3]
layer = layers.IndexLookup(
vocabulary=vocabulary_with_special_tokens, **kwargs
)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2, 3],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: set vocabulary
layer = layers.IndexLookup(**kwargs)
layer.set_vocabulary(vocabulary)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2, 3],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
# Case: set vocabulary (with special tokens)
layer = layers.IndexLookup(**kwargs)
layer.set_vocabulary(vocabulary_with_special_tokens)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2, 3],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
def test_max_tokens_adapt(self):
adapt_data = [1, 1, 1, 2, 2, 3]
input_data = [1, 2, 3, 4]
kwargs = {
"max_tokens": 4,
"num_oov_indices": 1,
"mask_token": 0,
"oov_token": -1,
"vocabulary_dtype": "int64",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2])
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
[1, 2],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
def test_pad_to_max_tokens(self):
vocabulary = [1, 2]
input_data = [1, 2]
kwargs = {
"max_tokens": 5,
"num_oov_indices": 1,
"mask_token": 0,
"oov_token": -1,
"vocabulary_dtype": "int64",
"vocabulary": vocabulary,
"pad_to_max_tokens": True,
"output_mode": "multi_hot",
}
layer = layers.IndexLookup(**kwargs)
output = layer(input_data)
self.assertAllClose(output, [0, 1, 1, 0, 0])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)
def test_output_modes(self):
vocabulary = ["one", "two", "three"]
single_sample_input_data = ["one", "two", "four"]
batch_input_data = [["one", "two", "four", "two"]]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
"vocabulary": vocabulary,
}
# int
kwargs["output_mode"] = "int"
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertAllClose(output, [2, 3, 1])
output = layer(batch_input_data)
self.assertAllClose(output, [[2, 3, 1, 3]])
# multi-hot
kwargs["output_mode"] = "multi_hot"
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertAllClose(output, [1, 1, 1, 0])
output = layer(batch_input_data)
self.assertAllClose(output, [[1, 1, 1, 0]])
# one-hot
kwargs["output_mode"] = "one_hot"
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertAllClose(output, [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]])
# count
kwargs["output_mode"] = "count"
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertAllClose(output, [1, 1, 1, 0])
output = layer(batch_input_data)
self.assertAllClose(output, [[1, 1, 2, 0]])
# tf-idf
kwargs["output_mode"] = "tf_idf"
kwargs["idf_weights"] = np.array([0.1, 0.2, 0.3])
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertAllClose(output, [0.2, 0.1, 0.2, 0.0])
output = layer(batch_input_data)
self.assertAllClose(output, [[0.2, 0.1, 0.4, 0.0]])
def test_sparse_outputs(self):
# TODO
pass
def test_adapt_tf_idf(self):
# Case: unbatched data
adapt_data = ["one", "one", "one", "two", "two", "three"]
input_data = ["one", "two", "four"]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
"output_mode": "tf_idf",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
output = layer(input_data)
# Document counts for one, two, three = [3, 2, 1]
idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([3, 2, 1])))
self.assertAllClose(layer.idf_weights[1:], idf_weights)
self.assertAllClose(output, [1.1337324, 0.91629076, 1.0986123, 0.0])
# Case: batched data
adapt_data = [["one", "one"], ["one", "two"], ["two", "three"]]
input_data = [["one", "two"], ["two", "four"]]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
"output_mode": "tf_idf",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
# Document counts for one, two, three = [2, 2, 1]
idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([2, 2, 1])))
self.assertAllClose(layer.idf_weights[1:], idf_weights)
output = layer(input_data)
self.assertAllClose(
output,
[
[0.0, 0.6931472, 0.6931472, 0.0],
[0.76752836, 0.0, 0.6931472, 0.0],
],
)
def test_invert(self):
vocabulary = ["one", "two", "three"]
single_sample_input_data = [2, 3, 1]
batch_input_data = [[2, 3, 1, 3]]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
"vocabulary": vocabulary,
"invert": True,
"output_mode": "int",
}
layer = layers.IndexLookup(**kwargs)
output = layer(single_sample_input_data)
self.assertEqual(
[w.decode("utf-8") for w in output.numpy()], ["one", "two", "[OOV]"]
)
output = layer(batch_input_data)
self.assertEqual(
[w.decode("utf-8") for w in output.numpy()[0]],
["one", "two", "[OOV]", "two"],
)
@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="Requires string input dtype"
)
def test_saving(self):
# Test with adapt()
vocabulary = ["one", "two", "three"]
adapt_data = ["one", "one", "one", "two", "two", "three"]
batch_input_data = np.array([["one", "two", "four"]])
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
"output_mode": "int",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
model = models.Sequential(
[
layers.Input(shape=(None,), dtype="string"),
layer,
]
)
output_1 = model(batch_input_data)
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path)
model = saving_api.load_model(path)
output_2 = model(batch_input_data)
self.assertAllClose(output_1, output_2)
# Test when vocabulary is provided
kwargs["vocabulary"] = vocabulary
layer = layers.IndexLookup(**kwargs)
model = models.Sequential(
[
layers.Input(shape=(None,), dtype="string"),
layer,
]
)
output_1 = model(batch_input_data)
path = os.path.join(self.get_temp_dir(), "model.keras")
model.save(path)
model = saving_api.load_model(path)
output_2 = model(batch_input_data)
self.assertAllClose(output_1, output_2)
def test_adapt_with_tf_data(self):
# Case: adapt + list inputs
adapt_data = tf.data.Dataset.from_tensor_slices(
["one", "one", "one", "two", "two", "three"]
).batch(2)
input_data = ["one", "two", "four"]
kwargs = {
"max_tokens": 7,
"num_oov_indices": 1,
"mask_token": "",
"oov_token": "[OOV]",
"vocabulary_dtype": "string",
}
layer = layers.IndexLookup(**kwargs)
layer.adapt(adapt_data)
self.assertEqual(
layer.get_vocabulary(), ["", "[OOV]", "one", "two", "three"]
)
self.assertEqual(
layer.get_vocabulary(include_special_tokens=False),
["one", "two", "three"],
)
output = layer(input_data)
self.assertEqual(list(output), [2, 3, 1])
if backend.backend() != "torch":
self.run_class_serialization_test(layer)

@ -255,6 +255,15 @@ class Sequential(Model):
f"Sequential model '{self.name}' has no defined outputs yet."
)
@property
def input_dtype(self):
# Sequential.__call__ will try to convert its inputs
# to the dtype expected by its input layer, if any.
layers = self._layers
if layers and isinstance(layers[0], InputLayer):
return layers[0].dtype
return super().input_dtype
def _is_layer_name_unique(self, layer):
for ref_layer in self._layers:
if layer.name == ref_layer.name and ref_layer is not layer:

@ -108,7 +108,7 @@ def encode_categorical_inputs(
bincounts.dense_shape,
)
else:
return tf.multiply(bincounts, idf_weights)
return tf.multiply(tf.cast(bincounts, idf_weights.dtype), idf_weights)
def get_tensor_spec(t, dynamic_batch=False, name=None):