Compatible with Py38 (#351)
This commit is contained in:
parent
f90a500d4a
commit
26a4a6e737
@ -1021,7 +1021,7 @@ class Layer(BackendLayer, Operation):
|
||||
# Case: all input keyword arguments were plain tensors.
|
||||
input_tensors = {
|
||||
# We strip the `_shape` suffix to recover kwarg names.
|
||||
k.removesuffix("_shape"): backend.KerasTensor(shape)
|
||||
utils.removesuffix(k, "_shape"): backend.KerasTensor(shape)
|
||||
for k, shape in shapes_dict.items()
|
||||
}
|
||||
try:
|
||||
@ -1286,7 +1286,7 @@ def check_shapes_signature(target_fn, call_spec, cls):
|
||||
f"Received `{method_name}()` argument "
|
||||
f"`{name}`, which does not end in `_shape`."
|
||||
)
|
||||
expected_call_arg = name.removesuffix("_shape")
|
||||
expected_call_arg = utils.removesuffix(name, "_shape")
|
||||
if expected_call_arg not in call_spec.arguments_dict:
|
||||
raise ValueError(
|
||||
f"{error_preamble} For layer '{cls.__name__}', "
|
||||
|
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
|
||||
from keras_core import operations as ops
|
||||
from keras_core import utils
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.layers.core.wrapper import Wrapper
|
||||
from keras_core.layers.layer import Layer
|
||||
@ -109,16 +110,16 @@ class Bidirectional(Wrapper):
|
||||
# Recreate the forward layer from the original layer config, so that it
|
||||
# will not carry over any state from the layer.
|
||||
config = serialization_lib.serialize_keras_object(layer)
|
||||
config["config"]["name"] = "forward_" + layer.name.removeprefix(
|
||||
"forward_"
|
||||
config["config"]["name"] = "forward_" + utils.removeprefix(
|
||||
layer.name, "forward_"
|
||||
)
|
||||
self.forward_layer = serialization_lib.deserialize_keras_object(config)
|
||||
|
||||
if backward_layer is None:
|
||||
config = serialization_lib.serialize_keras_object(layer)
|
||||
config["config"]["go_backwards"] = True
|
||||
config["config"]["name"] = "backward_" + layer.name.removeprefix(
|
||||
"backward_"
|
||||
config["config"]["name"] = "backward_" + utils.removeprefix(
|
||||
layer.name, "backward_"
|
||||
)
|
||||
self.backward_layer = serialization_lib.deserialize_keras_object(
|
||||
config
|
||||
|
@ -8,6 +8,7 @@ from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core import operations as ops
|
||||
from keras_core import utils
|
||||
from keras_core.models import Model
|
||||
from keras_core.utils import traceback_utils
|
||||
|
||||
@ -336,7 +337,9 @@ def create_keras_tensors(input_shape, dtype):
|
||||
return [keras_tensor.KerasTensor(s, dtype=dtype) for s in input_shape]
|
||||
if isinstance(input_shape, dict):
|
||||
return {
|
||||
k.removesuffix("_shape"): keras_tensor.KerasTensor(v, dtype=dtype)
|
||||
utils.removesuffix(k, "_shape"): keras_tensor.KerasTensor(
|
||||
v, dtype=dtype
|
||||
)
|
||||
for k, v in input_shape.items()
|
||||
}
|
||||
|
||||
@ -370,6 +373,6 @@ def create_eager_tensors(input_shape, dtype):
|
||||
return [create_fn(s, dtype=dtype) for s in input_shape]
|
||||
if isinstance(input_shape, dict):
|
||||
return {
|
||||
k.removesuffix("_shape"): create_fn(v, dtype=dtype)
|
||||
utils.removesuffix(k, "_shape"): create_fn(v, dtype=dtype)
|
||||
for k, v in input_shape.items()
|
||||
}
|
||||
|
@ -18,5 +18,7 @@ from keras_core.utils.numerical_utils import to_categorical
|
||||
from keras_core.utils.progbar import Progbar
|
||||
from keras_core.utils.python_utils import default
|
||||
from keras_core.utils.python_utils import is_default
|
||||
from keras_core.utils.python_utils import removeprefix
|
||||
from keras_core.utils.python_utils import removesuffix
|
||||
from keras_core.utils.rng_utils import set_random_seed
|
||||
from keras_core.utils.sequence_utils import pad_sequences
|
||||
|
@ -122,3 +122,21 @@ def remove_long_seq(maxlen, seq, label):
|
||||
new_seq.append(x)
|
||||
new_label.append(y)
|
||||
return new_seq, new_label
|
||||
|
||||
|
||||
def removeprefix(x, prefix):
|
||||
"""Backport of `removeprefix` from PEP-616 (Python 3.9+)"""
|
||||
|
||||
if len(prefix) > 0 and x.startswith(prefix):
|
||||
return x[len(prefix) :]
|
||||
else:
|
||||
return x
|
||||
|
||||
|
||||
def removesuffix(x, suffix):
|
||||
"""Backport of `removesuffix` from PEP-616 (Python 3.9+)"""
|
||||
|
||||
if len(suffix) > 0 and x.endswith(suffix):
|
||||
return x[: -len(suffix)]
|
||||
else:
|
||||
return x
|
||||
|
@ -10,3 +10,13 @@ class PythonUtilsTest(testing.TestCase):
|
||||
serialized = python_utils.func_dump(my_function)
|
||||
deserialized = python_utils.func_load(serialized)
|
||||
self.assertEqual(deserialized(2, y=3), 5)
|
||||
|
||||
def test_removesuffix(self):
|
||||
x = "model.keras"
|
||||
self.assertEqual(python_utils.removesuffix(x, ".keras"), "model")
|
||||
self.assertEqual(python_utils.removesuffix(x, "model"), x)
|
||||
|
||||
def test_removeprefix(self):
|
||||
x = "model.keras"
|
||||
self.assertEqual(python_utils.removeprefix(x, "model"), ".keras")
|
||||
self.assertEqual(python_utils.removeprefix(x, ".keras"), x)
|
||||
|
Loading…
Reference in New Issue
Block a user