Support dict inputs in TFDataLayer, plus some lint fixes (#383)

This commit is contained in:
Ian Stenbit 2023-06-21 13:10:50 -06:00 committed by Francois Chollet
parent 2350e681eb
commit d955292989
3 changed files with 12 additions and 6 deletions

@ -34,9 +34,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
def categorical(logits, num_samples, dtype="int64", seed=None):
seed = tf_draw_seed(seed)
output = tf.random.stateless_categorical(
logits, num_samples, seed=seed
)
output = tf.random.stateless_categorical(logits, num_samples, seed=seed)
return tf.cast(output, dtype)

@ -31,7 +31,10 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
dtype = to_torch_dtype(dtype)
generator = torch_seed_generator(seed, device=get_device())
return torch.multinomial(
logits, num_samples, replacement=True, generator=generator,
logits,
num_samples,
replacement=True,
generator=generator,
).type(dtype)

@ -1,3 +1,5 @@
from tensorflow import nest
from keras_core import backend
from keras_core.layers.layer import Layer
from keras_core.utils import backend_utils
@ -22,8 +24,11 @@ class TFDataLayer(Layer):
):
# We're in a TF graph, e.g. a tf.data pipeline.
self.backend.set_backend("tensorflow")
inputs = self.backend.convert_to_tensor(
inputs, dtype=self.compute_dtype
inputs = nest.map_structure(
lambda x: self.backend.convert_to_tensor(
x, dtype=self.compute_dtype
),
inputs,
)
switch_convert_input_args = False
if self._convert_input_args: