keras/keras_core/backend/stateless_scope_test.py

37 lines
1.2 KiB
Python
Raw Normal View History

2023-04-18 21:49:38 +00:00
import numpy as np
2023-04-18 15:52:21 +00:00
from keras_core import backend
from keras_core import operations as ops
2023-04-18 21:49:38 +00:00
from keras_core import testing
2023-04-18 15:52:21 +00:00
from keras_core.backend.stateless_scope import StatelessScope
class TestStatelessScope(testing.TestCase):
def test_basic_flow(self):
var1 = backend.Variable(np.zeros((2,)))
var2 = backend.Variable(np.zeros((2,)))
var_out = backend.Variable(np.zeros((2,)))
value1 = ops.ones(shape=(2,))
value2 = ops.ones(shape=(2,))
2023-04-18 21:49:38 +00:00
with StatelessScope(
state_mapping=[(var1, value1), (var2, value2)]
) as scope:
2023-04-18 15:52:21 +00:00
out = var1 + var2
var_out.assign(out)
2023-04-18 21:49:38 +00:00
var_out_value = var_out + 0.0
2023-04-18 15:52:21 +00:00
# Inside scope: new value is used.
self.assertAllClose(var_out_value, 2 * np.ones((2,)))
2023-04-18 21:49:38 +00:00
2023-04-18 15:52:21 +00:00
# Out of scope: old value is used.
2023-04-18 21:49:38 +00:00
var_out_value = var_out + 0.0
2023-04-18 15:52:21 +00:00
self.assertAllClose(var_out_value, np.zeros((2,)))
# Updates are tracked.
var_out_value = scope.get_current_value(var_out)
self.assertAllClose(var_out_value, 2 * np.ones((2,)))
# Updates can be reapplied.
var_out.assign(scope.get_current_value(var_out))
self.assertAllClose(var_out_value, 2 * np.ones((2,)))