keras/keras_core/backend/common/variables_test.py
2023-04-27 16:02:31 -07:00

45 lines
1.5 KiB
Python

from keras_core import backend
from keras_core import initializers
from keras_core.backend.common.variables import AutocastScope
from keras_core.testing import test_case
class VariablesTest(test_case.TestCase):
def test_deferred_initialization(self):
with backend.StatelessScope():
v = backend.Variable(
initializer=initializers.RandomNormal(), shape=(2, 2)
)
self.assertEqual(v._value, None)
# Variables can nevertheless be accessed
_ = v + 1
self.assertEqual(v._value.shape, (2, 2))
with self.assertRaisesRegex(ValueError, "while in a stateless scope"):
with backend.StatelessScope():
v = backend.Variable(initializer=0)
def test_autocasting(self):
v = backend.Variable(
initializer=initializers.RandomNormal(),
shape=(2, 2),
dtype="float32",
)
self.assertEqual(v.dtype, "float32")
self.assertEqual(v.value.dtype.name, "float32")
print("open scope")
with AutocastScope("float16"):
self.assertEqual(v.value.dtype.name, "float16")
self.assertEqual(v.value.dtype.name, "float32")
# Test non-float variables are not affected
v = backend.Variable(
initializer=initializers.Ones(), shape=(2, 2), dtype="int32"
)
self.assertEqual(v.dtype, "int32")
self.assertEqual(v.value.dtype.name, "int32")
with AutocastScope("float16"):
self.assertEqual(v.value.dtype.name, "int32")