diff --git a/jax_integration_test.py b/jax_integration_test.py index eb7f656c7..712b4aa6e 100644 --- a/jax_integration_test.py +++ b/jax_integration_test.py @@ -1,9 +1,9 @@ from keras_core import backend -from keras_core.layers.layer import Layer -from keras_core.backend import KerasTensor -from keras_core.operations.function import Function from keras_core import initializers +from keras_core.backend import KerasTensor +from keras_core.layers.layer import Layer from keras_core.operations import numpy as knp +from keras_core.operations.function import Function class MiniDense(Layer): diff --git a/jax_training_scratchpad.py b/jax_training_scratchpad.py index f6e563d65..c8558cd6c 100644 --- a/jax_training_scratchpad.py +++ b/jax_training_scratchpad.py @@ -1,11 +1,12 @@ -from keras_core import operations as ops -from keras_core import backend -from keras_core.optimizers import SGD -from keras_core.layers.layer import Layer -from keras_core import initializers import jax -from jax import numpy as jnp import numpy as np +from jax import numpy as jnp + +from keras_core import backend +from keras_core import initializers +from keras_core import operations as ops +from keras_core.layers.layer import Layer +from keras_core.optimizers import SGD class MiniDense(Layer): diff --git a/keras_core/__init__.py b/keras_core/__init__.py index e69de29bb..b8d6c0502 100644 --- a/keras_core/__init__.py +++ b/keras_core/__init__.py @@ -0,0 +1,2 @@ +from keras_core import backend +from keras_core import operations diff --git a/keras_core/api_export.py b/keras_core/api_export.py index 1b91736dd..d07bb0348 100644 --- a/keras_core/api_export.py +++ b/keras_core/api_export.py @@ -1,4 +1,5 @@ import types + from keras_core.saving import register_keras_core_serializable try: diff --git a/keras_core/backend/__init__.py b/keras_core/backend/__init__.py index 1dede3eb3..cf0386510 100644 --- a/keras_core/backend/__init__.py +++ b/keras_core/backend/__init__.py @@ -1,21 +1,20 @@ -import os import json +import os import sys -from keras_core.utils.io_utils import print_msg -from keras_core.backend.keras_tensor import KerasTensor -from keras_core.backend.keras_tensor import is_keras_tensor -from keras_core.backend.keras_tensor import any_symbolic_tensors -from keras_core.backend.config import floatx -from keras_core.backend.config import epsilon -from keras_core.backend.config import image_data_format -from keras_core.backend.config import set_floatx -from keras_core.backend.config import set_epsilon -from keras_core.backend.config import set_image_data_format -from keras_core.backend.common import standardize_shape -from keras_core.backend.common import standardize_dtype from keras_core.backend.common import StatelessScope - +from keras_core.backend.common import standardize_dtype +from keras_core.backend.common import standardize_shape +from keras_core.backend.config import epsilon +from keras_core.backend.config import floatx +from keras_core.backend.config import image_data_format +from keras_core.backend.config import set_epsilon +from keras_core.backend.config import set_floatx +from keras_core.backend.config import set_image_data_format +from keras_core.backend.keras_tensor import KerasTensor +from keras_core.backend.keras_tensor import any_symbolic_tensors +from keras_core.backend.keras_tensor import is_keras_tensor +from keras_core.utils.io_utils import print_msg # Set Keras base dir path given KERAS_HOME env variable, if applicable. # Otherwise either ~/.keras or /tmp. diff --git a/keras_core/backend/common.py b/keras_core/backend/common.py index f32148767..966199e47 100644 --- a/keras_core/backend/common.py +++ b/keras_core/backend/common.py @@ -1,7 +1,9 @@ -from keras_core.backend.config import floatx -from tensorflow import nest import threading +from tensorflow import nest + +from keras_core.backend.config import floatx + class KerasVariable: def __init__(self, value, dtype, trainable=True, name=None): diff --git a/keras_core/backend/jax/__init__.py b/keras_core/backend/jax/__init__.py index ecb79f197..796c32961 100644 --- a/keras_core/backend/jax/__init__.py +++ b/keras_core/backend/jax/__init__.py @@ -1,14 +1,15 @@ -from jax import numpy as jnp +import jax import numpy as np +from jax import numpy as jnp +from tensorflow import nest + from keras_core.backend.common import KerasVariable +from keras_core.backend.common import StatelessScope +from keras_core.backend.common import get_stateless_scope +from keras_core.backend.common import in_stateless_scope from keras_core.backend.common import standardize_dtype from keras_core.backend.keras_tensor import KerasTensor from keras_core.utils.naming import auto_name -from keras_core.backend.common import in_stateless_scope -from keras_core.backend.common import get_stateless_scope -from keras_core.backend.common import StatelessScope -from tensorflow import nest -import jax DYNAMIC_SHAPES_OK = False # Dynamic shapes NG diff --git a/keras_core/backend/jax/random.py b/keras_core/backend/jax/random.py index d21342c6f..8ed06c16e 100644 --- a/keras_core/backend/jax/random.py +++ b/keras_core/backend/jax/random.py @@ -1,4 +1,5 @@ import jax + from keras_core.backend import floatx from keras_core.backend.random import draw_seed diff --git a/keras_core/backend/keras_tensor.py b/keras_core/backend/keras_tensor.py index 90a686073..1fdbd0534 100644 --- a/keras_core/backend/keras_tensor.py +++ b/keras_core/backend/keras_tensor.py @@ -1,6 +1,7 @@ -from keras_core.utils.naming import auto_name from tensorflow import nest +from keras_core.utils.naming import auto_name + class KerasTensor: def __init__(self, shape, dtype="float32", name=None): diff --git a/keras_core/backend/random/__init__.py b/keras_core/backend/random/__init__.py index 386cd274f..898cee974 100644 --- a/keras_core/backend/random/__init__.py +++ b/keras_core/backend/random/__init__.py @@ -1,5 +1,4 @@ from keras_core.backend import backend - from keras_core.backend.random.random_seed_generator import RandomSeedGenerator from keras_core.backend.random.random_seed_generator import draw_seed from keras_core.backend.random.random_seed_generator import make_default_seed diff --git a/keras_core/backend/tensorflow/__init__.py b/keras_core/backend/tensorflow/__init__.py index 86d3c2d8c..a47ad5f1d 100644 --- a/keras_core/backend/tensorflow/__init__.py +++ b/keras_core/backend/tensorflow/__init__.py @@ -1,13 +1,13 @@ import tensorflow as tf from tensorflow.experimental import numpy as tfnp + from keras_core.backend.common import KerasVariable +from keras_core.backend.common import get_stateless_scope +from keras_core.backend.common import in_stateless_scope from keras_core.backend.common import standardize_dtype from keras_core.backend.keras_tensor import KerasTensor -from keras_core.utils.naming import auto_name from keras_core.backend.tensorflow.trainer import Trainer -from keras_core.backend.common import in_stateless_scope -from keras_core.backend.common import get_stateless_scope - +from keras_core.utils.naming import auto_name DYNAMIC_SHAPES_OK = True diff --git a/keras_core/backend/tensorflow/random.py b/keras_core/backend/tensorflow/random.py index 1ed09ce53..9b1498b14 100644 --- a/keras_core/backend/tensorflow/random.py +++ b/keras_core/backend/tensorflow/random.py @@ -1,6 +1,7 @@ -from keras_core.backend.random import draw_seed import tensorflow as tf + from keras_core.backend import floatx +from keras_core.backend.random import draw_seed def tf_draw_seed(seed): diff --git a/keras_core/backend/tensorflow/trainer.py b/keras_core/backend/tensorflow/trainer.py index 73fa32c48..8697a074d 100644 --- a/keras_core/backend/tensorflow/trainer.py +++ b/keras_core/backend/tensorflow/trainer.py @@ -1,6 +1,7 @@ -from keras_core.trainers import trainer import tensorflow as tf +from keras_core.trainers import trainer + class Trainer(trainer.Trainer): def train_step(self, data): diff --git a/keras_core/initializers/__init__.py b/keras_core/initializers/__init__.py index e99bbfb02..65490a83c 100644 --- a/keras_core/initializers/__init__.py +++ b/keras_core/initializers/__init__.py @@ -1,3 +1,3 @@ +from keras_core.initializers.constant_initializers import * from keras_core.initializers.initializer import Initializer from keras_core.initializers.random_initializers import * -from keras_core.initializers.constant_initializers import * diff --git a/keras_core/initializers/constant_initializers.py b/keras_core/initializers/constant_initializers.py index e8eda6e82..18343095b 100644 --- a/keras_core/initializers/constant_initializers.py +++ b/keras_core/initializers/constant_initializers.py @@ -1,6 +1,6 @@ +from keras_core.backend import standardize_dtype from keras_core.initializers.initializer import Initializer from keras_core.operations import numpy as knp -from keras_core.backend import standardize_dtype class Zeros(Initializer): diff --git a/keras_core/initializers/initializer.py b/keras_core/initializers/initializer.py index bdc9cf9c6..7a95d1418 100644 --- a/keras_core/initializers/initializer.py +++ b/keras_core/initializers/initializer.py @@ -1,4 +1,5 @@ import warnings + from keras_core.api_export import keras_core_export diff --git a/keras_core/initializers/random_initializers.py b/keras_core/initializers/random_initializers.py index 8494ee464..165ced222 100644 --- a/keras_core/initializers/random_initializers.py +++ b/keras_core/initializers/random_initializers.py @@ -1,6 +1,7 @@ import math -from keras_core.initializers.initializer import Initializer + from keras_core.backend import random +from keras_core.initializers.initializer import Initializer class VarianceScaling(Initializer): diff --git a/keras_core/layers/input_spec.py b/keras_core/layers/input_spec.py index 99785f52b..57053d000 100644 --- a/keras_core/layers/input_spec.py +++ b/keras_core/layers/input_spec.py @@ -1,5 +1,6 @@ -from keras_core import backend from tensorflow import nest + +from keras_core import backend from keras_core.api_export import keras_core_export diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index d3ce465b4..ede48180c 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -14,21 +14,23 @@ And some more magic: - metric tracking - RNG seed tracking """ -from keras_core.operations.operation import Operation -from keras_core.backend import KerasTensor -from keras_core import backend -from keras_core.utils.tracking import Tracker -from keras_core.metrics.metric import Metric -from keras_core import utils -from keras_core.utils import summary_utils -from keras_core.layers import input_spec -from keras_core.api_export import keras_core_export -from tensorflow import nest -from tensorflow import keras as tf_keras -import numpy as np +import collections import inspect import threading -import collections + +import numpy as np +from tensorflow import keras as tf_keras +from tensorflow import nest + +from keras_core import backend +from keras_core import utils +from keras_core.api_export import keras_core_export +from keras_core.backend import KerasTensor +from keras_core.layers import input_spec +from keras_core.metrics.metric import Metric +from keras_core.operations.operation import Operation +from keras_core.utils import summary_utils +from keras_core.utils.tracking import Tracker # TODO: cache all call signature processing. See layer_utils.CallFunctionSpec() in Keras. diff --git a/keras_core/layers/layer_test.py b/keras_core/layers/layer_test.py index a11272450..ffc36d8e8 100644 --- a/keras_core/layers/layer_test.py +++ b/keras_core/layers/layer_test.py @@ -1,7 +1,8 @@ +import numpy as np + from keras_core import testing from keras_core.engine import keras_tensor from keras_core.layers.layer import Layer -import numpy as np class FunctionTest(testing.TestCase): diff --git a/keras_core/losses/loss.py b/keras_core/losses/loss.py index 5f0f2b0de..70e24c505 100644 --- a/keras_core/losses/loss.py +++ b/keras_core/losses/loss.py @@ -1,8 +1,9 @@ -from keras_core import operations as ops from keras_core import backend from keras_core.utils.naming import auto_name from keras_core.utils import dtype_utils +from keras_core import operations as ops from keras_core.api_export import keras_core_export +from keras_core.utils.naming import auto_name @keras_core_export(["keras_core.Loss", "keras_core.losses.Loss"]) diff --git a/keras_core/losses/loss_test.py b/keras_core/losses/loss_test.py index 28e05d79e..36035c95b 100644 --- a/keras_core/losses/loss_test.py +++ b/keras_core/losses/loss_test.py @@ -1,7 +1,8 @@ +import numpy as np + +from keras_core import operations as ops from keras_core import testing from keras_core.losses.loss import Loss -from keras_core import operations as ops -import numpy as np class ExampleLoss(Loss): diff --git a/keras_core/losses/losses.py b/keras_core/losses/losses.py index 4dacc74ea..ae06a561a 100644 --- a/keras_core/losses/losses.py +++ b/keras_core/losses/losses.py @@ -1,5 +1,5 @@ -from keras_core.losses.loss import Loss from keras_core import operations as ops +from keras_core.losses.loss import Loss from keras_core.losses.loss import squeeze_to_same_rank diff --git a/keras_core/metrics/metric.py b/keras_core/metrics/metric.py index 4b87b6972..5b146d5ef 100644 --- a/keras_core/metrics/metric.py +++ b/keras_core/metrics/metric.py @@ -1,7 +1,7 @@ from keras_core import backend -from keras_core.utils.tracking import Tracker -from keras_core.utils.naming import auto_name from keras_core.api_export import keras_core_export +from keras_core.utils.naming import auto_name +from keras_core.utils.tracking import Tracker @keras_core_export(["keras_core.Metric", "keras_core.metrics.Metric"]) diff --git a/keras_core/metrics/metric_test.py b/keras_core/metrics/metric_test.py index cb24e275a..878928c09 100644 --- a/keras_core/metrics/metric_test.py +++ b/keras_core/metrics/metric_test.py @@ -1,9 +1,10 @@ -from keras_core import testing +import numpy as np + +from keras_core import backend from keras_core import initializers from keras_core import operations as ops -from keras_core import backend +from keras_core import testing from keras_core.metrics.metric import Metric -import numpy as np class ExampleMetric(Metric): diff --git a/keras_core/metrics/regression_metrics.py b/keras_core/metrics/regression_metrics.py index 758c0ef2f..ffcbee696 100644 --- a/keras_core/metrics/regression_metrics.py +++ b/keras_core/metrics/regression_metrics.py @@ -1,6 +1,6 @@ -from keras_core.metrics.metric import Metric from keras_core import backend from keras_core import initializers +from keras_core.metrics.metric import Metric class MeanSquareError(Metric): diff --git a/keras_core/models/functional.py b/keras_core/models/functional.py index 410ac8ead..57d35c5ef 100644 --- a/keras_core/models/functional.py +++ b/keras_core/models/functional.py @@ -1,6 +1,6 @@ -from keras_core.operations.function import Function -from keras_core.models.model import Model from keras_core.layers.layer import Layer +from keras_core.models.model import Model +from keras_core.operations.function import Function class Functional(Function, Model): diff --git a/keras_core/models/model.py b/keras_core/models/model.py index 77b59b92d..4b214ebaf 100644 --- a/keras_core/models/model.py +++ b/keras_core/models/model.py @@ -1,6 +1,6 @@ from keras_core.api_export import keras_core_export -from keras_core.layers.layer import Layer from keras_core.backend import Trainer +from keras_core.layers.layer import Layer @keras_core_export(["keras_core.Model", "keras_core.models.Model"]) diff --git a/keras_core/operations/__init__.py b/keras_core/operations/__init__.py index 40d96a2b5..a78b36bc8 100644 --- a/keras_core/operations/__init__.py +++ b/keras_core/operations/__init__.py @@ -2,11 +2,11 @@ # from keras_core.operations.numpy import Add, add # from keras_core.operations.numpy import Multiply, multiply -from keras_core.operations.numpy import * -from keras_core.operations.nn import * -from keras_core.backend import is_tensor -from keras_core.backend import convert_to_tensor from keras_core.backend import cast -from keras_core.backend import shape from keras_core.backend import cond +from keras_core.backend import convert_to_tensor +from keras_core.backend import is_tensor from keras_core.backend import name_scope +from keras_core.backend import shape +from keras_core.operations.nn import * +from keras_core.operations.numpy import * diff --git a/keras_core/operations/function.py b/keras_core/operations/function.py index 86ef81171..deb1ec52e 100644 --- a/keras_core/operations/function.py +++ b/keras_core/operations/function.py @@ -1,8 +1,10 @@ +import collections + +from tensorflow import nest + from keras_core.backend import KerasTensor from keras_core.operations.operation import Operation from keras_core.utils.naming import auto_name -from tensorflow import nest -import collections class Function(Operation): diff --git a/keras_core/operations/function_test.py b/keras_core/operations/function_test.py index d183f0949..469b6eeb4 100644 --- a/keras_core/operations/function_test.py +++ b/keras_core/operations/function_test.py @@ -1,10 +1,11 @@ +import numpy as np +import pytest + +from keras_core import backend +from keras_core import testing from keras_core.backend import keras_tensor from keras_core.operations import function from keras_core.operations import numpy as knp -from keras_core import testing -from keras_core import backend -import pytest -import numpy as np class FunctionTest(testing.TestCase): diff --git a/keras_core/operations/node.py b/keras_core/operations/node.py index e29f95074..24378d171 100644 --- a/keras_core/operations/node.py +++ b/keras_core/operations/node.py @@ -1,5 +1,7 @@ import collections + from tensorflow import nest + from keras_core.backend import KerasTensor from keras_core.operations.symbolic_arguments import SymbolicArguments diff --git a/keras_core/operations/numpy.py b/keras_core/operations/numpy.py index f2dfb4a5c..aa8496a73 100644 --- a/keras_core/operations/numpy.py +++ b/keras_core/operations/numpy.py @@ -1,201 +1,217 @@ """ MANIFEST: -matmul -add -subtract -multiply -divide -true_divide -power -negative +abs absolute +add +all +amax +amin +append +arange +arccos +arcsin +arctan +arctanh +argmax +argmin +argsort +array +array_equal +average +broadcast_to +ceil +clip +concatenate +conj +conjugate +copy +cos +count_nonzero +cov +cross +cumprod +cumsum +diag +diag_indices +diagonal +diff +divide +dot +dtype +einsum +empty +equal +exp +expand_dims +expm1 +eye +flip +floor +full +full_like +greater +greater_equal +hstack +identity +imag +indices +interp +isclose +isfinite +isin +isinf +isnan +isscalar +issubdtype +issubctype +less +less_equal +linspace +log +log10 +log1p +log2 +logaddexp +logical_and +logical_not +logical_or +logspace +matmul +max +maximum mean -var -zeros +median +meshgrid +mgrid +min +minimum +mod +moveaxis +multiply +nan_to_num +ndim +nonzero +not_equal ones +ones_like +outer +pad +percentile +power +prod +ravel +real +reciprocal +repeat +reshape +roll +round +shape +sign +sin +size +sort +split +sqrt +square +squeeze +stack +std +subtract +sum +swapaxes +take +take_along_axis +tan +tensordot +tile +trace +transpose +tri +tril +triu +true_divide +unique +unrival_index +vdot +vectorize +vstack +where +zeros +zeros_like """ +import numpy as np + +from keras_core import backend from keras_core.backend import KerasTensor from keras_core.backend import any_symbolic_tensors -from keras_core.backend import convert_to_tensor -from keras_core.operations.symbolic_arguments import SymbolicArguments from keras_core.operations.operation import Operation -from keras_core import backend -from tensorflow import nest -import jax - -# TODO: replace this function with one that can handle -# dynamic shapes. -def compute_np_output_spec(op_name, *args, **kwargs): - op = getattr(jax.numpy, op_name) - - def convert_keras_tensor_to_jax_array(x): - if isinstance(x, KerasTensor): - return jax.numpy.zeros(x.shape, dtype=x.dtype) - return x - - args, kwargs = SymbolicArguments(*args, **kwargs).convert( - convert_keras_tensor_to_jax_array - ) - jax_out = jax.eval_shape(op, *args, **kwargs) - - def convert_jax_spec_to_keras_tensor(x): - if isinstance(x, jax.ShapeDtypeStruct): - return KerasTensor(x.shape, x.dtype) - return x - - return nest.map_structure(convert_jax_spec_to_keras_tensor, jax_out) - - -##################### -### Two-input ops ### -##################### - - -### matmul ### - - -class Matmul(Operation): - def call(self, x1, x2): - return backend.execute("matmul", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("matmul", x1, x2) - - -def matmul(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Matmul().symbolic_call(x1, x2) - x1 = convert_to_tensor(x1, x1.dtype) - x2 = convert_to_tensor(x2, x2.dtype) - return backend.execute("matmul", x1, x2) - - -### add ### - - -class Add(Operation): - def call(self, x1, x2): - return backend.execute("add", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("add", x1, x2) - - -def add(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Add().symbolic_call(x1, x2) - return backend.execute("add", x1, x2) - - -### subtract ### - - -class Subtract(Operation): - def call(self, x1, x2): - return backend.execute("subtract", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("subtract", x1, x2) - - -def subtract(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Subtract().symbolic_call(x1, x2) - return backend.execute("subtract", x1, x2) - - -### multiply ### - - -class Multiply(Operation): - def call(self, x1, x2): - return backend.execute("multiply", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("multiply", x1, x2) - - -def multiply(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Multiply().symbolic_call(x1, x2) - return backend.execute("multiply", x1, x2) - - -### divide ### - - -class Divide(Operation): - def call(self, x1, x2): - return backend.execute("divide", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("divide", x1, x2) - - -def divide(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Divide().symbolic_call(x1, x2) - return backend.execute("divide", x1, x2) - - -### true_divide ### - - -class TrueDivide(Operation): - def call(self, x1, x2): - return backend.execute("true_divide", x1, x2) - - def compute_output_spec(self, x1, x2): - return compute_np_output_spec("true_divide", x1, x2) - - -def true_divide(x1, x2): - if any_symbolic_tensors((x1, x2)): - return TrueDivide().symbolic_call(x1, x2) - return backend.execute("true_divide", x1, x2) - - -class Power(Operation): - def call(self, x1, x2): - return backend.execute("power", x1, x2) - - def compute_output_spec(self, x1, x2): - return KerasTensor(x1.shape, dtype=x1.dtype) - - -def power(x1, x2): - if any_symbolic_tensors((x1, x2)): - return Power().symbolic_call(x1, x2) - return backend.execute("power", x1, x2) - - -######################## -### Single-input ops ### -######################## - -### negative ### - - -class Negative(Operation): - def call(self, x): - return backend.execute("negative", x) - - def compute_output_spec(self, x): - return KerasTensor(x.shape, dtype=x.dtype) - - -def negative(x): - if any_symbolic_tensors((x,)): - return Negative().symbolic_call(x) - return backend.execute("negative", x) - - -### absolute ### +def broadcast_shapes(shape1, shape2): + # Broadcast input shapes to a unified shape. + # Convert to list for mutability. + shape1 = list(shape1) + shape2 = list(shape2) + origin_shape1 = shape1 + origin_shape2 = shape2 + + if len(shape1) > len(shape2): + shape2 = [None] * (len(shape1) - len(shape2)) + shape2 + if len(shape1) < len(shape2): + shape1 = [None] * (len(shape2) - len(shape1)) + shape1 + output_shape = list(shape1) + for i in range(len(shape1)): + if shape1[i] == 1: + output_shape[i] = shape2[i] + elif shape1[i] == None: + output_shape[i] = shape2[i] + else: + if shape2[i] == 1 or shape2[i] == None or shape2[i] == shape1[i]: + output_shape[i] = shape1[i] + else: + raise ValueError( + "Cannot broadcast shape, the failure dim has value " + f"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. " + f"Input shapes are: {origin_shape1} and {origin_shape2}." + ) + + return output_shape + + +def reduce_shape(shape, axis=None, keepdims=False): + shape = list(shape) + if axis is None: + if keepdims: + output_shape = [1 for _ in range(shape)] + else: + output_shape = [] + return output_shape + + if keepdims: + for ax in axis: + shape[ax] = 1 + return shape + else: + for ax in axis: + shape[ax] = -1 + output_shape = list(filter((-1).__ne__, shape)) + return output_shape + + +def shape_equal(shape1, shape2, axis=None): + if len(shape1) != len(shape2): + return False + if axis is not None: + shape1 = list(shape1) + shape2 = list(shape2) + for ax in axis: + shape1[ax] = -1 + shape2[ax] = -1 + return shape1 == shape2 class Absolute(Operation): @@ -212,7 +228,297 @@ def absolute(x): return backend.execute("absolute", x) -### square ### +class Abs(Absolute): + pass + + +def abs(x): + return absolute(x) + + +class Add(Operation): + def call(self, x1, x2): + return backend.execute("add", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def add(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Add().symbolic_call(x1, x2) + return backend.execute("add", x1, x2) + + +class All(Operation): + def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + self.axis = [axis] + else: + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.execute( + "all", + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape( + x.shape, + axis=self.axis, + keepdims=self.keepdims, + ), + dtype=x.dtype, + ) + + +def all(x, axis=None, keepdims=False): + if any_symbolic_tensors((x,)): + return All(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.execute("all", x, axis=axis, keepdims=keepdims) + + +class Amax(Operation): + def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.execute( + "amax", + x, + axis=self.axis, + keepdims=self.keepdims, + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +def amax(x, axis=None, keepdims=False): + if any_symbolic_tensors((x,)): + return All(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.execute("amax", x, axis=axis, keepdims=keepdims) + + +class Amin(Operation): + def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + axis = [axis] + self.axis = axis + self.keepdims = keepdims + + def call(self, x): + return backend.execute( + "amin", x, axis=self.axis, keepdims=self.keepdims + ) + + def compute_output_spec(self, x): + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, + ) + + +def amin(x, axis=None, keepdims=False): + if any_symbolic_tensors((x,)): + return All(axis=axis, keepdims=keepdims).symbolic_call(x) + return backend.execute("amin", x, axis=axis, keepdims=keepdims) + + +class Append(Operation): + def __init__(self, axis=None): + super().__init__() + self.axis = axis + + def call(self, x1, x2): + return backend.execute("append", x1, x2, axis=self.axis) + + def compute_output_spec(self, x1, x2): + x1_shape = x1.shape + x2_shape = x2.shape + if self.axis is None: + if None in x1_shape or None in x2_shape: + output_shape = [None] + else: + output_shape = [int(np.prod(x1_shape) + np.prod(x2_shape))] + return KerasTensor(output_shape, dtype=x1.dtype) + + if not shape_equal(x1_shape, x2_shape, [self.axis]): + raise ValueError( + "`append` requires inputs to have the same shape except the " + f"`axis={self.axis}`, but received shape {x1_shape} and " + f"{x2_shape}." + ) + + output_shape = list(x1_shape) + output_shape[self.axis] = x1_shape[self.axis] + x2_shape[self.axis] + return KerasTensor(output_shape, dtype=x1.dtype) + + +def append( + x1, + x2, + axis=None, +): + if any_symbolic_tensors((x1, x2)): + return Append(axis=axis).symbolic_call(x1, x2) + return backend.execute("append", x1, x2, axis=axis) + + +class Matmul(Operation): + def call(self, x1, x2): + return backend.execute("matmul", x1, x2) + + def compute_output_spec(self, x1, x2): + x1_shape = x1.shape + x2_shape = x2.shape + if len(x1_shape) == 1: + x1_shape = (1, x1_shape[0]) + if len(x2_shape) == 1: + x2_shape = (x2_shape[0], 1) + if ( + x1_shape[-1] is not None + and x2_shape[-2] is not None + and x1_shape[-1] != x2_shape[-2] + ): + raise ValueError( + "Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be " + f"equal, but received `x1.shape={x1.shape}` and " + f"`x2.shape={x2.shape}`." + ) + + leading_shape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2]) + last_2_dims_shape = [x1_shape[-2], x2_shape[-1]] + output_shape = leading_shape + last_2_dims_shape + if len(x1.shape) == 1: + del output_shape[-2] + if len(x2.shape) == 1: + del output_shape[-1] + return KerasTensor(output_shape, dtype=x1.dtype) + + +def matmul(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Matmul().symbolic_call(x1, x2) + return backend.execute("matmul", x1, x2) + + +class Subtract(Operation): + def call(self, x1, x2): + return backend.execute("subtract", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def subtract(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Subtract().symbolic_call(x1, x2) + return backend.execute("subtract", x1, x2) + + +class Multiply(Operation): + def call(self, x1, x2): + return backend.execute("multiply", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def multiply(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Multiply().symbolic_call(x1, x2) + return backend.execute("multiply", x1, x2) + + +class Divide(Operation): + def call(self, x1, x2): + return backend.execute("divide", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def divide(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Divide().symbolic_call(x1, x2) + return backend.execute("divide", x1, x2) + + +class TrueDivide(Operation): + def call(self, x1, x2): + return backend.execute("true_divide", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def true_divide(x1, x2): + if any_symbolic_tensors((x1, x2)): + return TrueDivide().symbolic_call(x1, x2) + return backend.execute("true_divide", x1, x2) + + +class Power(Operation): + def call(self, x1, x2): + return backend.execute("power", x1, x2) + + def compute_output_spec(self, x1, x2): + output_shape = broadcast_shapes(x1.shape, x2.shape) + return KerasTensor(output_shape, dtype=x1.dtype) + + +def power(x1, x2): + if any_symbolic_tensors((x1, x2)): + return Power().symbolic_call(x1, x2) + return backend.execute("power", x1, x2) + + +class Negative(Operation): + def call(self, x): + return backend.execute("negative", x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +def negative(x): + if any_symbolic_tensors((x,)): + return Negative().symbolic_call(x) + return backend.execute("negative", x) + + +class Absolute(Operation): + def call(self, x): + return backend.execute("absolute", x) + + def compute_output_spec(self, x): + return KerasTensor(x.shape, dtype=x.dtype) + + +def absolute(x): + if any_symbolic_tensors((x,)): + return Absolute().symbolic_call(x) + return backend.execute("absolute", x) class Square(Operation): @@ -229,58 +535,70 @@ def square(x): return backend.execute("square", x) -##################### -### Reshaping ops ### -##################### - - -### squeeze ### - - class Squeeze(Operation): def __init__(self, axis=None): + super().__init__() self.axis = axis - def call(self, a): - return backend.execute("squeeze", a, axis=self.axis) + def call(self, x): + return backend.execute("squeeze", x, axis=self.axis) - def compute_output_spec(self, a): - return compute_np_output_spec("squeeze", a, axis=self.axis) + def compute_output_spec(self, x, axis=None): + input_shape = list(x.shape) + if axis is None: + output_shape = list(filter((1).__ne__, input_shape)) + return KerasTensor(output_shape) + else: + if input_shape[axis] != 1: + raise ValueError( + f"Cannot squeeze axis {axis}, because the dimension is not " + "1." + ) + del input_shape[axis] + return KerasTensor(input_shape, dtype=x.dtype) -def squeeze(a, axis=None): - if any_symbolic_tensors((a,)): - return Squeeze().symbolic_call(a, axis=axis) - return backend.execute("squeeze", a, axis=axis) - - -### transpose ### +def squeeze(x, axis=None): + if any_symbolic_tensors((x,)): + return Squeeze().symbolic_call(x, axis=axis) + return backend.execute("squeeze", x, axis=axis) class Transpose(Operation): def __init__(self, axes=None): + super().__init__() self.axes = axes - def call(self, a): - return backend.execute("transpose", a, axes=self.axes) + def call(self, x): + return backend.execute("transpose", x, axes=self.axes) - def compute_output_spec(self, a): - return compute_np_output_spec("transpose", a, axes=self.axes) + def compute_output_spec(self, x): + x_shape = x.shape + if self.axes is None: + return KerasTensor(x_shape[::-1]) + + if len(self.axes) != len(x_shape): + raise ValueError( + "axis must be a list of the same length as the input shape, " + f"expected {len(x_shape)}, but received {len(self.axes)}." + ) + output_shape = [] + for ax in self.axes: + output_shape.append(x_shape[ax]) + return KerasTensor(output_shape, dtype=x.dtype) -def transpose(a, axes=None): - if any_symbolic_tensors((a,)): - return Transpose().symbolic_call(a, axes=axes) - return backend.execute("transpose", a, axes=axes) - - -##################### -### Reduction ops ### -##################### +def transpose(x, axes=None): + if any_symbolic_tensors((x,)): + return Transpose(axes=axes).symbolic_call(x) + return backend.execute("transpose", x, axes=axes) class Mean(Operation): def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + axis = [axis] self.axis = axis self.keepdims = keepdims @@ -290,8 +608,9 @@ class Mean(Operation): ) def compute_output_spec(self, x): - return compute_np_output_spec( - "mean", x, axis=self.axis, keepdims=self.keepdims + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, ) @@ -303,6 +622,9 @@ def mean(x, axis=None, keepdims=False): class Var(Operation): def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + axis = [axis] self.axis = axis self.keepdims = keepdims @@ -310,8 +632,9 @@ class Var(Operation): return backend.execute("var", x, axis=self.axis, keepdims=self.keepdims) def compute_output_spec(self, x): - return compute_np_output_spec( - "var", x, axis=self.axis, keepdims=self.keepdims + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, ) @@ -323,6 +646,9 @@ def var(x, axis=None, keepdims=False): class Sum(Operation): def __init__(self, axis=None, keepdims=False): + super().__init__() + if isinstance(axis, int): + axis = [axis] self.axis = axis self.keepdims = keepdims @@ -330,8 +656,9 @@ class Sum(Operation): return backend.execute("sum", x, axis=self.axis, keepdims=self.keepdims) def compute_output_spec(self, x): - return compute_np_output_spec( - "sum", x, axis=self.axis, keepdims=self.keepdims + return KerasTensor( + reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims), + dtype=x.dtype, ) @@ -341,14 +668,6 @@ def sum(x, axis=None, keepdims=False): return backend.execute("sum", x, axis=axis, keepdims=keepdims) -########################## -### Array creation ops ### -########################## - - -### zeros ### - - class Zeros(Operation): def call(self, shape, dtype="float32"): return backend.execute("zeros", shape, dtype) @@ -361,9 +680,6 @@ def zeros(shape, dtype="float32"): return backend.execute("zeros", shape, dtype) -### ones ### - - class Ones(Operation): def call(self, shape, dtype="float32"): return backend.execute("ones", shape, dtype) @@ -376,9 +692,6 @@ def ones(shape, dtype="float32"): return backend.execute("ones", shape, dtype) -### eye ### - - class Eye(Operation): def call(self, N, M=None, k=0, dtype="float32"): return backend.execute("eye", N, M=M, k=k, dtype=dtype) diff --git a/keras_core/operations/numpy_test.py b/keras_core/operations/numpy_test.py new file mode 100644 index 000000000..66284f6f8 --- /dev/null +++ b/keras_core/operations/numpy_test.py @@ -0,0 +1,594 @@ +import numpy as np + +from keras_core import backend +from keras_core import testing +from keras_core.backend.keras_tensor import KerasTensor +from keras_core.operations import numpy as knp +from keras_core.operations import operation + + +class NumpyTwoInputOpsShapeTest(testing.TestCase): + def test_add(self): + x = KerasTensor((2, 3)) + y = KerasTensor((2, 3)) + self.assertEqual(knp.add(x, y).shape, (2, 3)) + + x = KerasTensor((None, 3)) + y = KerasTensor((2, None)) + self.assertEqual(knp.add(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.add(x, y) + + def test_subtract(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.subtract(x, y).shape, (2, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.subtract(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.subtract(x, y) + + def test_multiply(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.multiply(x, y).shape, (2, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.multiply(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.multiply(x, y) + + def test_matmul(self): + x = KerasTensor([2, 3]) + y = KerasTensor([3, 2]) + self.assertEqual(knp.matmul(x, y).shape, (2, 2)) + + x = KerasTensor([None, 3, 4]) + y = KerasTensor([3, None, 4, 5]) + self.assertEqual(knp.matmul(x, y).shape, (3, None, 3, 5)) + + with self.assertRaises(ValueError): + x = KerasTensor([3, 4]) + y = KerasTensor([2, 3, 4]) + knp.matmul(x, y) + + def test_power(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.power(x, y).shape, (2, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.power(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.power(x, y) + + def test_divide(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.divide(x, y).shape, (2, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.divide(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.divide(x, y) + + def test_true_divide(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.true_divide(x, y).shape, (2, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.true_divide(x, y).shape, (2, 3)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.true_divide(x, y) + + def test_append(self): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.append(x, y).shape, (12,)) + + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3]) + self.assertEqual(knp.append(x, y, axis=0).shape, (4, 3)) + + x = KerasTensor([None, 3]) + y = KerasTensor([2, None]) + self.assertEqual(knp.append(x, y).shape, (None,)) + + with self.assertRaises(ValueError): + x = KerasTensor([2, 3]) + y = KerasTensor([2, 3, 4]) + knp.append(x, y, axis=2) + + +class NumpyOneInputOpsShapeTest(testing.TestCase): + def test_mean(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.mean(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.mean(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.mean(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.mean(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_all(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.all(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.all(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.all(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.all(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_var(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.var(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.var(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.var(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.var(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_sum(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.sum(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.sum(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.sum(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.sum(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_amax(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.amax(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.amax(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.amax(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.amax(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_amin(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.amin(x).shape, ()) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.amin(x).shape, ()) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.amin(x, axis=1).shape, (None, 3)) + self.assertEqual(knp.amin(x, axis=1, keepdims=True).shape, (None, 1, 3)) + + def test_square(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.square(x).shape, (2, 3)) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.square(x).shape, (None, 3)) + + def test_negative(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.negative(x).shape, (2, 3)) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.negative(x).shape, (None, 3)) + + def test_abs(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.abs(x).shape, (2, 3)) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.abs(x).shape, (None, 3)) + + def test_absolute(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.absolute(x).shape, (2, 3)) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.absolute(x).shape, (None, 3)) + + def test_squeeze(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.squeeze(x).shape, (2, 3)) + + x = KerasTensor([None, 1]) + self.assertEqual(knp.squeeze(x).shape, (None,)) + self.assertEqual(knp.squeeze(x, axis=1).shape, (None,)) + + with self.assertRaises(ValueError): + x = KerasTensor([None, 1]) + knp.squeeze(x, axis=0) + + def test_transpose(self): + x = KerasTensor([2, 3]) + self.assertEqual(knp.transpose(x).shape, (3, 2)) + + x = KerasTensor([None, 3]) + self.assertEqual(knp.transpose(x).shape, (3, None)) + + x = KerasTensor([None, 3, 3]) + self.assertEqual(knp.transpose(x, (2, 0, 1)).shape, (3, None, 3)) + + +class NumpyTwoInputOpsCorretnessTest(testing.TestCase): + def test_add(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal(np.array(knp.add(x, y)), np.add(x, y)) + np.testing.assert_array_equal(np.array(knp.add(x, z)), np.add(x, z)) + + np.testing.assert_array_equal(np.array(knp.Add()(x, y)), np.add(x, y)) + np.testing.assert_array_equal(np.array(knp.Add()(x, z)), np.add(x, z)) + + def test_subtract(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal( + np.array(knp.subtract(x, y)), np.subtract(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.subtract(x, z)), np.subtract(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.Subtract()(x, y)), np.subtract(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Subtract()(x, z)), np.subtract(x, z) + ) + + def test_multiply(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal( + np.array(knp.multiply(x, y)), np.multiply(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.multiply(x, z)), np.multiply(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.Multiply()(x, y)), np.multiply(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Multiply()(x, z)), np.multiply(x, z) + ) + + def test_matmul(self): + x = np.ones([2, 3, 4, 5]) + y = np.ones([2, 3, 5, 6]) + z = np.ones([5, 6]) + np.testing.assert_array_equal( + np.array(knp.matmul(x, y)), np.matmul(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.matmul(x, z)), np.matmul(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.Matmul()(x, y)), np.matmul(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Matmul()(x, z)), np.matmul(x, z) + ) + + def test_power(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal(np.array(knp.power(x, y)), np.power(x, y)) + np.testing.assert_array_equal(np.array(knp.power(x, z)), np.power(x, z)) + + np.testing.assert_array_equal( + np.array(knp.Power()(x, y)), np.power(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Power()(x, z)), np.power(x, z) + ) + + def test_divide(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal( + np.array(knp.divide(x, y)), np.divide(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.divide(x, z)), np.divide(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.Divide()(x, y)), np.divide(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Divide()(x, z)), np.divide(x, z) + ) + + def test_true_divide(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]]]) + np.testing.assert_array_equal( + np.array(knp.true_divide(x, y)), np.true_divide(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.true_divide(x, z)), np.true_divide(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.TrueDivide()(x, y)), np.true_divide(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.TrueDivide()(x, z)), np.true_divide(x, z) + ) + + def test_append(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + y = np.array([[4, 5, 6], [3, 2, 1]]) + z = np.array([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [3, 2, 1]]]) + np.testing.assert_array_equal( + np.array(knp.append(x, y)), np.append(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.append(x, y, axis=1)), np.append(x, y, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.append(x, z)), np.append(x, z) + ) + + np.testing.assert_array_equal( + np.array(knp.Append()(x, y)), np.append(x, y) + ) + np.testing.assert_array_equal( + np.array(knp.Append(axis=1)(x, y)), np.append(x, y, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Append()(x, z)), np.append(x, z) + ) + + +class NumpyOneInputOpsCorrectnessTest(testing.TestCase): + def test_mean(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.mean(x)), np.mean(x)) + np.testing.assert_array_equal( + np.array(knp.mean(x, axis=1)), np.mean(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.mean(x, axis=1, keepdims=True)), + np.mean(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.Mean()(x)), np.mean(x)) + np.testing.assert_array_equal( + np.array(knp.Mean(axis=1)(x)), np.mean(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Mean(axis=1, keepdims=True)(x)), + np.mean(x, axis=1, keepdims=True), + ) + + def test_all(self): + x = np.array([[True, False, True], [True, True, True]]) + np.testing.assert_array_equal(np.array(knp.all(x)), np.all(x)) + np.testing.assert_array_equal( + np.array(knp.all(x, axis=1)), np.all(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.all(x, axis=1, keepdims=True)), + np.all(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.All()(x)), np.all(x)) + np.testing.assert_array_equal( + np.array(knp.All(axis=1)(x)), np.all(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.All(axis=1, keepdims=True)(x)), + np.all(x, axis=1, keepdims=True), + ) + + def test_var(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.var(x)), np.var(x)) + np.testing.assert_array_equal( + np.array(knp.var(x, axis=1)), np.var(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.var(x, axis=1, keepdims=True)), + np.var(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.Var()(x)), np.var(x)) + np.testing.assert_array_equal( + np.array(knp.Var(axis=1)(x)), np.var(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Var(axis=1, keepdims=True)(x)), + np.var(x, axis=1, keepdims=True), + ) + + def test_sum(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.sum(x)), np.sum(x)) + np.testing.assert_array_equal( + np.array(knp.sum(x, axis=1)), np.sum(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.sum(x, axis=1, keepdims=True)), + np.sum(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.Sum()(x)), np.sum(x)) + np.testing.assert_array_equal( + np.array(knp.Sum(axis=1)(x)), np.sum(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Sum(axis=1, keepdims=True)(x)), + np.sum(x, axis=1, keepdims=True), + ) + + def test_amax(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.amax(x)), np.amax(x)) + np.testing.assert_array_equal( + np.array(knp.amax(x, axis=1)), np.amax(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.amax(x, axis=1, keepdims=True)), + np.amax(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.Amax()(x)), np.amax(x)) + np.testing.assert_array_equal( + np.array(knp.Amax(axis=1)(x)), np.amax(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Amax(axis=1, keepdims=True)(x)), + np.amax(x, axis=1, keepdims=True), + ) + + def test_amin(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.amin(x)), np.amin(x)) + np.testing.assert_array_equal( + np.array(knp.amin(x, axis=1)), np.amin(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.amin(x, axis=1, keepdims=True)), + np.amin(x, axis=1, keepdims=True), + ) + + np.testing.assert_array_equal(np.array(knp.Amin()(x)), np.amin(x)) + np.testing.assert_array_equal( + np.array(knp.Amin(axis=1)(x)), np.amin(x, axis=1) + ) + np.testing.assert_array_equal( + np.array(knp.Amin(axis=1, keepdims=True)(x)), + np.amin(x, axis=1, keepdims=True), + ) + + def test_square(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.square(x)), np.square(x)) + + np.testing.assert_array_equal(np.array(knp.Square()(x)), np.square(x)) + + def test_negative(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.negative(x)), np.negative(x)) + + np.testing.assert_array_equal( + np.array(knp.Negative()(x)), np.negative(x) + ) + + def test_abs(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.abs(x)), np.abs(x)) + + np.testing.assert_array_equal(np.array(knp.Abs()(x)), np.abs(x)) + + def test_absolute(self): + x = np.array([[1, 2, 3], [3, 2, 1]]) + np.testing.assert_array_equal(np.array(knp.absolute(x)), np.absolute(x)) + + np.testing.assert_array_equal( + np.array(knp.Absolute()(x)), np.absolute(x) + ) + + def test_squeeze(self): + x = np.ones([1, 2, 3, 4, 5]) + np.testing.assert_array_equal(np.array(knp.squeeze(x)), np.squeeze(x)) + np.testing.assert_array_equal( + np.array(knp.squeeze(x, axis=0)), np.squeeze(x, axis=0) + ) + + np.testing.assert_array_equal(np.array(knp.Squeeze()(x)), np.squeeze(x)) + np.testing.assert_array_equal( + np.array(knp.Squeeze(axis=0)(x)), np.squeeze(x, axis=0) + ) + + def test_transpose(self): + x = np.ones([1, 2, 3, 4, 5]) + np.testing.assert_array_equal( + np.array(knp.transpose(x)), np.transpose(x) + ) + np.testing.assert_array_equal( + np.array(knp.transpose(x, axes=(1, 0, 3, 2, 4))), + np.transpose(x, axes=(1, 0, 3, 2, 4)), + ) + + np.testing.assert_array_equal( + np.array(knp.Transpose()(x)), np.transpose(x) + ) + np.testing.assert_array_equal( + np.array(knp.Transpose(axes=(1, 0, 3, 2, 4))(x)), + np.transpose(x, axes=(1, 0, 3, 2, 4)), + ) + + +class NumpyArrayCreateOpsCorrectnessTest(testing.TestCase): + def test_ones(self): + np.testing.assert_array_equal( + np.array(knp.ones([2, 3])), np.ones([2, 3]) + ) + np.testing.assert_array_equal( + np.array(knp.Ones()([2, 3])), np.ones([2, 3]) + ) + + def test_zeros(self): + np.testing.assert_array_equal( + np.array(knp.zeros([2, 3])), np.zeros([2, 3]) + ) + np.testing.assert_array_equal( + np.array(knp.Zeros()([2, 3])), np.zeros([2, 3]) + ) + + def test_eye(self): + np.testing.assert_array_equal(np.array(knp.eye(3)), np.eye(3)) + np.testing.assert_array_equal(np.array(knp.eye(3, 4)), np.eye(3, 4)) + np.testing.assert_array_equal( + np.array(knp.eye(3, 4, 1)), np.eye(3, 4, 1) + ) + + np.testing.assert_array_equal(np.array(knp.Eye()(3)), np.eye(3)) + np.testing.assert_array_equal(np.array(knp.Eye()(3, 4)), np.eye(3, 4)) + np.testing.assert_array_equal( + np.array(knp.Eye()(3, 4, 1)), np.eye(3, 4, 1) + ) diff --git a/keras_core/operations/operation.py b/keras_core/operations/operation.py index a418db37b..84a706be4 100644 --- a/keras_core/operations/operation.py +++ b/keras_core/operations/operation.py @@ -1,7 +1,7 @@ +from keras_core import backend from keras_core.backend.keras_tensor import any_symbolic_tensors from keras_core.operations.node import Node from keras_core.utils.naming import auto_name -from keras_core import backend class Operation: diff --git a/keras_core/operations/operation_test.py b/keras_core/operations/operation_test.py index 79be189b5..ecc9197eb 100644 --- a/keras_core/operations/operation_test.py +++ b/keras_core/operations/operation_test.py @@ -1,9 +1,10 @@ -from keras_core.operations import operation -from keras_core.engine import keras_tensor -from keras_core.operations import numpy as knp +import numpy as np + from keras_core import backend from keras_core import testing -import numpy as np +from keras_core.engine import keras_tensor +from keras_core.operations import numpy as knp +from keras_core.operations import operation class OpWithMultipleInputs(operation.Operation): diff --git a/keras_core/operations/random.py b/keras_core/operations/random.py index 06d85a546..30af95399 100644 --- a/keras_core/operations/random.py +++ b/keras_core/operations/random.py @@ -8,7 +8,7 @@ truncated_normal dropout """ -from keras_core.backend.random import normal -from keras_core.backend.random import uniform -from keras_core.backend.random import truncated_normal from keras_core.backend.random import dropout +from keras_core.backend.random import normal +from keras_core.backend.random import truncated_normal +from keras_core.backend.random import uniform diff --git a/keras_core/operations/symbolic_arguments.py b/keras_core/operations/symbolic_arguments.py index efa35282e..816cf1f95 100644 --- a/keras_core/operations/symbolic_arguments.py +++ b/keras_core/operations/symbolic_arguments.py @@ -1,4 +1,5 @@ from tensorflow import nest + from keras_core.backend import KerasTensor diff --git a/keras_core/optimizers/optimizer.py b/keras_core/optimizers/optimizer.py index 2db6e185f..d199f3039 100644 --- a/keras_core/optimizers/optimizer.py +++ b/keras_core/optimizers/optimizer.py @@ -1,13 +1,14 @@ -from keras_core import backend -from keras_core import operations as ops -from keras_core.utils.tracking import Tracker -from keras_core import initializers -from keras_core.optimizers.schedules import learning_rate_schedule -from keras_core.utils.naming import auto_name -from keras_core.api_export import keras_core_export import re import warnings +from keras_core import backend +from keras_core import initializers +from keras_core import operations as ops +from keras_core.api_export import keras_core_export +from keras_core.optimizers.schedules import learning_rate_schedule +from keras_core.utils.naming import auto_name +from keras_core.utils.tracking import Tracker + @keras_core_export(["keras_core.Optimizer", "keras_core.optimizers.Optimizer"]) class Optimizer: diff --git a/keras_core/optimizers/sgd.py b/keras_core/optimizers/sgd.py index 757f20136..43c3c441b 100644 --- a/keras_core/optimizers/sgd.py +++ b/keras_core/optimizers/sgd.py @@ -1,5 +1,5 @@ -from keras_core.optimizers import optimizer from keras_core import operations as ops +from keras_core.optimizers import optimizer class SGD(optimizer.Optimizer): diff --git a/keras_core/regularizers/__init__.py b/keras_core/regularizers/__init__.py index 5ea79f11c..0dcb1bf11 100644 --- a/keras_core/regularizers/__init__.py +++ b/keras_core/regularizers/__init__.py @@ -1,5 +1,5 @@ -from keras_core.regularizers.regularizers import Regularizer from keras_core.regularizers.regularizers import L1 -from keras_core.regularizers.regularizers import L2 from keras_core.regularizers.regularizers import L1L2 +from keras_core.regularizers.regularizers import L2 from keras_core.regularizers.regularizers import OrthogonalRegularizer +from keras_core.regularizers.regularizers import Regularizer diff --git a/keras_core/regularizers/regularizers.py b/keras_core/regularizers/regularizers.py index 23f308f58..03cd8847d 100644 --- a/keras_core/regularizers/regularizers.py +++ b/keras_core/regularizers/regularizers.py @@ -1,4 +1,5 @@ import math + from keras_core import operations as ops from keras_core.api_export import keras_core_export diff --git a/keras_core/regularizers/regularizers_test.py b/keras_core/regularizers/regularizers_test.py index 22bf5fd77..705b508ef 100644 --- a/keras_core/regularizers/regularizers_test.py +++ b/keras_core/regularizers/regularizers_test.py @@ -1,9 +1,10 @@ -from keras_core import testing +import numpy as np + +from keras_core import backend from keras_core import initializers from keras_core import operations as ops -from keras_core import backend from keras_core import regularizers -import numpy as np +from keras_core import testing # TODO: serialization tests diff --git a/keras_core/testing/test_case.py b/keras_core/testing/test_case.py index e4cc81687..73eaebbeb 100644 --- a/keras_core/testing/test_case.py +++ b/keras_core/testing/test_case.py @@ -1,6 +1,7 @@ -import numpy as np import unittest +import numpy as np + class TestCase(unittest.TestCase): def assertAllClose(self, x1, x2, atol=1e-7, rtol=1e-7): diff --git a/keras_core/utils/io_utils.py b/keras_core/utils/io_utils.py index 817683832..3ee281154 100644 --- a/keras_core/utils/io_utils.py +++ b/keras_core/utils/io_utils.py @@ -1,8 +1,10 @@ -from keras_core.api_export import keras_core_export -import threading import sys +import threading + from absl import logging +from keras_core.api_export import keras_core_export + INTERACTIVE_LOGGING = threading.local() INTERACTIVE_LOGGING.enable = True diff --git a/keras_core/utils/summary_utils.py b/keras_core/utils/summary_utils.py index 676c542d6..3aec4245a 100644 --- a/keras_core/utils/summary_utils.py +++ b/keras_core/utils/summary_utils.py @@ -79,7 +79,8 @@ def print_summary( matches `layer_range[1]`. By default (`None`) all layers in the model are included in the summary. """ - from keras_core.models import Sequential, Functional + from keras_core.models import Functional + from keras_core.models import Sequential if print_fn is None: print_fn = io_utils.print_msg diff --git a/requirements.txt b/requirements.txt index 5cca9dd6a..2859033f5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,7 @@ -absl tensorflow -jax -namex \ No newline at end of file +jax[cpu] +namex +black>=22 +flake8 +isort +pytest \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 000000000..b62019718 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,53 @@ +[tool:pytest] +filterwarnings = + error + ignore::DeprecationWarning + ignore::ImportWarning + ignore::RuntimeWarning + ignore::PendingDeprecationWarning + ignore::FutureWarning + ignore::UserWarning + # Ignore a spurious warning on tf-nightly related to save model changes. + ignore:Custom mask layers require a config + +addopts=-vv + +# Do not run tests in the `build` folders +norecursedirs = build + +[isort] +known_first_party = keras_core,tests +default_section = THIRDPARTY +line_length = 80 +profile = black + +[coverage:report] +exclude_lines = + pragma: no cover + @abstract + raise NotImplementedError +omit = *_test.py + +[flake8] + +ignore = + # Conflicts with black + E203 + # defaults flake8 ignores + E121,E123,E126,E226,E24,E704,W503,W504 + # Function name should be lowercase + N802 + # lowercase ... imported as non lowercase + # Useful to ignore for "import keras.backend as K" + N812 + # do not use bare 'except' + E722 + +exclude = + *_pb2.py + *_pb2_grpc.py + +#imported but unused in __init__.py, that's ok. +per-file-ignores = **/__init__.py:F401 + +max-line-length = 80 \ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..dd9892940 --- /dev/null +++ b/setup.py @@ -0,0 +1,42 @@ +"""Setup script.""" + +import pathlib + +from setuptools import find_packages +from setuptools import setup + +HERE = pathlib.Path(__file__).parent + +setup( + name="keras-core", + description="Multi-backend Keras.", + long_description_content_type="text/markdown", + version="0.1.0", + url="https://github.com/keras-team/keras-core", + author="Keras team", + author_email="keras@google.com", + license="Apache License 2.0", + install_requires=[ + "absl-py", + "numpy", + "packaging", + ], + # Supported Python versions + python_requires=">=3.8", + classifiers=[ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: Unix", + "Operating System :: Microsoft :: Windows", + "Operating System :: MacOS", + "Intended Audience :: Science/Research", + "Topic :: Scientific/Engineering", + "Topic :: Software Development", + ], + packages=find_packages(exclude=("*_test.py",)), +) diff --git a/shell/format.sh b/shell/format.sh new file mode 100644 index 000000000..c1b2926a6 --- /dev/null +++ b/shell/format.sh @@ -0,0 +1,9 @@ +#!/bin/bash -e + +base_dir=$(dirname $(dirname $0)) +targets="${base_dir}/*.py ${base_dir}/keras_core/" + +isort --sp "${base_dir}/setup.cfg" --sl ${targets} +black --line-length 80 ${targets} + +flake8 --config "${base_dir}/setup.cfg" --max-line-length=200 ${targets} diff --git a/tf_integration_test.py b/tf_integration_test.py index ad74b044d..28af5d9a4 100644 --- a/tf_integration_test.py +++ b/tf_integration_test.py @@ -1,9 +1,9 @@ from keras_core import backend -from keras_core.layers.layer import Layer -from keras_core.backend import KerasTensor -from keras_core.operations.function import Function from keras_core import initializers +from keras_core.backend import KerasTensor +from keras_core.layers.layer import Layer from keras_core.operations import numpy as knp +from keras_core.operations.function import Function class MiniDense(Layer):