Minor fixes
This commit is contained in:
parent
9e58d0d0fb
commit
8d63604975
@ -80,7 +80,7 @@ class KerasTensor:
|
||||
"class MyLayer(Layer):\n"
|
||||
" def call(self, x):\n"
|
||||
" return jax_fn(x)\n\n"
|
||||
"x = MyLayer()(x)"
|
||||
"x = MyLayer()(x)\n"
|
||||
"```\n"
|
||||
)
|
||||
|
||||
|
@ -10,6 +10,7 @@ from keras_core import optimizers as optimizers_module
|
||||
from keras_core.trainers import trainer as base_trainer
|
||||
from keras_core.trainers.data_adapters import data_adapter_utils
|
||||
from keras_core.trainers.epoch_iterator import EpochIterator
|
||||
from keras_core.utils import traceback_utils
|
||||
|
||||
|
||||
class TensorFlowTrainer(base_trainer.Trainer):
|
||||
@ -100,6 +101,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def one_step_on_iterator(iterator):
|
||||
"""Runs a single training step given a Dataset iterator."""
|
||||
data = next(iterator)
|
||||
@ -113,8 +115,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
)
|
||||
return outputs
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def multi_step_on_iterator(iterator):
|
||||
for _ in tf.range(self.steps_per_execution):
|
||||
for _ in range(self.steps_per_execution):
|
||||
outputs = one_step_on_iterator(iterator)
|
||||
return outputs
|
||||
|
||||
@ -142,6 +145,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def one_step_on_iterator(iterator):
|
||||
"""Runs a single test step given a Dataset iterator."""
|
||||
data = next(iterator)
|
||||
@ -155,8 +159,9 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
)
|
||||
return outputs
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def multi_step_on_iterator(iterator):
|
||||
for _ in tf.range(self.steps_per_execution):
|
||||
for _ in range(self.steps_per_execution):
|
||||
outputs = one_step_on_iterator(iterator)
|
||||
return outputs
|
||||
|
||||
@ -184,6 +189,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
one_step_on_data, jit_compile=True, reduce_retracing=True
|
||||
)
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def one_step_on_data_distributed(data):
|
||||
data = data[0]
|
||||
outputs = self.distribute_strategy.run(
|
||||
@ -196,6 +202,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
)
|
||||
return outputs
|
||||
|
||||
@tf.autograph.experimental.do_not_convert
|
||||
def multi_step_on_data(data):
|
||||
outputs = one_step_on_data_distributed(data[:1])
|
||||
for single_step_data in data[1:]:
|
||||
@ -217,6 +224,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
|
||||
self.predict_function = predict_function
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def fit(
|
||||
self,
|
||||
x=None,
|
||||
@ -346,6 +354,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
callbacks.on_train_end(logs=training_logs)
|
||||
return self.history
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def evaluate(
|
||||
self,
|
||||
x=None,
|
||||
@ -405,6 +414,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
return logs
|
||||
return self._flatten_metrics_in_order(logs)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def predict(
|
||||
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
||||
):
|
||||
|
@ -9,6 +9,7 @@ from keras_core.saving import saving_api
|
||||
from keras_core.saving import saving_lib
|
||||
from keras_core.utils import io_utils
|
||||
from keras_core.utils import summary_utils
|
||||
from keras_core.utils import traceback_utils
|
||||
|
||||
if backend.backend() == "tensorflow":
|
||||
from keras_core.backend.tensorflow.trainer import (
|
||||
@ -164,6 +165,7 @@ class Model(Trainer, Layer):
|
||||
"Please use another name."
|
||||
)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def get_layer(self, name=None, index=None):
|
||||
"""Retrieves a layer based on either its name (unique) or index.
|
||||
|
||||
@ -204,6 +206,7 @@ class Model(Trainer, Layer):
|
||||
"Provide either a layer name or layer index at `get_layer`."
|
||||
)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def summary(
|
||||
self,
|
||||
line_length=None,
|
||||
@ -253,6 +256,7 @@ class Model(Trainer, Layer):
|
||||
layer_range=layer_range,
|
||||
)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def save(self, filepath, overwrite=True):
|
||||
if not str(filepath).endswith(".keras"):
|
||||
raise ValueError(
|
||||
@ -269,6 +273,7 @@ class Model(Trainer, Layer):
|
||||
return
|
||||
saving_lib.save_model(self, filepath)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def save_weights(self, filepath, overwrite=True):
|
||||
if not str(filepath).endswith(".weights.h5"):
|
||||
raise ValueError(
|
||||
@ -285,6 +290,7 @@ class Model(Trainer, Layer):
|
||||
return
|
||||
saving_lib.save_weights_only(self, filepath)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def load_weights(self, filepath, skip_mismatch=False, **kwargs):
|
||||
saving_api.load_weights(
|
||||
self, filepath, skip_mismatch=skip_mismatch, **kwargs
|
||||
@ -335,6 +341,7 @@ class Model(Trainer, Layer):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def export(self, filepath):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,8 +1,6 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import keras_core
|
||||
from keras_core import testing
|
||||
from keras_core.operations import numpy as knp
|
||||
from keras_core.random import random
|
||||
@ -63,23 +61,3 @@ class RandomTest(testing.TestCase, parameterized.TestCase):
|
||||
x_res = random.dropout(x, rate=0.8, seed=0)
|
||||
self.assertGreater(knp.max(x_res), knp.max(x))
|
||||
self.assertGreater(knp.sum(x_res == 0), 2)
|
||||
|
||||
@pytest.mark.skipif(
|
||||
keras_core.backend.backend() != "jax",
|
||||
reason="This test requires `jax` as the backend.",
|
||||
)
|
||||
def test_dropout_jax_jit_stateless(self):
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
x = knp.ones(3)
|
||||
|
||||
@jax.jit
|
||||
def train_step(x):
|
||||
with keras_core.backend.StatelessScope():
|
||||
x = keras_core.layers.Dropout(rate=0.1)(x, training=True)
|
||||
return x
|
||||
|
||||
keras_core.utils.traceback_utils.disable_traceback_filtering()
|
||||
x = train_step(x)
|
||||
assert isinstance(x, jnp.ndarray)
|
||||
|
@ -16,10 +16,7 @@ class SeedGenerator:
|
||||
)
|
||||
|
||||
def seed_initializer(*args, **kwargs):
|
||||
from keras_core.backend import convert_to_tensor
|
||||
|
||||
dtype = kwargs.get("dtype", None)
|
||||
return convert_to_tensor([seed, 0], dtype=dtype)
|
||||
return [seed, 0]
|
||||
|
||||
self.state = Variable(
|
||||
seed_initializer,
|
||||
|
@ -176,25 +176,9 @@ class PyDatasetAdapter(DataAdapter):
|
||||
self.shuffle = shuffle
|
||||
|
||||
# Grab the first example
|
||||
data = self.py_dataset[0]
|
||||
if not isinstance(data, tuple):
|
||||
raise ValueError(
|
||||
"PyDataset.__getitem__() must return a tuple, either "
|
||||
"(input,) or (inputs, targets) or "
|
||||
"(inputs, targets, sample_weights). "
|
||||
f"Received: {data}"
|
||||
)
|
||||
if self.class_weight is not None:
|
||||
if len(data) == 3:
|
||||
raise ValueError(
|
||||
"You cannot `class_weight` and `sample_weight` "
|
||||
"at the same time."
|
||||
)
|
||||
if len(data) == 2:
|
||||
sw = data_adapter_utils.class_weight_to_sample_weights(
|
||||
data[1], class_weight
|
||||
)
|
||||
data = data + (sw,)
|
||||
batch = self.py_dataset[0]
|
||||
# Run checks on it and format it
|
||||
batch = self._standardize_batch(batch)
|
||||
|
||||
def get_tensor_spec(x):
|
||||
shape = x.shape
|
||||
@ -208,7 +192,32 @@ class PyDatasetAdapter(DataAdapter):
|
||||
shape[0] = None # The batch size is not guaranteed to be static.
|
||||
return tf.TensorSpec(shape=shape, dtype=x.dtype.name)
|
||||
|
||||
self._output_signature = tf.nest.map_structure(get_tensor_spec, data)
|
||||
self._output_signature = tf.nest.map_structure(get_tensor_spec, batch)
|
||||
|
||||
def _standardize_batch(self, batch):
|
||||
if isinstance(batch, np.ndarray):
|
||||
batch = (batch,)
|
||||
if isinstance(batch, list):
|
||||
batch = tuple(batch)
|
||||
if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}:
|
||||
raise ValueError(
|
||||
"PyDataset.__getitem__() must return a tuple, either "
|
||||
"(input,) or (inputs, targets) or "
|
||||
"(inputs, targets, sample_weights). "
|
||||
f"Received: {str(batch)[:100]}... of type {type(batch)}"
|
||||
)
|
||||
if self.class_weight is not None:
|
||||
if len(batch) == 3:
|
||||
raise ValueError(
|
||||
"You cannot specify `class_weight` "
|
||||
"and `sample_weight` at the same time."
|
||||
)
|
||||
if len(batch) == 2:
|
||||
sw = data_adapter_utils.class_weight_to_sample_weights(
|
||||
batch[1], class_weight
|
||||
)
|
||||
batch = batch + (sw,)
|
||||
return batch
|
||||
|
||||
def _make_multiprocessed_generator_fn(self):
|
||||
workers = self.py_dataset.workers
|
||||
@ -244,11 +253,7 @@ class PyDatasetAdapter(DataAdapter):
|
||||
def get_numpy_iterator(self):
|
||||
gen_fn = self._make_multiprocessed_generator_fn()
|
||||
for i, batch in enumerate(gen_fn()):
|
||||
if len(batch) == 2 and self.class_weight is not None:
|
||||
sw = data_adapter_utils.class_weight_to_sample_weights(
|
||||
batch[1], self.class_weight
|
||||
)
|
||||
batch = batch + (sw,)
|
||||
batch = self._standardize_batch(batch)
|
||||
yield batch
|
||||
if i >= len(self.py_dataset) - 1 and self.enqueuer:
|
||||
self.enqueuer.stop()
|
||||
@ -266,11 +271,11 @@ class PyDatasetAdapter(DataAdapter):
|
||||
def on_epoch_end(self):
|
||||
if self.enqueuer:
|
||||
self.enqueuer.stop()
|
||||
self._py_dataset.on_epoch_end()
|
||||
self.py_dataset.on_epoch_end()
|
||||
|
||||
@property
|
||||
def num_batches(self):
|
||||
return len(self._py_dataset)
|
||||
return len(self.py_dataset)
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
|
@ -7,6 +7,7 @@ from keras_core import optimizers
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.trainers.compile_utils import CompileLoss
|
||||
from keras_core.trainers.compile_utils import CompileMetrics
|
||||
from keras_core.utils import traceback_utils
|
||||
from keras_core.utils import tracking
|
||||
|
||||
|
||||
@ -18,6 +19,7 @@ class Trainer:
|
||||
self.compiled = False
|
||||
self.steps_per_execution = 1
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def compile(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user