Support dict inputs in TFDataLayer, plus some lint fixes (#383)
This commit is contained in:
parent
2350e681eb
commit
d955292989
@ -34,9 +34,7 @@ def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
|
||||
|
||||
def categorical(logits, num_samples, dtype="int64", seed=None):
|
||||
seed = tf_draw_seed(seed)
|
||||
output = tf.random.stateless_categorical(
|
||||
logits, num_samples, seed=seed
|
||||
)
|
||||
output = tf.random.stateless_categorical(logits, num_samples, seed=seed)
|
||||
return tf.cast(output, dtype)
|
||||
|
||||
|
||||
|
@ -31,7 +31,10 @@ def categorical(logits, num_samples, dtype="int32", seed=None):
|
||||
dtype = to_torch_dtype(dtype)
|
||||
generator = torch_seed_generator(seed, device=get_device())
|
||||
return torch.multinomial(
|
||||
logits, num_samples, replacement=True, generator=generator,
|
||||
logits,
|
||||
num_samples,
|
||||
replacement=True,
|
||||
generator=generator,
|
||||
).type(dtype)
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.layers.layer import Layer
|
||||
from keras_core.utils import backend_utils
|
||||
@ -22,8 +24,11 @@ class TFDataLayer(Layer):
|
||||
):
|
||||
# We're in a TF graph, e.g. a tf.data pipeline.
|
||||
self.backend.set_backend("tensorflow")
|
||||
inputs = self.backend.convert_to_tensor(
|
||||
inputs, dtype=self.compute_dtype
|
||||
inputs = nest.map_structure(
|
||||
lambda x: self.backend.convert_to_tensor(
|
||||
x, dtype=self.compute_dtype
|
||||
),
|
||||
inputs,
|
||||
)
|
||||
switch_convert_input_args = False
|
||||
if self._convert_input_args:
|
||||
|
Loading…
Reference in New Issue
Block a user