Merge branch 'main' of github.com:keras-team/keras-core

This commit is contained in:
Francois Chollet 2023-04-21 23:16:51 -07:00
parent a8c426fc59
commit e82672951a
36 changed files with 162 additions and 88 deletions

@ -28,7 +28,10 @@ class KerasVariable:
raise NotImplementedError
def __repr__(self):
return f"<KerasVariable shape={self.shape}, dtype={self.dtype}, name={self.name}>"
return (
f"<KerasVariable shape={self.shape}, dtype={self.dtype}, "
"name={self.name}>"
)
ALLOWED_DTYPES = {
@ -70,10 +73,12 @@ def standardize_shape(shape, fully_defined=False):
continue
if not isinstance(e, int):
raise ValueError(
f"Cannot convert '{shape}' to a shape. Found invalid entry '{e}'"
f"Cannot convert '{shape}' to a shape. "
f"Found invalid entry '{e}'"
)
if e < 0:
raise ValueError(
f"Cannot convert '{shape}' to a shape. Negative dimensions are not allowed."
f"Cannot convert '{shape}' to a shape. "
"Negative dimensions are not allowed."
)
return shape

@ -214,7 +214,8 @@ class JAXTrainer(base_trainer.Trainer):
break
# Update variable values
# NOTE: doing this after each step would be a big performance bottleneck.
# NOTE: doing this after each step would be a big performance
# bottleneck.
for ref_v, v in zip(self.trainable_variables, trainable_variables):
ref_v.assign(v)
for ref_v, v in zip(
@ -399,7 +400,8 @@ class JAXTrainer(base_trainer.Trainer):
metrics_variables,
)
logs, state = test_step(state, data)
# Note that trainable variables are not returned since they're immutable here.
# Note that trainable variables are not returned since they're
# immutable here.
non_trainable_variables, metrics_variables = state
callbacks.on_test_batch_end(step, logs)
@ -434,7 +436,7 @@ class JAXTrainer(base_trainer.Trainer):
# Build the model on one batch of data.
for _, data in epoch_iterator.enumerate_epoch(return_type="np"):
# Build model
y_pred = self(data)
self(data)
break
# Container that configures and calls callbacks.

@ -21,7 +21,10 @@ class KerasTensor:
return operations.Cast(dtype=dtype)(self)
def __repr__(self):
return f"<KerasTensor shape={self.shape}, dtype={self.dtype}, name={self.name}>"
return (
f"<KerasTensor shape={self.shape}, dtype={self.dtype}, "
"name={self.name}>"
)
def __iter__(self):
raise NotImplementedError(

@ -28,7 +28,8 @@ class StatelessScope:
"Invalid variable value in VariableSwapScope: "
"all values in argument `mapping` must be tensors with "
"a shape that matches the corresponding variable shape. "
f"For variable {k}, received invalid value {v} with shape {v.shape}."
f"For variable {k}, received invalid value {v} with shape "
f"{v.shape}."
)
self.state_mapping[id(k)] = v

@ -58,7 +58,7 @@ class Variable(KerasVariable, tf.__internal__.types.Tensor):
def ndim(self):
return self.value.ndim
def numpy(self):
def numpy(self): # noqa: F811
return self.value.numpy()
# Overload native accessor.

@ -26,8 +26,8 @@ class CallbackList(Callback):
Args:
callbacks: List of `Callback` instances.
add_history: Whether a `History` callback should be added, if one does
not already exist in the `callbacks` list.
add_history: Whether a `History` callback should be added, if one
does not already exist in the `callbacks` list.
add_progbar: Whether a `ProgbarLogger` callback should be added, if
one does not already exist in the `callbacks` list.
model: The `Model` these callbacks are used with.

@ -57,9 +57,9 @@ class Zeros(Initializer):
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras_core.backend.floatx()` is
used, which default to `float32` unless you configured it otherwise
(via `keras_core.backend.set_floatx(float_dtype)`).
are supported. If not specified, `keras_core.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras_core.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
dtype = standardize_dtype(dtype)
@ -89,9 +89,9 @@ class Ones(Initializer):
Args:
shape: Shape of the tensor.
dtype: Optional dtype of the tensor. Only numeric or boolean dtypes
are supported. If not specified, `keras_core.backend.floatx()` is
used, which default to `float32` unless you configured it otherwise
(via `keras_core.backend.set_floatx(float_dtype)`).
are supported. If not specified, `keras_core.backend.floatx()`
is used, which default to `float32` unless you configured it
otherwise (via `keras_core.backend.set_floatx(float_dtype)`).
**kwargs: Additional keyword arguments.
"""
dtype = standardize_dtype(dtype)

@ -41,8 +41,8 @@ class Initializer:
Note that we don't have to implement `from_config()` in the example above
since the constructor arguments of the class the keys in the config returned
by `get_config()` are the same. In this case, the default `from_config()` works
fine.
by `get_config()` are the same. In this case, the default `from_config()`
works fine.
"""
def __call__(self, shape, dtype=None):

@ -17,11 +17,13 @@ class InputLayer(Layer):
super().__init__(name=name)
if shape is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `shape` and `batch_shape` at the same time."
"You cannot pass both `shape` and `batch_shape` at the "
"same time."
)
if batch_size is not None and batch_shape is not None:
raise ValueError(
"You cannot pass both `batch_size` and `batch_shape` at the same time."
"You cannot pass both `batch_size` and `batch_shape` at the "
"same time."
)
if shape is None and batch_shape is None:
raise ValueError("You must pass a `shape` argument.")
@ -36,7 +38,8 @@ class InputLayer(Layer):
if not isinstance(input_tensor, backend.KerasTensor):
raise ValueError(
"Argument `input_tensor` must be a KerasTensor. "
f"Received invalid type: input_tensor={input_tensor} (of type {type(input_tensor)})"
f"Received invalid type: input_tensor={input_tensor} "
f"(of type {type(input_tensor)})"
)
else:
input_tensor = backend.KerasTensor(

@ -7,8 +7,8 @@ from keras_core.api_export import keras_core_export
class Dropout(layers.Layer):
"""Applies dropout to the input.
The `Dropout` layer randomly sets input units to 0 with a frequency of `rate`
at each step during training time, which helps prevent overfitting.
The `Dropout` layer randomly sets input units to 0 with a frequency of
`rate` at each step during training time, which helps prevent overfitting.
Inputs not set to 0 are scaled up by `1 / (1 - rate)` such that the sum over
all inputs is unchanged.

@ -4,6 +4,10 @@ from keras_core.losses.losses import LossFunctionWrapper
from keras_core.losses.losses import MeanSquaredError
def deserialize(obj):
raise NotImplementedError
@keras_core_export("keras_core.losses.get")
def get(identifier):
"""Retrieves a Keras loss as a `function`/`Loss` class instance.
@ -29,8 +33,8 @@ def get(identifier):
Args:
identifier: A loss identifier. One of None or string name of a loss
function/class or loss configuration dictionary or a loss function or a
loss class instance.
function/class or loss configuration dictionary or a loss function
or a loss class instance.
Returns:
A Keras loss as a `function`/ `Loss` class instance.

@ -54,7 +54,7 @@ class Loss:
def standardize_reduction(reduction):
allowed = {"sum_over_batch_size", "sum", None}
if not reduction in allowed:
if reduction not in allowed:
raise ValueError(
"Invalid value for argument `reduction`. "
f"Expected on of {allowed}. Received: "

@ -106,7 +106,8 @@ class LossTest(testing.TestCase):
loss,
)
# @testing.parametrize("uprank", ["mask", "sample_weight", "y_true", "y_pred"])
# @testing.parametrize(
# "uprank", ["mask", "sample_weight", "y_true", "y_pred"])
# TODO: use parameterization decorator
def test_rank_adjustment(self):
for uprank in ["mask", "sample_weight", "ys"]:

@ -6,6 +6,10 @@ from keras_core.metrics.reduction_metrics import Sum
from keras_core.metrics.regression_metrics import MeanSquaredError
def deserialize(obj):
raise NotImplementedError
@keras_core_export("keras_core.metrics.get")
def get(identifier):
"""Retrieves a Keras metric as a `function`/`Metric` class instance.
@ -31,8 +35,8 @@ def get(identifier):
Args:
identifier: A metric identifier. One of None or string name of a metric
function/class or metric configuration dictionary or a metric function
or a metric class instance
function/class or metric configuration dictionary or a metric
function or a metric class instance
Returns:
A Keras metric as a `function`/ `Metric` class instance.

@ -19,7 +19,7 @@ class MeanSquaredErrorTest(testing.TestCase):
[[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]
)
update_op = mse_obj.update_state(y_true, y_pred)
mse_obj.update_state(y_true, y_pred)
result = mse_obj.result()
self.assertAllClose(0.5, result, atol=1e-5)

@ -148,7 +148,6 @@ class Functional(Function, Model):
def _adjust_input_rank(self, flat_inputs):
flat_ref_shapes = [x.shape for x in self._inputs]
names = [x.name for x in self._inputs]
adjusted = []
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
x_rank = len(x.shape)

@ -334,8 +334,9 @@ def max_pool(
dimensions only.
pool_size: int or tuple/list of integers of size
`len(inputs_spatial_shape)`, specifying the size of the pooling
window for each spatial dimension of the input tensor. If `pool_size`
is int, then every spatial dimension shares the same `pool_size`.
window for each spatial dimension of the input tensor. If
`pool_size` is int, then every spatial dimension shares the same
`pool_size`.
strides: int or tuple/list of integers of size
`len(inputs_spatial_shape)`. The stride of the sliding window for
each spatial dimension of the input tensor. If `strides` is int,
@ -435,8 +436,9 @@ def average_pool(
dimensions only.
pool_size: int or tuple/list of integers of size
`len(inputs_spatial_shape)`, specifying the size of the pooling
window for each spatial dimension of the input tensor. If `pool_size`
is int, then every spatial dimension shares the same `pool_size`.
window for each spatial dimension of the input tensor. If
`pool_size` is int, then every spatial dimension shares the same
`pool_size`.
strides: int or tuple/list of integers of size
`len(inputs_spatial_shape)`. The stride of the sliding window for
each spatial dimension of the input tensor. If `strides` is int,
@ -511,7 +513,8 @@ class Conv(Operation):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={self.dilation_rate}` and input of shape {input_shape}."
f"`dilation_rate={self.dilation_rate}` and "
f"input of shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel.shape[:-2])
@ -651,7 +654,8 @@ class DepthwiseConv(Operation):
raise ValueError(
"Dilation must be None, scalar or tuple/list of length of "
"inputs' spatial shape, but received "
f"`dilation_rate={self.dilation_rate}` and input of shape {input_shape}."
f"`dilation_rate={self.dilation_rate}` and input of "
f"shape {input_shape}."
)
spatial_shape = np.array(spatial_shape)
kernel_spatial_shape = np.array(kernel.shape[:-2])

@ -281,6 +281,10 @@ class NNOpsDynamicShapeTest(testing.TestCase):
)
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Not have other backend support yet.",
)
class NNOpsStaticShapeTest(testing.TestCase):
def test_relu(self):
x = KerasTensor([1, 2, 3])
@ -543,6 +547,10 @@ class NNOpsStaticShapeTest(testing.TestCase):
)
@pytest.mark.skipif(
backend() != "tensorflow",
reason="Not have other backend support yet.",
)
class NNOpsCorrectnessTest(testing.TestCase):
def test_relu(self):
x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)

@ -10,9 +10,9 @@ class Node:
"""A `Node` describes an operation `__call__()` event.
A Keras Function is a DAG with `Node` instances as nodes, and
`KerasTensor` instances as edges. Nodes aren't `Operation` instances, because a
single operation could be called multiple times, which would result in graph
cycles.
`KerasTensor` instances as edges. Nodes aren't `Operation` instances,
because a single operation could be called multiple times, which would
result in graph cycles.
A `__call__()` event involves input tensors (and other input arguments),
the operation that was called, and the resulting output tensors.
@ -89,7 +89,11 @@ class Node:
@property
def parent_nodes(self):
"""Returns all the `Node`s whose output this node immediately depends on."""
"""The parent `Node`s.
Returns:
all the `Node`s whose output this node immediately depends on.
"""
node_deps = []
for kt in self.arguments.keras_tensors:
op = kt._keras_history.operation
@ -116,9 +120,9 @@ class KerasHistory(
operation: The Operation instance that produced the Tensor.
node_index: The specific call to the Operation that produced this Tensor.
Operations can be called multiple times in order to share weights. A new
node is created every time an Operation is called. The corresponding node
that represents the call event that produced the Tensor can be found at
`op._inbound_nodes[node_index]`.
node is created every time an Operation is called. The corresponding
node that represents the call event that produced the Tensor can be
found at `op._inbound_nodes[node_index]`.
tensor_index: The output index for this Tensor.
Always zero if the Operation that produced this Tensor
only has one output. Nested structures of

@ -1716,7 +1716,8 @@ class Meshgrid(Operation):
super().__init__()
if indexing not in ("xy", "ij"):
raise ValueError(
"Valid values for `indexing` are 'xy' and 'ij', but received {index}."
"Valid values for `indexing` are 'xy' and 'ij', "
"but received {index}."
)
self.indexing = indexing

@ -37,7 +37,8 @@ class Operation:
return backend.compute_output_spec(self.call, *args, **kwargs)
except Exception as e:
raise RuntimeError(
"Could not automatically infer the output shape / dtype of this operation. "
"Could not automatically infer the output shape / dtype of "
"this operation. "
"Please implement the `compute_output_spec` method "
f"on your object ({self.__class__.__name__}). "
f"Error encountered: {e}"

@ -287,7 +287,7 @@ class Optimizer:
def _filter_empty_gradients(self, grads_and_vars):
filtered = [(g, v) for g, v in grads_and_vars if g is not None]
if not filtered:
raise ValueError(f"No gradients provided for any variable.")
raise ValueError("No gradients provided for any variable.")
if len(filtered) < len(grads_and_vars):
missing_grad_vars = [v for g, v in grads_and_vars if g is None]
warnings.warn(

@ -30,5 +30,5 @@ class RegularizersTest(testing.TestCase):
def test_orthogonal_regularizer(self):
value = np.random.random((4, 4))
x = backend.Variable(value)
y = regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
regularizers.OrthogonalRegularizer(factor=0.1, mode="rows")(x)
# TODO

@ -261,7 +261,8 @@ def serialize_with_public_class(cls, inner_config=None):
Keras API or has been registered as serializable via
`keras_core.saving.register_keras_serializable()`.
"""
# This gets the `keras_core.*` exported name, such as "keras_core.optimizers.Adam".
# This gets the `keras_core.*` exported name, such as
# "keras_core.optimizers.Adam".
keras_api_name = api_export.get_name_from_symbol(cls)
# Case of custom or unknown class object
@ -293,8 +294,8 @@ def serialize_with_public_fn(fn, config, fn_module_name=None):
Called to check and retrieve the config of any function that has a public
Keras API or has been registered as serializable via
`keras_core.saving.register_keras_serializable()`. If function's module name is
already known, returns corresponding config.
`keras_core.saving.register_keras_serializable()`. If function's module name
is already known, returns corresponding config.
"""
if fn_module_name:
return {
@ -373,9 +374,9 @@ def deserialize_keras_object(
- `module`: String. The path of the python module. Built-in Keras classes
expect to have prefix `keras_core`.
- `registered_name`: String. The key the class is registered under via
`keras_core.saving.register_keras_serializable(package, name)` API. The key has
the format of '{package}>{name}', where `package` and `name` are the
arguments passed to `register_keras_serializable()`. If `name` is not
`keras_core.saving.register_keras_serializable(package, name)` API. The
key has the format of '{package}>{name}', where `package` and `name` are
the arguments passed to `register_keras_serializable()`. If `name` is not
provided, it uses the class name. If `registered_name` successfully
resolves to a class (that was registered), the `class_name` and `config`
values in the dict will not be used. `registered_name` is only used for

@ -410,7 +410,7 @@ class CompileLoss(losses_module.Loss):
raise ValueError(
f"When there is only a single output, the `loss_weights` argument "
"must be a Python float. "
f"Received instead:\loss_weights={loss_weights} of type {type(loss_weights)}"
f"Received instead: loss_weights={loss_weights} of type {type(loss_weights)}"
)
flat_loss_weights.append(loss_weights)
else:
@ -447,7 +447,7 @@ class CompileLoss(losses_module.Loss):
"For a model with multiple outputs, "
f"when providing the `loss_weights` argument as a list, "
"it should have as many entries as the model has outputs. "
f"Received:\loss_weights={loss_weights}\nof length {len(loss_weights)} "
f"Received: loss_weights={loss_weights} of length {len(loss_weights)} "
f"whereas the model has {len(y_pred)} outputs."
)
if not all(isinstance(e, float) for e in loss_weights):
@ -503,7 +503,7 @@ class CompileLoss(losses_module.Loss):
raise ValueError(
f"In the dict argument `loss_weights`, key "
f"'{name}' does not correspond to any model output. "
f"Received:\loss_weights={loss_weights}"
f"Received: loss_weights={loss_weights}"
)
if not isinstance(loss_weights[name], float):
raise ValueError(

@ -88,7 +88,8 @@ class TestArrayDataAdapter(testing.TestCase):
self.assertEqual(len(batch), 3)
bx, by, bw = batch
self.assertTrue(isinstance(bx, dict))
# NOTE: the y list was converted to a tuple for tf.data compatibility.
# NOTE: the y list was converted to a tuple for tf.data
# compatibility.
self.assertTrue(isinstance(by, tuple))
self.assertTrue(isinstance(bw, tuple))
@ -118,7 +119,8 @@ class TestArrayDataAdapter(testing.TestCase):
self.assertEqual(len(batch), 3)
bx, by, bw = batch
self.assertTrue(isinstance(bx, dict))
# NOTE: the y list was converted to a tuple for tf.data compatibility.
# NOTE: the y list was converted to a tuple for tf.data
# compatibility.
self.assertTrue(isinstance(by, tuple))
self.assertTrue(isinstance(bw, tuple))

@ -23,8 +23,8 @@ class DataAdapter(object):
Returns:
A `tf.data.Dataset`. Caller might use the dataset in different
context, e.g. iter(dataset) in eager to get the value directly, or in
graph mode, provide the iterator tensor to Keras model function.
context, e.g. iter(dataset) in eager to get the value directly, or
in graph mode, provide the iterator tensor to Keras model function.
"""
raise NotImplementedError
@ -38,9 +38,9 @@ class DataAdapter(object):
may or may not have an end state.
Returns:
int, the number of batches for the dataset, or None if it is unknown.
The caller could use this to control the loop of training, show
progress bar, or handle unexpected StopIteration error.
int, the number of batches for the dataset, or None if it is
unknown. The caller could use this to control the loop of training,
show progress bar, or handle unexpected StopIteration error.
"""
raise NotImplementedError

@ -141,10 +141,10 @@ def train_validation_split(arrays, validation_split):
The last part of data will become validation data.
Args:
arrays: Tensors to split. Allowed inputs are arbitrarily nested structures
of Tensors and NumPy arrays.
validation_split: Float between 0 and 1. The proportion of the dataset to
include in the validation split. The rest of the dataset will be
arrays: Tensors to split. Allowed inputs are arbitrarily nested
structures of Tensors and NumPy arrays.
validation_split: Float between 0 and 1. The proportion of the dataset
to include in the validation split. The rest of the dataset will be
included in the training split.
Returns:
@ -158,7 +158,8 @@ def train_validation_split(arrays, validation_split):
unsplitable = [type(t) for t in flat_arrays if not _can_split(t)]
if unsplitable:
raise ValueError(
"Argument `validation_split` is only supported for tf.Tensors or NumPy "
"Argument `validation_split` is only supported "
"for tf.Tensors or NumPy "
"arrays. Found incompatible type in the input: {unsplitable}"
)

@ -10,7 +10,8 @@ class TFDatasetAdapter(DataAdapter):
def __init__(self, dataset, class_weight=None):
if not isinstance(dataset, tf.data.Dataset):
raise ValueError(
f"Expected argument `dataset` to be a tf.data.Dataset. Received: {dataset}"
"Expected argument `dataset` to be a tf.data.Dataset. "
f"Received: {dataset}"
)
if class_weight is not None:
dataset = dataset.map(
@ -72,7 +73,8 @@ def make_class_weight_map_fn(class_weight):
x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data)
if sw is not None:
raise ValueError(
"You cannot `class_weight` and `sample_weight` at the same time."
"You cannot `class_weight` and `sample_weight` "
"at the same time."
)
if tf.nest.is_nested(y):
raise ValueError(

@ -138,8 +138,8 @@ class Trainer:
loss = self._compile_loss(y, y_pred, sample_weight)
if loss is not None:
losses.append(loss)
for l in self.losses:
losses.append(ops.cast(l, dtype=backend.floatx()))
for loss in self.losses:
losses.append(ops.cast(loss, dtype=backend.floatx()))
if len(losses) == 0:
raise ValueError(
"No loss to compute. Provide a `loss` argument in `compile()`."

@ -74,7 +74,8 @@ class TestTrainer(testing.TestCase):
# Fit the model to make sure compile_metrics are built
model.fit(x, y, batch_size=2, epochs=1)
# The model should have 3 metrics: loss_tracker, compile_metrics, my_metric
# The model should have 3 metrics: loss_tracker, compile_metrics,
# my_metric.
self.assertEqual(len(model.metrics), 3)
self.assertEqual(model.metrics[0], model._loss_tracker)
self.assertEqual(model.metrics[1], model.my_metric)

@ -42,7 +42,8 @@ class Tracker:
self.tracker = Tracker(
# Format: `name: (test_fn, store)`
{
"variables": (lambda x: isinstance(x, Variable), self._variables),
"variables":
(lambda x: isinstance(x, Variable), self._variables),
"metrics": (lambda x: isinstance(x, Metric), self._metrics),
"layers": (lambda x: isinstance(x, Layer), self._layers),
}

@ -1,2 +1,9 @@
[tool.black]
line-length = 80
[tool.isort]
profile = "black"
force_single_line = "True"
known_first_party = ["keras_core", "tests"]
default_section = "THIRDPARTY"
line_length = 80

@ -15,12 +15,6 @@ addopts=-vv
# Do not run tests in the `build` folders
norecursedirs = build
[isort]
known_first_party = keras_core,tests
default_section = THIRDPARTY
line_length = 80
profile = black
[coverage:report]
exclude_lines =
pragma: no cover
@ -44,12 +38,28 @@ ignore =
E722
# wildcard imports
F401,F403
# too many "#"
E266
exclude =
*_pb2.py
*_pb2_grpc.py
#imported but unused in __init__.py, that's ok.
per-file-ignores = **/__init__.py:F401
max-line-length = 80
per-file-ignores =
# import not used
**/__init__.py:F401
# TODO: remove the following files when long lines are reformatted.
keras_core/trainers/trainer.py:E501
keras_core/trainers/epoch_iterator.py:E501
keras_core/trainers/data_adapters/array_data_adapter.py:E501
keras_core/trainers/compile_utils.py:E501
keras_core/saving/object_registration.py:E501
keras_core/regularizers/regularizers.py:E501
keras_core/optimizers/optimizer.py:E501
keras_core/operations/function.py:E501
keras_core/models/sequential.py:E501
keras_core/models/functional.py:E501
keras_core/layers/layer.py:E501
keras_core/initializers/random_initializers.py:E501
max-line-length = 80

@ -3,7 +3,7 @@
base_dir=$(dirname $(dirname $0))
targets="${base_dir}/*.py ${base_dir}/keras_core/"
isort --sp "${base_dir}/setup.cfg" --sl ${targets}
black --line-length 80 ${targets}
isort --sp "${base_dir}/pyproject.toml" ${targets}
black --config "${base_dir}/pyproject.toml" ${targets}
flake8 --config "${base_dir}/setup.cfg" --max-line-length=200 ${targets}
flake8 --config "${base_dir}/setup.cfg" ${targets}

9
shell/lint.sh Executable file

@ -0,0 +1,9 @@
#!/bin/bash -e
base_dir=$(dirname $(dirname $0))
targets="${base_dir}/*.py ${base_dir}/keras_core/"
isort --sp "${base_dir}/pyproject.toml" -c ${targets}
black --config "${base_dir}/pyproject.toml" --check ${targets}
flake8 --config "${base_dir}/setup.cfg" ${targets}