keras/keras_core/backend/stateless_scope_test.py
2023-04-18 14:49:38 -07:00

37 lines
1.2 KiB
Python

import numpy as np
from keras_core import backend
from keras_core import operations as ops
from keras_core import testing
from keras_core.backend.stateless_scope import StatelessScope
class TestStatelessScope(testing.TestCase):
def test_basic_flow(self):
var1 = backend.Variable(np.zeros((2,)))
var2 = backend.Variable(np.zeros((2,)))
var_out = backend.Variable(np.zeros((2,)))
value1 = ops.ones(shape=(2,))
value2 = ops.ones(shape=(2,))
with StatelessScope(
state_mapping=[(var1, value1), (var2, value2)]
) as scope:
out = var1 + var2
var_out.assign(out)
var_out_value = var_out + 0.0
# Inside scope: new value is used.
self.assertAllClose(var_out_value, 2 * np.ones((2,)))
# Out of scope: old value is used.
var_out_value = var_out + 0.0
self.assertAllClose(var_out_value, np.zeros((2,)))
# Updates are tracked.
var_out_value = scope.get_current_value(var_out)
self.assertAllClose(var_out_value, 2 * np.ones((2,)))
# Updates can be reapplied.
var_out.assign(scope.get_current_value(var_out))
self.assertAllClose(var_out_value, 2 * np.ones((2,)))