122 lines
4.6 KiB
Python
122 lines
4.6 KiB
Python
import keras_core
|
|
from keras_core import testing
|
|
from keras_core.saving import object_registration
|
|
from keras_core.saving import serialization_lib
|
|
|
|
|
|
class TestObjectRegistration(testing.TestCase):
|
|
def test_custom_object_scope(self):
|
|
def custom_fn():
|
|
pass
|
|
|
|
class CustomClass:
|
|
pass
|
|
|
|
def check_get_in_thread():
|
|
with object_registration.custom_object_scope(
|
|
{"CustomClass": CustomClass, "custom_fn": custom_fn}
|
|
):
|
|
actual_custom_fn = keras_core.activations.get("custom_fn")
|
|
self.assertEqual(actual_custom_fn, custom_fn)
|
|
actual_custom_class = keras_core.regularizers.get("CustomClass")
|
|
self.assertEqual(actual_custom_class.__class__, CustomClass)
|
|
|
|
with object_registration.custom_object_scope(
|
|
{"CustomClass": CustomClass, "custom_fn": custom_fn}
|
|
):
|
|
actual_custom_fn = keras_core.activations.get("custom_fn")
|
|
self.assertEqual(actual_custom_fn, custom_fn)
|
|
actual_custom_class = keras_core.regularizers.get("CustomClass")
|
|
self.assertEqual(actual_custom_class.__class__, CustomClass)
|
|
checked_thread = self.checkedThread(check_get_in_thread)
|
|
checked_thread.start()
|
|
checked_thread.join()
|
|
|
|
def test_serialize_custom_class_with_default_name(self):
|
|
@object_registration.register_keras_serializable()
|
|
class TestClass:
|
|
def __init__(self, value):
|
|
self._value = value
|
|
|
|
def get_config(self):
|
|
return {"value": self._value}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls(**config)
|
|
|
|
serialized_name = "Custom>TestClass"
|
|
inst = TestClass(value=10)
|
|
class_name = object_registration.GLOBAL_CUSTOM_NAMES[TestClass]
|
|
self.assertEqual(serialized_name, class_name)
|
|
|
|
config = serialization_lib.serialize_keras_object(inst)
|
|
self.assertEqual("TestClass", config["class_name"])
|
|
new_inst = serialization_lib.deserialize_keras_object(config)
|
|
self.assertIsNot(inst, new_inst)
|
|
self.assertIsInstance(new_inst, TestClass)
|
|
self.assertEqual(10, new_inst._value)
|
|
|
|
def test_serialize_custom_class_with_custom_name(self):
|
|
@object_registration.register_keras_serializable(
|
|
"TestPackage", "CustomName"
|
|
)
|
|
class OtherTestClass:
|
|
def __init__(self, val):
|
|
self._val = val
|
|
|
|
def get_config(self):
|
|
return {"val": self._val}
|
|
|
|
@classmethod
|
|
def from_config(cls, config):
|
|
return cls(**config)
|
|
|
|
serialized_name = "TestPackage>CustomName"
|
|
inst = OtherTestClass(val=5)
|
|
class_name = object_registration.GLOBAL_CUSTOM_NAMES[OtherTestClass]
|
|
self.assertEqual(serialized_name, class_name)
|
|
fn_class_name = object_registration.get_registered_name(OtherTestClass)
|
|
self.assertEqual(fn_class_name, class_name)
|
|
|
|
cls = object_registration.get_registered_object(fn_class_name)
|
|
self.assertEqual(OtherTestClass, cls)
|
|
|
|
config = keras_core.saving.serialize_keras_object(inst)
|
|
self.assertEqual("OtherTestClass", config["class_name"])
|
|
new_inst = keras_core.saving.deserialize_keras_object(config)
|
|
self.assertIsNot(inst, new_inst)
|
|
self.assertIsInstance(new_inst, OtherTestClass)
|
|
self.assertEqual(5, new_inst._val)
|
|
|
|
def test_serialize_custom_function(self):
|
|
@object_registration.register_keras_serializable()
|
|
def my_fn():
|
|
return 42
|
|
|
|
serialized_name = "Custom>my_fn"
|
|
class_name = object_registration.GLOBAL_CUSTOM_NAMES[my_fn]
|
|
self.assertEqual(serialized_name, class_name)
|
|
fn_class_name = object_registration.get_registered_name(my_fn)
|
|
self.assertEqual(fn_class_name, class_name)
|
|
|
|
config = keras_core.saving.serialize_keras_object(my_fn)
|
|
fn = keras_core.saving.deserialize_keras_object(config)
|
|
self.assertEqual(42, fn())
|
|
|
|
fn_2 = object_registration.get_registered_object(fn_class_name)
|
|
self.assertEqual(42, fn_2())
|
|
|
|
def test_serialize_custom_class_without_get_config_fails(self):
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
"Cannot register a class that does not have a get_config.*",
|
|
):
|
|
|
|
@object_registration.register_keras_serializable(
|
|
"TestPackage", "TestClass"
|
|
)
|
|
class TestClass:
|
|
def __init__(self, value):
|
|
self._value = value
|