Merge branch 'main' of github.com:keras-team/keras-core
This commit is contained in:
parent
a8c426fc59
commit
e82672951a
@ -28,7 +28,10 @@ class KerasVariable:
|
||||
raise NotImplementedError
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
|
||||
return (
|
||||
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
|
||||
"name={self.name}>"
|
||||
)
|
||||
|
||||
|
||||
ALLOWED_DTYPES = {
|
||||
@ -70,10 +73,12 @@ def standardize_shape(shape, fully_defined=False):
|
||||
continue
|
||||
if not isinstance(e, int):
|
||||
raise ValueError(
|
||||
f"Cannot convert '{shape}' to a shape. Found invalid entry '{e}'"
|
||||
f"Cannot convert '{shape}' to a shape. "
|
||||
f"Found invalid entry '{e}'"
|
||||
)
|
||||
if e < 0:
|
||||
raise ValueError(
|
||||
f"Cannot convert '{shape}' to a shape. Negative dimensions are not allowed."
|
||||
f"Cannot convert '{shape}' to a shape. "
|
||||
"Negative dimensions are not allowed."
|
||||
)
|
||||
return shape
|
||||
|
@ -214,7 +214,8 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
break
|
||||
|
||||
# Update variable values
|
||||
# NOTE: doing this after each step would be a big performance bottleneck.
|
||||
# NOTE: doing this after each step would be a big performance
|
||||
# bottleneck.
|
||||
for ref_v, v in zip(self.trainable_variables, trainable_variables):
|
||||
ref_v.assign(v)
|
||||
for ref_v, v in zip(
|
||||
@ -399,7 +400,8 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
metrics_variables,
|
||||
)
|
||||
logs, state = test_step(state, data)
|
||||
# Note that trainable variables are not returned since they're immutable here.
|
||||
# Note that trainable variables are not returned since they're
|
||||
# immutable here.
|
||||
non_trainable_variables, metrics_variables = state
|
||||
|
||||
callbacks.on_test_batch_end(step, logs)
|
||||
@ -434,7 +436,7 @@ class JAXTrainer(base_trainer.Trainer):
|
||||
# Build the model on one batch of data.
|
||||
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
|
||||
# Build model
|
||||
y_pred = self(data)
|
||||
self(data)
|
||||
break
|
||||
|
||||
# Container that configures and calls callbacks.
|
||||
|
@ -21,7 +21,10 @@ class KerasTensor:
|
||||
return operations.Cast(dtype=dtype)(self)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<KerasTensor shape={self.shape}, dtype={self.dtype}, name={self.name}>"
|
||||
return (
|
||||
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
|
||||
"name={self.name}>"
|
||||
)
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError(
|
||||
|
@ -28,7 +28,8 @@ class StatelessScope:
|
||||
"Invalid variable value in VariableSwapScope: "
|
||||
"all values in argument `mapping` must be tensors with "
|
||||
"a shape that matches the corresponding variable shape. "
|
||||
f"For variable {k}, received invalid value {v} with shape {v.shape}."
|
||||
f"For variable {k}, received invalid value {v} with shape "
|
||||
f"{v.shape}."
|
||||
)
|
||||
self.state_mapping[id(k)] = v
|
||||
|
||||
|
@ -58,7 +58,7 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor):
|
||||
def ndim(self):
|
||||
return self.value.ndim
|
||||
|
||||
def numpy(self):
|
||||
def numpy(self): # noqa: F811
|
||||
return self.value.numpy()
|
||||
|
||||
# Overload native accessor.
|
||||
|
@ -26,8 +26,8 @@ class CallbackList(Callback):
|
||||
|
||||
Args:
|
||||
callbacks: List of `Callback` instances.
|
||||
add_history: Whether a `History` callback should be added, if one does
|
||||
not already exist in the `callbacks` list.
|
||||
add_history: Whether a `History` callback should be added, if one
|
||||
does not already exist in the `callbacks` list.
|
||||
add_progbar: Whether a `ProgbarLogger` callback should be added, if
|
||||
one does not already exist in the `callbacks` list.
|
||||
model: The `Model` these callbacks are used with.
|
||||
|
@ -57,9 +57,9 @@ class Zeros(Initializer):
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
|
||||
are supported. If not specified, `keras_core.backend.floatx()` is
|
||||
used, which default to `float32` unless you configured it otherwise
|
||||
(via `keras_core.backend.set_floatx(float_dtype)`).
|
||||
are supported. If not specified, `keras_core.backend.floatx()`
|
||||
is used, which default to `float32` unless you configured it
|
||||
otherwise (via `keras_core.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
dtype = standardize_dtype(dtype)
|
||||
@ -89,9 +89,9 @@ class Ones(Initializer):
|
||||
Args:
|
||||
shape: Shape of the tensor.
|
||||
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
|
||||
are supported. If not specified, `keras_core.backend.floatx()` is
|
||||
used, which default to `float32` unless you configured it otherwise
|
||||
(via `keras_core.backend.set_floatx(float_dtype)`).
|
||||
are supported. If not specified, `keras_core.backend.floatx()`
|
||||
is used, which default to `float32` unless you configured it
|
||||
otherwise (via `keras_core.backend.set_floatx(float_dtype)`).
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
dtype = standardize_dtype(dtype)
|
||||
|
@ -41,8 +41,8 @@ class Initializer:
|
||||
|
||||
Note that we don't have to implement `from_config()` in the example above
|
||||
since the constructor arguments of the class the keys in the config returned
|
||||
by `get_config()` are the same. In this case, the default `from_config()` works
|
||||
fine.
|
||||
by `get_config()` are the same. In this case, the default `from_config()`
|
||||
works fine.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
|
@ -17,11 +17,13 @@ class InputLayer(Layer):
|
||||
super().__init__(name=name)
|
||||
if shape is not None and batch_shape is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass both `shape` and `batch_shape` at the same time."
|
||||
"You cannot pass both `shape` and `batch_shape` at the "
|
||||
"same time."
|
||||
)
|
||||
if batch_size is not None and batch_shape is not None:
|
||||
raise ValueError(
|
||||
"You cannot pass both `batch_size` and `batch_shape` at the same time."
|
||||
"You cannot pass both `batch_size` and `batch_shape` at the "
|
||||
"same time."
|
||||
)
|
||||
if shape is None and batch_shape is None:
|
||||
raise ValueError("You must pass a `shape` argument.")
|
||||
@ -36,7 +38,8 @@ class InputLayer(Layer):
|
||||
if not isinstance(input_tensor, backend.KerasTensor):
|
||||
raise ValueError(
|
||||
"Argument `input_tensor` must be a KerasTensor. "
|
||||
f"Received invalid type: input_tensor={input_tensor} (of type {type(input_tensor)})"
|
||||
f"Received invalid type: input_tensor={input_tensor} "
|
||||
f"(of type {type(input_tensor)})"
|
||||
)
|
||||
else:
|
||||
input_tensor = backend.KerasTensor(
|
||||
|
@ -7,8 +7,8 @@ from keras_core.api_export import keras_core_export
|
||||
class Dropout(layers.Layer):
|
||||
"""Applies dropout to the input.
|
||||
|
||||
The `Dropout` layer randomly sets input units to 0 with a frequency of `rate`
|
||||
at each step during training time, which helps prevent overfitting.
|
||||
The `Dropout` layer randomly sets input units to 0 with a frequency of
|
||||
`rate` at each step during training time, which helps prevent overfitting.
|
||||
Inputs not set to 0 are scaled up by `1 / (1 - rate)` such that the sum over
|
||||
all inputs is unchanged.
|
||||
|
||||
|
@ -4,6 +4,10 @@ from keras_core.losses.losses import LossFunctionWrapper
|
||||
from keras_core.losses.losses import MeanSquaredError
|
||||
|
||||
|
||||
def deserialize(obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@keras_core_export("keras_core.losses.get")
|
||||
def get(identifier):
|
||||
"""Retrieves a Keras loss as a `function`/`Loss` class instance.
|
||||
@ -29,8 +33,8 @@ def get(identifier):
|
||||
|
||||
Args:
|
||||
identifier: A loss identifier. One of None or string name of a loss
|
||||
function/class or loss configuration dictionary or a loss function or a
|
||||
loss class instance.
|
||||
function/class or loss configuration dictionary or a loss function
|
||||
or a loss class instance.
|
||||
|
||||
Returns:
|
||||
A Keras loss as a `function`/ `Loss` class instance.
|
||||
|
@ -54,7 +54,7 @@ class Loss:
|
||||
|
||||
def standardize_reduction(reduction):
|
||||
allowed = {"sum_over_batch_size", "sum", None}
|
||||
if not reduction in allowed:
|
||||
if reduction not in allowed:
|
||||
raise ValueError(
|
||||
"Invalid value for argument `reduction`. "
|
||||
f"Expected on of {allowed}. Received: "
|
||||
|
@ -106,7 +106,8 @@ class LossTest(testing.TestCase):
|
||||
loss,
|
||||
)
|
||||
|
||||
# @testing.parametrize("uprank", ["mask", "sample_weight", "y_true", "y_pred"])
|
||||
# @testing.parametrize(
|
||||
# "uprank", ["mask", "sample_weight", "y_true", "y_pred"])
|
||||
# TODO: use parameterization decorator
|
||||
def test_rank_adjustment(self):
|
||||
for uprank in ["mask", "sample_weight", "ys"]:
|
||||
|
@ -6,6 +6,10 @@ from keras_core.metrics.reduction_metrics import Sum
|
||||
from keras_core.metrics.regression_metrics import MeanSquaredError
|
||||
|
||||
|
||||
def deserialize(obj):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.get")
|
||||
def get(identifier):
|
||||
"""Retrieves a Keras metric as a `function`/`Metric` class instance.
|
||||
@ -31,8 +35,8 @@ def get(identifier):
|
||||
|
||||
Args:
|
||||
identifier: A metric identifier. One of None or string name of a metric
|
||||
function/class or metric configuration dictionary or a metric function
|
||||
or a metric class instance
|
||||
function/class or metric configuration dictionary or a metric
|
||||
function or a metric class instance
|
||||
|
||||
Returns:
|
||||
A Keras metric as a `function`/ `Metric` class instance.
|
||||
|
@ -19,7 +19,7 @@ class MeanSquaredErrorTest(testing.TestCase):
|
||||
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
|
||||
)
|
||||
|
||||
update_op = mse_obj.update_state(y_true, y_pred)
|
||||
mse_obj.update_state(y_true, y_pred)
|
||||
result = mse_obj.result()
|
||||
self.assertAllClose(0.5, result, atol=1e-5)
|
||||
|
||||
|
@ -148,7 +148,6 @@ class Functional(Function, Model):
|
||||
|
||||
def _adjust_input_rank(self, flat_inputs):
|
||||
flat_ref_shapes = [x.shape for x in self._inputs]
|
||||
names = [x.name for x in self._inputs]
|
||||
adjusted = []
|
||||
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
|
||||
x_rank = len(x.shape)
|
||||
|
@ -334,8 +334,9 @@ def max_pool(
|
||||
dimensions only.
|
||||
pool_size: int or tuple/list of integers of size
|
||||
`len(inputs_spatial_shape)`, specifying the size of the pooling
|
||||
window for each spatial dimension of the input tensor. If `pool_size`
|
||||
is int, then every spatial dimension shares the same `pool_size`.
|
||||
window for each spatial dimension of the input tensor. If
|
||||
`pool_size` is int, then every spatial dimension shares the same
|
||||
`pool_size`.
|
||||
strides: int or tuple/list of integers of size
|
||||
`len(inputs_spatial_shape)`. The stride of the sliding window for
|
||||
each spatial dimension of the input tensor. If `strides` is int,
|
||||
@ -435,8 +436,9 @@ def average_pool(
|
||||
dimensions only.
|
||||
pool_size: int or tuple/list of integers of size
|
||||
`len(inputs_spatial_shape)`, specifying the size of the pooling
|
||||
window for each spatial dimension of the input tensor. If `pool_size`
|
||||
is int, then every spatial dimension shares the same `pool_size`.
|
||||
window for each spatial dimension of the input tensor. If
|
||||
`pool_size` is int, then every spatial dimension shares the same
|
||||
`pool_size`.
|
||||
strides: int or tuple/list of integers of size
|
||||
`len(inputs_spatial_shape)`. The stride of the sliding window for
|
||||
each spatial dimension of the input tensor. If `strides` is int,
|
||||
@ -511,7 +513,8 @@ class Conv(Operation):
|
||||
raise ValueError(
|
||||
"Dilation must be None, scalar or tuple/list of length of "
|
||||
"inputs' spatial shape, but received "
|
||||
f"`dilation_rate={self.dilation_rate}` and input of shape {input_shape}."
|
||||
f"`dilation_rate={self.dilation_rate}` and "
|
||||
f"input of shape {input_shape}."
|
||||
)
|
||||
spatial_shape = np.array(spatial_shape)
|
||||
kernel_spatial_shape = np.array(kernel.shape[:-2])
|
||||
@ -651,7 +654,8 @@ class DepthwiseConv(Operation):
|
||||
raise ValueError(
|
||||
"Dilation must be None, scalar or tuple/list of length of "
|
||||
"inputs' spatial shape, but received "
|
||||
f"`dilation_rate={self.dilation_rate}` and input of shape {input_shape}."
|
||||
f"`dilation_rate={self.dilation_rate}` and input of "
|
||||
f"shape {input_shape}."
|
||||
)
|
||||
spatial_shape = np.array(spatial_shape)
|
||||
kernel_spatial_shape = np.array(kernel.shape[:-2])
|
||||
|
@ -281,6 +281,10 @@ class NNOpsDynamicShapeTest(testing.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend() != "tensorflow",
|
||||
reason="Not have other backend support yet.",
|
||||
)
|
||||
class NNOpsStaticShapeTest(testing.TestCase):
|
||||
def test_relu(self):
|
||||
x = KerasTensor([1, 2, 3])
|
||||
@ -543,6 +547,10 @@ class NNOpsStaticShapeTest(testing.TestCase):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
backend() != "tensorflow",
|
||||
reason="Not have other backend support yet.",
|
||||
)
|
||||
class NNOpsCorrectnessTest(testing.TestCase):
|
||||
def test_relu(self):
|
||||
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)
|
||||
|
@ -10,9 +10,9 @@ class Node:
|
||||
"""A `Node` describes an operation `__call__()` event.
|
||||
|
||||
A Keras Function is a DAG with `Node` instances as nodes, and
|
||||
`KerasTensor` instances as edges. Nodes aren't `Operation` instances, because a
|
||||
single operation could be called multiple times, which would result in graph
|
||||
cycles.
|
||||
`KerasTensor` instances as edges. Nodes aren't `Operation` instances,
|
||||
because a single operation could be called multiple times, which would
|
||||
result in graph cycles.
|
||||
|
||||
A `__call__()` event involves input tensors (and other input arguments),
|
||||
the operation that was called, and the resulting output tensors.
|
||||
@ -89,7 +89,11 @@ class Node:
|
||||
|
||||
@property
|
||||
def parent_nodes(self):
|
||||
"""Returns all the `Node`s whose output this node immediately depends on."""
|
||||
"""The parent `Node`s.
|
||||
|
||||
Returns:
|
||||
all the `Node`s whose output this node immediately depends on.
|
||||
"""
|
||||
node_deps = []
|
||||
for kt in self.arguments.keras_tensors:
|
||||
op = kt._keras_history.operation
|
||||
@ -116,9 +120,9 @@ class KerasHistory(
|
||||
operation: The Operation instance that produced the Tensor.
|
||||
node_index: The specific call to the Operation that produced this Tensor.
|
||||
Operations can be called multiple times in order to share weights. A new
|
||||
node is created every time an Operation is called. The corresponding node
|
||||
that represents the call event that produced the Tensor can be found at
|
||||
`op._inbound_nodes[node_index]`.
|
||||
node is created every time an Operation is called. The corresponding
|
||||
node that represents the call event that produced the Tensor can be
|
||||
found at `op._inbound_nodes[node_index]`.
|
||||
tensor_index: The output index for this Tensor.
|
||||
Always zero if the Operation that produced this Tensor
|
||||
only has one output. Nested structures of
|
||||
|
@ -1716,7 +1716,8 @@ class Meshgrid(Operation):
|
||||
super().__init__()
|
||||
if indexing not in ("xy", "ij"):
|
||||
raise ValueError(
|
||||
"Valid values for `indexing` are 'xy' and 'ij', but received {index}."
|
||||
"Valid values for `indexing` are 'xy' and 'ij', "
|
||||
"but received {index}."
|
||||
)
|
||||
self.indexing = indexing
|
||||
|
||||
|
@ -37,7 +37,8 @@ class Operation:
|
||||
return backend.compute_output_spec(self.call, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
"Could not automatically infer the output shape / dtype of this operation. "
|
||||
"Could not automatically infer the output shape / dtype of "
|
||||
"this operation. "
|
||||
"Please implement the `compute_output_spec` method "
|
||||
f"on your object ({self.__class__.__name__}). "
|
||||
f"Error encountered: {e}"
|
||||
|
@ -287,7 +287,7 @@ class Optimizer:
|
||||
def _filter_empty_gradients(self, grads_and_vars):
|
||||
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
|
||||
if not filtered:
|
||||
raise ValueError(f"No gradients provided for any variable.")
|
||||
raise ValueError("No gradients provided for any variable.")
|
||||
if len(filtered) < len(grads_and_vars):
|
||||
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
|
||||
warnings.warn(
|
||||
|
@ -30,5 +30,5 @@ class RegularizersTest(testing.TestCase):
|
||||
def test_orthogonal_regularizer(self):
|
||||
value = np.random.random((4, 4))
|
||||
x = backend.Variable(value)
|
||||
y = regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
|
||||
regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
|
||||
# TODO
|
||||
|
@ -261,7 +261,8 @@ def serialize_with_public_class(cls, inner_config=None):
|
||||
Keras API or has been registered as serializable via
|
||||
`keras_core.saving.register_keras_serializable()`.
|
||||
"""
|
||||
# This gets the `keras_core.*` exported name, such as "keras_core.optimizers.Adam".
|
||||
# This gets the `keras_core.*` exported name, such as
|
||||
# "keras_core.optimizers.Adam".
|
||||
keras_api_name = api_export.get_name_from_symbol(cls)
|
||||
|
||||
# Case of custom or unknown class object
|
||||
@ -293,8 +294,8 @@ def serialize_with_public_fn(fn, config, fn_module_name=None):
|
||||
|
||||
Called to check and retrieve the config of any function that has a public
|
||||
Keras API or has been registered as serializable via
|
||||
`keras_core.saving.register_keras_serializable()`. If function's module name is
|
||||
already known, returns corresponding config.
|
||||
`keras_core.saving.register_keras_serializable()`. If function's module name
|
||||
is already known, returns corresponding config.
|
||||
"""
|
||||
if fn_module_name:
|
||||
return {
|
||||
@ -373,9 +374,9 @@ def deserialize_keras_object(
|
||||
- `module`: String. The path of the python module. Built-in Keras classes
|
||||
expect to have prefix `keras_core`.
|
||||
- `registered_name`: String. The key the class is registered under via
|
||||
`keras_core.saving.register_keras_serializable(package, name)` API. The key has
|
||||
the format of '{package}>{name}', where `package` and `name` are the
|
||||
arguments passed to `register_keras_serializable()`. If `name` is not
|
||||
`keras_core.saving.register_keras_serializable(package, name)` API. The
|
||||
key has the format of '{package}>{name}', where `package` and `name` are
|
||||
the arguments passed to `register_keras_serializable()`. If `name` is not
|
||||
provided, it uses the class name. If `registered_name` successfully
|
||||
resolves to a class (that was registered), the `class_name` and `config`
|
||||
values in the dict will not be used. `registered_name` is only used for
|
||||
|
@ -410,7 +410,7 @@ class CompileLoss(losses_module.Loss):
|
||||
raise ValueError(
|
||||
f"When there is only a single output, the `loss_weights` argument "
|
||||
"must be a Python float. "
|
||||
f"Received instead:\loss_weights={loss_weights} of type {type(loss_weights)}"
|
||||
f"Received instead: loss_weights={loss_weights} of type {type(loss_weights)}"
|
||||
)
|
||||
flat_loss_weights.append(loss_weights)
|
||||
else:
|
||||
@ -447,7 +447,7 @@ class CompileLoss(losses_module.Loss):
|
||||
"For a model with multiple outputs, "
|
||||
f"when providing the `loss_weights` argument as a list, "
|
||||
"it should have as many entries as the model has outputs. "
|
||||
f"Received:\loss_weights={loss_weights}\nof length {len(loss_weights)} "
|
||||
f"Received: loss_weights={loss_weights} of length {len(loss_weights)} "
|
||||
f"whereas the model has {len(y_pred)} outputs."
|
||||
)
|
||||
if not all(isinstance(e, float) for e in loss_weights):
|
||||
@ -503,7 +503,7 @@ class CompileLoss(losses_module.Loss):
|
||||
raise ValueError(
|
||||
f"In the dict argument `loss_weights`, key "
|
||||
f"'{name}' does not correspond to any model output. "
|
||||
f"Received:\loss_weights={loss_weights}"
|
||||
f"Received: loss_weights={loss_weights}"
|
||||
)
|
||||
if not isinstance(loss_weights[name], float):
|
||||
raise ValueError(
|
||||
|
@ -88,7 +88,8 @@ class TestArrayDataAdapter(testing.TestCase):
|
||||
self.assertEqual(len(batch), 3)
|
||||
bx, by, bw = batch
|
||||
self.assertTrue(isinstance(bx, dict))
|
||||
# NOTE: the y list was converted to a tuple for tf.data compatibility.
|
||||
# NOTE: the y list was converted to a tuple for tf.data
|
||||
# compatibility.
|
||||
self.assertTrue(isinstance(by, tuple))
|
||||
self.assertTrue(isinstance(bw, tuple))
|
||||
|
||||
@ -118,7 +119,8 @@ class TestArrayDataAdapter(testing.TestCase):
|
||||
self.assertEqual(len(batch), 3)
|
||||
bx, by, bw = batch
|
||||
self.assertTrue(isinstance(bx, dict))
|
||||
# NOTE: the y list was converted to a tuple for tf.data compatibility.
|
||||
# NOTE: the y list was converted to a tuple for tf.data
|
||||
# compatibility.
|
||||
self.assertTrue(isinstance(by, tuple))
|
||||
self.assertTrue(isinstance(bw, tuple))
|
||||
|
||||
|
@ -23,8 +23,8 @@ class DataAdapter(object):
|
||||
|
||||
Returns:
|
||||
A `tf.data.Dataset`. Caller might use the dataset in different
|
||||
context, e.g. iter(dataset) in eager to get the value directly, or in
|
||||
graph mode, provide the iterator tensor to Keras model function.
|
||||
context, e.g. iter(dataset) in eager to get the value directly, or
|
||||
in graph mode, provide the iterator tensor to Keras model function.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -38,9 +38,9 @@ class DataAdapter(object):
|
||||
may or may not have an end state.
|
||||
|
||||
Returns:
|
||||
int, the number of batches for the dataset, or None if it is unknown.
|
||||
The caller could use this to control the loop of training, show
|
||||
progress bar, or handle unexpected StopIteration error.
|
||||
int, the number of batches for the dataset, or None if it is
|
||||
unknown. The caller could use this to control the loop of training,
|
||||
show progress bar, or handle unexpected StopIteration error.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -141,10 +141,10 @@ def train_validation_split(arrays, validation_split):
|
||||
The last part of data will become validation data.
|
||||
|
||||
Args:
|
||||
arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
|
||||
of Tensors and NumPy arrays.
|
||||
validation_split: Float between 0 and 1. The proportion of the dataset to
|
||||
include in the validation split. The rest of the dataset will be
|
||||
arrays: Tensors to split. Allowed inputs are arbitrarily nested
|
||||
structures of Tensors and NumPy arrays.
|
||||
validation_split: Float between 0 and 1. The proportion of the dataset
|
||||
to include in the validation split. The rest of the dataset will be
|
||||
included in the training split.
|
||||
|
||||
Returns:
|
||||
@ -158,7 +158,8 @@ def train_validation_split(arrays, validation_split):
|
||||
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
|
||||
if unsplitable:
|
||||
raise ValueError(
|
||||
"Argument `validation_split` is only supported for tf.Tensors or NumPy "
|
||||
"Argument `validation_split` is only supported "
|
||||
"for tf.Tensors or NumPy "
|
||||
"arrays. Found incompatible type in the input: {unsplitable}"
|
||||
)
|
||||
|
||||
|
@ -10,7 +10,8 @@ class TFDatasetAdapter(DataAdapter):
|
||||
def __init__(self, dataset, class_weight=None):
|
||||
if not isinstance(dataset, tf.data.Dataset):
|
||||
raise ValueError(
|
||||
f"Expected argument `dataset` to be a tf.data.Dataset. Received: {dataset}"
|
||||
"Expected argument `dataset` to be a tf.data.Dataset. "
|
||||
f"Received: {dataset}"
|
||||
)
|
||||
if class_weight is not None:
|
||||
dataset = dataset.map(
|
||||
@ -72,7 +73,8 @@ def make_class_weight_map_fn(class_weight):
|
||||
x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data)
|
||||
if sw is not None:
|
||||
raise ValueError(
|
||||
"You cannot `class_weight` and `sample_weight` at the same time."
|
||||
"You cannot `class_weight` and `sample_weight` "
|
||||
"at the same time."
|
||||
)
|
||||
if tf.nest.is_nested(y):
|
||||
raise ValueError(
|
||||
|
@ -138,8 +138,8 @@ class Trainer:
|
||||
loss = self._compile_loss(y, y_pred, sample_weight)
|
||||
if loss is not None:
|
||||
losses.append(loss)
|
||||
for l in self.losses:
|
||||
losses.append(ops.cast(l, dtype=backend.floatx()))
|
||||
for loss in self.losses:
|
||||
losses.append(ops.cast(loss, dtype=backend.floatx()))
|
||||
if len(losses) == 0:
|
||||
raise ValueError(
|
||||
"No loss to compute. Provide a `loss` argument in `compile()`."
|
||||
|
@ -74,7 +74,8 @@ class TestTrainer(testing.TestCase):
|
||||
# Fit the model to make sure compile_metrics are built
|
||||
model.fit(x, y, batch_size=2, epochs=1)
|
||||
|
||||
# The model should have 3 metrics: loss_tracker, compile_metrics, my_metric
|
||||
# The model should have 3 metrics: loss_tracker, compile_metrics,
|
||||
# my_metric.
|
||||
self.assertEqual(len(model.metrics), 3)
|
||||
self.assertEqual(model.metrics[0], model._loss_tracker)
|
||||
self.assertEqual(model.metrics[1], model.my_metric)
|
||||
|
@ -42,7 +42,8 @@ class Tracker:
|
||||
self.tracker = Tracker(
|
||||
# Format: `name: (test_fn, store)`
|
||||
{
|
||||
"variables": (lambda x: isinstance(x, Variable), self._variables),
|
||||
"variables":
|
||||
(lambda x: isinstance(x, Variable), self._variables),
|
||||
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
|
||||
"layers": (lambda x: isinstance(x, Layer), self._layers),
|
||||
}
|
||||
|
@ -1,2 +1,9 @@
|
||||
[tool.black]
|
||||
line-length = 80
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
force_single_line = "True"
|
||||
known_first_party = ["keras_core", "tests"]
|
||||
default_section = "THIRDPARTY"
|
||||
line_length = 80
|
||||
|
28
setup.cfg
28
setup.cfg
@ -15,12 +15,6 @@ addopts=-vv
|
||||
# Do not run tests in the `build` folders
|
||||
norecursedirs = build
|
||||
|
||||
[isort]
|
||||
known_first_party = keras_core,tests
|
||||
default_section = THIRDPARTY
|
||||
line_length = 80
|
||||
profile = black
|
||||
|
||||
[coverage:report]
|
||||
exclude_lines =
|
||||
pragma: no cover
|
||||
@ -44,12 +38,28 @@ ignore =
|
||||
E722
|
||||
# wildcard imports
|
||||
F401,F403
|
||||
# too many "#"
|
||||
E266
|
||||
|
||||
exclude =
|
||||
*_pb2.py
|
||||
*_pb2_grpc.py
|
||||
|
||||
#imported but unused in __init__.py, that's ok.
|
||||
per-file-ignores = **/__init__.py:F401
|
||||
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# import not used
|
||||
**/__init__.py:F401
|
||||
# TODO: remove the following files when long lines are reformatted.
|
||||
keras_core/trainers/trainer.py:E501
|
||||
keras_core/trainers/epoch_iterator.py:E501
|
||||
keras_core/trainers/data_adapters/array_data_adapter.py:E501
|
||||
keras_core/trainers/compile_utils.py:E501
|
||||
keras_core/saving/object_registration.py:E501
|
||||
keras_core/regularizers/regularizers.py:E501
|
||||
keras_core/optimizers/optimizer.py:E501
|
||||
keras_core/operations/function.py:E501
|
||||
keras_core/models/sequential.py:E501
|
||||
keras_core/models/functional.py:E501
|
||||
keras_core/layers/layer.py:E501
|
||||
keras_core/initializers/random_initializers.py:E501
|
||||
max-line-length = 80
|
||||
|
@ -3,7 +3,7 @@
|
||||
base_dir=$(dirname $(dirname $0))
|
||||
targets="${base_dir}/*.py ${base_dir}/keras_core/"
|
||||
|
||||
isort --sp "${base_dir}/setup.cfg" --sl ${targets}
|
||||
black --line-length 80 ${targets}
|
||||
isort --sp "${base_dir}/pyproject.toml" ${targets}
|
||||
black --config "${base_dir}/pyproject.toml" ${targets}
|
||||
|
||||
flake8 --config "${base_dir}/setup.cfg" --max-line-length=200 ${targets}
|
||||
flake8 --config "${base_dir}/setup.cfg" ${targets}
|
||||
|
9
shell/lint.sh
Executable file
9
shell/lint.sh
Executable file
@ -0,0 +1,9 @@
|
||||
#!/bin/bash -e
|
||||
|
||||
base_dir=$(dirname $(dirname $0))
|
||||
targets="${base_dir}/*.py ${base_dir}/keras_core/"
|
||||
|
||||
isort --sp "${base_dir}/pyproject.toml" -c ${targets}
|
||||
black --config "${base_dir}/pyproject.toml" --check ${targets}
|
||||
|
||||
flake8 --config "${base_dir}/setup.cfg" ${targets}
|
Loading…
Reference in New Issue
Block a user