Minor fixes

This commit is contained in:
Francois Chollet 2023-06-03 10:36:26 -07:00
parent 9e58d0d0fb
commit 8d63604975
7 changed files with 55 additions and 56 deletions

@ -80,7 +80,7 @@ class KerasTensor:
"class MyLayer(Layer):\n"
" def call(self, x):\n"
" return jax_fn(x)\n\n"
"x = MyLayer()(x)"
"x = MyLayer()(x)\n"
"```\n"
)

@ -10,6 +10,7 @@ from keras_core import optimizers as optimizers_module
from keras_core.trainers import trainer as base_trainer
from keras_core.trainers.data_adapters import data_adapter_utils
from keras_core.trainers.epoch_iterator import EpochIterator
from keras_core.utils import traceback_utils
class TensorFlowTrainer(base_trainer.Trainer):
@ -100,6 +101,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
@tf.autograph.experimental.do_not_convert
def one_step_on_iterator(iterator):
"""Runs a single training step given a Dataset iterator."""
data = next(iterator)
@ -113,8 +115,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
@tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
for _ in range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
@ -142,6 +145,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
@tf.autograph.experimental.do_not_convert
def one_step_on_iterator(iterator):
"""Runs a single test step given a Dataset iterator."""
data = next(iterator)
@ -155,8 +159,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
@tf.autograph.experimental.do_not_convert
def multi_step_on_iterator(iterator):
for _ in tf.range(self.steps_per_execution):
for _ in range(self.steps_per_execution):
outputs = one_step_on_iterator(iterator)
return outputs
@ -184,6 +189,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
one_step_on_data, jit_compile=True, reduce_retracing=True
)
@tf.autograph.experimental.do_not_convert
def one_step_on_data_distributed(data):
data = data[0]
outputs = self.distribute_strategy.run(
@ -196,6 +202,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
)
return outputs
@tf.autograph.experimental.do_not_convert
def multi_step_on_data(data):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
@ -217,6 +224,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
self.predict_function = predict_function
@traceback_utils.filter_traceback
def fit(
self,
x=None,
@ -346,6 +354,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
callbacks.on_train_end(logs=training_logs)
return self.history
@traceback_utils.filter_traceback
def evaluate(
self,
x=None,
@ -405,6 +414,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
return logs
return self._flatten_metrics_in_order(logs)
@traceback_utils.filter_traceback
def predict(
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):

@ -9,6 +9,7 @@ from keras_core.saving import saving_api
from keras_core.saving import saving_lib
from keras_core.utils import io_utils
from keras_core.utils import summary_utils
from keras_core.utils import traceback_utils
if backend.backend() == "tensorflow":
from keras_core.backend.tensorflow.trainer import (
@ -164,6 +165,7 @@ class Model(Trainer, Layer):
"Please use another name."
)
@traceback_utils.filter_traceback
def get_layer(self, name=None, index=None):
"""Retrieves a layer based on either its name (unique) or index.
@ -204,6 +206,7 @@ class Model(Trainer, Layer):
"Provide either a layer name or layer index at `get_layer`."
)
@traceback_utils.filter_traceback
def summary(
self,
line_length=None,
@ -253,6 +256,7 @@ class Model(Trainer, Layer):
layer_range=layer_range,
)
@traceback_utils.filter_traceback
def save(self, filepath, overwrite=True):
if not str(filepath).endswith(".keras"):
raise ValueError(
@ -269,6 +273,7 @@ class Model(Trainer, Layer):
return
saving_lib.save_model(self, filepath)
@traceback_utils.filter_traceback
def save_weights(self, filepath, overwrite=True):
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
@ -285,6 +290,7 @@ class Model(Trainer, Layer):
return
saving_lib.save_weights_only(self, filepath)
@traceback_utils.filter_traceback
def load_weights(self, filepath, skip_mismatch=False, **kwargs):
saving_api.load_weights(
self, filepath, skip_mismatch=skip_mismatch, **kwargs
@ -335,6 +341,7 @@ class Model(Trainer, Layer):
stacklevel=2,
)
@traceback_utils.filter_traceback
def export(self, filepath):
raise NotImplementedError

@ -1,8 +1,6 @@
import numpy as np
import pytest
from absl.testing import parameterized
import keras_core
from keras_core import testing
from keras_core.operations import numpy as knp
from keras_core.random import random
@ -63,23 +61,3 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
x_res = random.dropout(x, rate=0.8, seed=0)
self.assertGreater(knp.max(x_res), knp.max(x))
self.assertGreater(knp.sum(x_res == 0), 2)
@pytest.mark.skipif(
keras_core.backend.backend() != "jax",
reason="This test requires `jax` as the backend.",
)
def test_dropout_jax_jit_stateless(self):
import jax
import jax.numpy as jnp
x = knp.ones(3)
@jax.jit
def train_step(x):
with keras_core.backend.StatelessScope():
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
return x
keras_core.utils.traceback_utils.disable_traceback_filtering()
x = train_step(x)
assert isinstance(x, jnp.ndarray)

@ -16,10 +16,7 @@ class SeedGenerator:
)
def seed_initializer(*args, **kwargs):
from keras_core.backend import convert_to_tensor
dtype = kwargs.get("dtype", None)
return convert_to_tensor([seed, 0], dtype=dtype)
return [seed, 0]
self.state = Variable(
seed_initializer,

@ -176,25 +176,9 @@ class PyDatasetAdapter(DataAdapter):
self.shuffle = shuffle
# Grab the first example
data = self.py_dataset[0]
if not isinstance(data, tuple):
raise ValueError(
"PyDataset.__getitem__() must return a tuple, either "
"(input,) or (inputs, targets) or "
"(inputs, targets, sample_weights). "
f"Received: {data}"
)
if self.class_weight is not None:
if len(data) == 3:
raise ValueError(
"You cannot `class_weight` and `sample_weight` "
"at the same time."
)
if len(data) == 2:
sw = data_adapter_utils.class_weight_to_sample_weights(
data[1], class_weight
)
data = data + (sw,)
batch = self.py_dataset[0]
# Run checks on it and format it
batch = self._standardize_batch(batch)
def get_tensor_spec(x):
shape = x.shape
@ -208,7 +192,32 @@ class PyDatasetAdapter(DataAdapter):
shape[0] = None # The batch size is not guaranteed to be static.
return tf.TensorSpec(shape=shape, dtype=x.dtype.name)
self._output_signature = tf.nest.map_structure(get_tensor_spec, data)
self._output_signature = tf.nest.map_structure(get_tensor_spec, batch)
def _standardize_batch(self, batch):
if isinstance(batch, np.ndarray):
batch = (batch,)
if isinstance(batch, list):
batch = tuple(batch)
if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}:
raise ValueError(
"PyDataset.__getitem__() must return a tuple, either "
"(input,) or (inputs, targets) or "
"(inputs, targets, sample_weights). "
f"Received: {str(batch)[:100]}... of type {type(batch)}"
)
if self.class_weight is not None:
if len(batch) == 3:
raise ValueError(
"You cannot specify `class_weight` "
"and `sample_weight` at the same time."
)
if len(batch) == 2:
sw = data_adapter_utils.class_weight_to_sample_weights(
batch[1], class_weight
)
batch = batch + (sw,)
return batch
def _make_multiprocessed_generator_fn(self):
workers = self.py_dataset.workers
@ -244,11 +253,7 @@ class PyDatasetAdapter(DataAdapter):
def get_numpy_iterator(self):
gen_fn = self._make_multiprocessed_generator_fn()
for i, batch in enumerate(gen_fn()):
if len(batch) == 2 and self.class_weight is not None:
sw = data_adapter_utils.class_weight_to_sample_weights(
batch[1], self.class_weight
)
batch = batch + (sw,)
batch = self._standardize_batch(batch)
yield batch
if i >= len(self.py_dataset) - 1 and self.enqueuer:
self.enqueuer.stop()
@ -266,11 +271,11 @@ class PyDatasetAdapter(DataAdapter):
def on_epoch_end(self):
if self.enqueuer:
self.enqueuer.stop()
self._py_dataset.on_epoch_end()
self.py_dataset.on_epoch_end()
@property
def num_batches(self):
return len(self._py_dataset)
return len(self.py_dataset)
@property
def batch_size(self):

@ -7,6 +7,7 @@ from keras_core import optimizers
from keras_core.saving import serialization_lib
from keras_core.trainers.compile_utils import CompileLoss
from keras_core.trainers.compile_utils import CompileMetrics
from keras_core.utils import traceback_utils
from keras_core.utils import tracking
@ -18,6 +19,7 @@ class Trainer:
self.compiled = False
self.steps_per_execution = 1
@traceback_utils.filter_traceback
@tracking.no_automatic_dependency_tracking
def compile(
self,