Compatible with Py38 (#351)

This commit is contained in:
Ramesh Sampath 2023-06-14 23:10:12 +05:30 committed by Francois Chollet
parent f90a500d4a
commit 26a4a6e737
6 changed files with 42 additions and 8 deletions

@ -1021,7 +1021,7 @@ class Layer(BackendLayer, Operation):
# Case: all input keyword arguments were plain tensors.
input_tensors = {
# We strip the `_shape` suffix to recover kwarg names.
k.removesuffix("_shape"): backend.KerasTensor(shape)
utils.removesuffix(k, "_shape"): backend.KerasTensor(shape)
for k, shape in shapes_dict.items()
}
try:
@ -1286,7 +1286,7 @@ def check_shapes_signature(target_fn, call_spec, cls):
f"Received `{method_name}()` argument "
f"`{name}`, which does not end in `_shape`."
)
expected_call_arg = name.removesuffix("_shape")
expected_call_arg = utils.removesuffix(name, "_shape")
if expected_call_arg not in call_spec.arguments_dict:
raise ValueError(
f"{error_preamble} For layer '{cls.__name__}', "

@ -1,6 +1,7 @@
import copy
from keras_core import operations as ops
from keras_core import utils
from keras_core.api_export import keras_core_export
from keras_core.layers.core.wrapper import Wrapper
from keras_core.layers.layer import Layer
@ -109,16 +110,16 @@ class Bidirectional(Wrapper):
# Recreate the forward layer from the original layer config, so that it
# will not carry over any state from the layer.
config = serialization_lib.serialize_keras_object(layer)
config["config"]["name"] = "forward_" + layer.name.removeprefix(
"forward_"
config["config"]["name"] = "forward_" + utils.removeprefix(
layer.name, "forward_"
)
self.forward_layer = serialization_lib.deserialize_keras_object(config)
if backward_layer is None:
config = serialization_lib.serialize_keras_object(layer)
config["config"]["go_backwards"] = True
config["config"]["name"] = "backward_" + layer.name.removeprefix(
"backward_"
config["config"]["name"] = "backward_" + utils.removeprefix(
layer.name, "backward_"
)
self.backward_layer = serialization_lib.deserialize_keras_object(
config

@ -8,6 +8,7 @@ from tensorflow import nest
from keras_core import backend
from keras_core import operations as ops
from keras_core import utils
from keras_core.models import Model
from keras_core.utils import traceback_utils
@ -336,7 +337,9 @@ def create_keras_tensors(input_shape, dtype):
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k.removesuffix("_shape"): keras_tensor.KerasTensor(v, dtype=dtype)
utils.removesuffix(k, "_shape"): keras_tensor.KerasTensor(
v, dtype=dtype
)
for k, v in input_shape.items()
}
@ -370,6 +373,6 @@ def create_eager_tensors(input_shape, dtype):
return [create_fn(s, dtype=dtype) for s in input_shape]
if isinstance(input_shape, dict):
return {
k.removesuffix("_shape"): create_fn(v, dtype=dtype)
utils.removesuffix(k, "_shape"): create_fn(v, dtype=dtype)
for k, v in input_shape.items()
}

@ -18,5 +18,7 @@ from keras_core.utils.numerical_utils import to_categorical
from keras_core.utils.progbar import Progbar
from keras_core.utils.python_utils import default
from keras_core.utils.python_utils import is_default
from keras_core.utils.python_utils import removeprefix
from keras_core.utils.python_utils import removesuffix
from keras_core.utils.rng_utils import set_random_seed
from keras_core.utils.sequence_utils import pad_sequences

@ -122,3 +122,21 @@ def remove_long_seq(maxlen, seq, label):
new_seq.append(x)
new_label.append(y)
return new_seq, new_label
def removeprefix(x, prefix):
"""Backport of `removeprefix` from PEP-616 (Python 3.9+)"""
if len(prefix) > 0 and x.startswith(prefix):
return x[len(prefix) :]
else:
return x
def removesuffix(x, suffix):
"""Backport of `removesuffix` from PEP-616 (Python 3.9+)"""
if len(suffix) > 0 and x.endswith(suffix):
return x[: -len(suffix)]
else:
return x

@ -10,3 +10,13 @@ class PythonUtilsTest(testing.TestCase):
serialized = python_utils.func_dump(my_function)
deserialized = python_utils.func_load(serialized)
self.assertEqual(deserialized(2, y=3), 5)
def test_removesuffix(self):
x = "model.keras"
self.assertEqual(python_utils.removesuffix(x, ".keras"), "model")
self.assertEqual(python_utils.removesuffix(x, "model"), x)
def test_removeprefix(self):
x = "model.keras"
self.assertEqual(python_utils.removeprefix(x, "model"), ".keras")
self.assertEqual(python_utils.removeprefix(x, ".keras"), x)