From 1e3e94cf3c64a4186c289c7a4642ca9e10db80c8 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 3 May 2023 11:11:30 -0700 Subject: [PATCH] Fix demo --- demo_custom_jax_workflow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/demo_custom_jax_workflow.py b/demo_custom_jax_workflow.py index a7819df7d..c5651942b 100644 --- a/demo_custom_jax_workflow.py +++ b/demo_custom_jax_workflow.py @@ -113,6 +113,7 @@ for data in dataset: print("Loss:", loss) # Post-processing model state update +trainable_variables, non_trainable_variables, optimizer_variables = state for variable, value in zip(model.trainable_variables, trainable_variables): variable.assign(value) for variable, value in zip(