Format code

This commit is contained in:
Francois Chollet 2023-05-31 18:11:12 -07:00
parent 00841c00ed
commit 7267c4e32c
2 changed files with 13 additions and 9 deletions

@ -44,9 +44,6 @@ class Variable(
return tf.convert_to_tensor(self.value, dtype=dtype, name=name)
# Methods below are for SavedModel support
def _write_object_proto(self, *args, **kwargs):
return self.value._write_object_proto(*args, **kwargs)
@property
def _shared_name(self):
return self.value._shared_name
@ -57,8 +54,12 @@ class Variable(
def _restore_from_tensors(self, restored_tensors):
return self.value._restore_from_tensors(restored_tensors)
def _export_to_saved_model_graph(self, object_map, tensor_map, options, **kwargs):
resource_list = self.value._export_to_saved_model_graph(object_map, tensor_map, options, **kwargs)
def _export_to_saved_model_graph(
self, object_map, tensor_map, options, **kwargs
):
resource_list = self.value._export_to_saved_model_graph(
object_map, tensor_map, options, **kwargs
)
object_map[self] = tf.Variable(object_map[self.value])
return resource_list

@ -2,7 +2,6 @@ import tensorflow as tf
class TFLayer(tf.__internal__.tracking.AutoTrackable):
@property
def _default_save_signature(self):
"""For SavedModel support: returns the default serving signature."""
@ -12,11 +11,15 @@ class TFLayer(tf.__internal__.tracking.AutoTrackable):
input_shape = tuple(shapes_dict.values())[0]
input_signature = [tf.TensorSpec(input_shape, self.compute_dtype)]
else:
input_signature = [tf.nest.map_structure(lambda x: tf.TensorSpec(x.shape, self.compute_dtype), shapes_dict)]
input_signature = [
tf.nest.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
shapes_dict,
)
]
@tf.function(input_signature=input_signature)
def serving_default(inputs):
return self(inputs)
return serving_default