Replace dm-tree with optree (#19306)

* Refactor `keras.utils.tree`

* Fix tests

* Replace `dm-tree` with `optree`

* Eliminate `tf.nest`

* Resolve comments

* Fix merge conflicts

* Update exporting path
This commit is contained in:
james77777778 2024-03-15 23:53:05 +08:00 committed by GitHub
parent 3fcb38c1a7
commit e2b43e2e74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 876 additions and 123 deletions

@ -9,6 +9,7 @@ from keras.backend.common.keras_tensor import KerasTensor
from keras.backend.common.name_scope import name_scope as base_name_scope
from keras.backend.common.stateless_scope import StatelessScope
from keras.backend.common.stateless_scope import in_stateless_scope
from keras.utils import tree
from keras.utils.naming import auto_name
SUPPORTS_SPARSE_TENSORS = True
@ -189,7 +190,7 @@ def compute_output_spec(fn, *args, **kwargs):
)
return x
args, kwargs = tf.nest.map_structure(
args, kwargs = tree.map_structure(
convert_keras_tensor_to_tf, (args, kwargs)
)
tf_out = fn(*args, **kwargs)
@ -201,9 +202,7 @@ def compute_output_spec(fn, *args, **kwargs):
)
return x
output_spec = tf.nest.map_structure(
convert_tf_to_keras_tensor, tf_out
)
output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)
return output_spec

@ -3,6 +3,7 @@ import tensorflow as tf
from keras.backend.tensorflow.trackable import KerasAutoTrackable
from keras.utils import tf_utils
from keras.utils import tracking
from keras.utils import tree
class TFLayer(KerasAutoTrackable):
@ -27,16 +28,16 @@ class TFLayer(KerasAutoTrackable):
if self._saved_model_inputs_spec is not None:
return # Already set.
inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)
args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])
kwargs_spec = {}
# Filter out non-tensor arguments from kwargs.
for key, kwarg in kwargs.items():
flat_kwarg = tf.nest.flatten(kwarg)
flat_kwarg = tree.flatten(kwarg)
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
if any(s is None for s in flat_specs):
continue
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)
self._saved_model_inputs_spec = inputs_spec
self._saved_model_arg_spec = (
@ -94,7 +95,7 @@ class TFLayer(KerasAutoTrackable):
if inputs is not None:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
inputs,
)
@ -108,7 +109,7 @@ class TFLayer(KerasAutoTrackable):
]
else:
input_signature = [
tf.nest.map_structure(
tree.map_structure(
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
shapes_dict,
)

@ -17,6 +17,7 @@ from keras.backend.common.backend_utils import canonicalize_axis
from keras.backend.common.backend_utils import to_tuple_or_list
from keras.backend.tensorflow import sparse
from keras.backend.tensorflow.core import convert_to_tensor
from keras.utils import tree
@sparse.elementwise_binary_union(tf.sparse.add)
@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):
def einsum(subscripts, *operands, **kwargs):
operands = tf.nest.map_structure(convert_to_tensor, operands)
operands = tree.map_structure(convert_to_tensor, operands)
subscripts = _normalize_einsum_subscripts(subscripts)
def is_valid_for_custom_ops(subscripts, *operands):
@ -240,7 +241,7 @@ def einsum(subscripts, *operands, **kwargs):
# output_type="int32"
if "int" in compute_dtype and output_type is None:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = use_custom_ops(subscripts, *operands, output_type=output_type)
@ -248,7 +249,7 @@ def einsum(subscripts, *operands, **kwargs):
# TODO: tf.einsum doesn't support integer dtype with gpu
if "int" in compute_dtype:
compute_dtype = config.floatx()
operands = tf.nest.map_structure(
operands = tree.map_structure(
lambda x: tf.cast(x, compute_dtype), operands
)
result = tf.einsum(subscripts, *operands, **kwargs)
@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
)
for x in xs
]
xs = tf.nest.map_structure(convert_to_tensor, xs)
xs = tree.map_structure(convert_to_tensor, xs)
dtype_set = set([x.dtype for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)
return tf.concat(xs, axis=axis)
@ -872,7 +873,7 @@ def digitize(x, bins):
bins = list(bins)
# bins must be float type
bins = tf.nest.map_structure(lambda x: float(x), bins)
bins = tree.map_structure(lambda x: float(x), bins)
# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
# int16, uint8, uint16, uint32
@ -1023,7 +1024,7 @@ def hstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
rank = tf.rank(xs[0])
return tf.cond(
tf.equal(rank, 1),
@ -1328,9 +1329,7 @@ def ndim(x):
def nonzero(x):
x = convert_to_tensor(x)
result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
return tf.nest.map_structure(
lambda indices: tf.cast(indices, "int32"), result
)
return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result)
def not_equal(x1, x2):
@ -1620,7 +1619,7 @@ def stack(x, axis=0):
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)
return tf.stack(x, axis=axis)
@ -1807,7 +1806,7 @@ def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
return tf.concat(xs, axis=0)

@ -225,7 +225,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
outputs = one_step_on_data_distributed(data[:1])
for single_step_data in data[1:]:
step_outputs = one_step_on_data_distributed([single_step_data])
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
)
return outputs
@ -473,7 +473,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
def append_to_outputs(batch_outputs, outputs):
if outputs is None:
outputs = tf.nest.map_structure(
outputs = tree.map_structure(
lambda batch_output: [batch_output],
batch_outputs,
)
@ -521,7 +521,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
outputs = tree.map_structure_up_to(
batch_outputs, potentially_ragged_concat, outputs
)
return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs)
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
def train_on_batch(
self,
@ -549,7 +549,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
yield (x, y, sample_weight)
logs = self.train_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
@ -568,7 +568,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
yield (x, y, sample_weight)
logs = self.test_function(data())
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
logs = tree.map_structure(lambda x: np.array(x), logs)
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
@ -576,7 +576,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
def predict_on_batch(self, x):
self.make_predict_function()
batch_outputs = self.predict_function([(x,)])
batch_outputs = tf.nest.map_structure(
batch_outputs = tree.map_structure(
convert_to_np_if_not_ragged, batch_outputs
)
return batch_outputs
@ -771,7 +771,7 @@ def reduce_per_replica(values, strategy, reduction):
f"Received: reduction={reduction}."
)
return tf.nest.map_structure(_reduce, values)
return tree.map_structure(_reduce, values)
def _multi_worker_concat(v, strategy):

@ -8,6 +8,7 @@ from keras.layers import Layer
from keras.models import Functional
from keras.models import Sequential
from keras.utils import io_utils
from keras.utils import tree
from keras.utils.module_utils import tensorflow as tf
@ -143,16 +144,16 @@ class ExportArchive:
# Variables in the lists below are actually part of the trackables
# that get saved, because the lists are created in __init__.
if backend.backend() == "jax":
self._tf_trackable.variables += tf.nest.flatten(
tf.nest.map_structure(tf.Variable, resource.variables)
self._tf_trackable.variables += tree.flatten(
tree.map_structure(tf.Variable, resource.variables)
)
self._tf_trackable.trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.trainable_variables
)
)
self._tf_trackable.non_trainable_variables += tf.nest.flatten(
tf.nest.map_structure(
self._tf_trackable.non_trainable_variables += tree.flatten(
tree.map_structure(
tf.Variable, resource.non_trainable_variables
)
)
@ -362,9 +363,7 @@ class ExportArchive:
f"{list(set(type(v) for v in variables))}"
)
if backend.backend() == "jax":
variables = tf.nest.flatten(
tf.nest.map_structure(tf.Variable, variables)
)
variables = tree.flatten(tree.map_structure(tf.Variable, variables))
setattr(self._tf_trackable, name, list(variables))
def write_out(self, filepath, options=None):
@ -470,7 +469,7 @@ class ExportArchive:
def _spec_to_poly_shape(self, spec):
if isinstance(spec, (dict, list)):
return tf.nest.map_structure(self._spec_to_poly_shape, spec)
return tree.map_structure(self._spec_to_poly_shape, spec)
spec_shape = spec.shape
spec_shape = str(spec_shape).replace("None", "b")
return spec_shape
@ -500,7 +499,7 @@ def export_model(model, filepath):
export_archive = ExportArchive()
export_archive.track(model)
if isinstance(model, (Functional, Sequential)):
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
input_signature = tree.map_structure(_make_tensor_spec, model.inputs)
if isinstance(input_signature, list) and len(input_signature) > 1:
input_signature = [input_signature]
export_archive.add_endpoint("serve", model.__call__, input_signature)

@ -261,7 +261,7 @@ def _clone_functional_model(model, input_tensors=None, clone_function=None):
)
try:
tree.assert_same_structure(input_tensors, model.input)
except TypeError as e:
except (ValueError, TypeError) as e:
raise ValueError(
"`input_tensors` must have the same structure as model.input"
f"\nReference structure: {model.input}"

@ -789,9 +789,8 @@ class CoreOpsCallsTests(testing.TestCase):
def test_cond_check_output_spec_other_types(self):
cond_op = core.Cond()
# Create mock objects with dtype and shape attributes
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))
def test_cond_check_output_spec_none(self):

@ -1,5 +1,8 @@
from functools import wraps
import optree
import optree.utils
from keras.backend.common.global_state import get_global_attribute
from keras.backend.common.global_state import set_global_attribute
from keras.utils import python_utils
@ -110,6 +113,7 @@ class Tracker:
self.stored_ids[store_name].add(id(value))
@optree.register_pytree_node_class(namespace="keras")
class TrackedList(list):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
@ -160,7 +164,17 @@ class TrackedList(list):
if self.tracker:
self.tracker.untrack(value)
def tree_flatten(self):
# For optree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)
@optree.register_pytree_node_class(namespace="keras")
class TrackedDict(dict):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
@ -199,7 +213,20 @@ class TrackedDict(dict):
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree
keys, values = optree.utils.unzip2(
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
)
return values, list(keys), keys
@classmethod
def tree_unflatten(cls, keys, values):
# For optree
return cls(optree.utils.safe_zip(keys, values))
@optree.register_pytree_node_class(namespace="keras")
class TrackedSet(set):
def __init__(self, values=None, tracker=None):
self.tracker = tracker
@ -233,3 +260,12 @@ class TrackedSet(set):
for value in self:
self.tracker.untrack(value)
super().clear()
def tree_flatten(self):
# For optree
return (self, None)
@classmethod
def tree_unflatten(cls, metadata, children):
# For optree
return cls(children)

@ -1,56 +1,389 @@
import tree
import collections
import collections.abc
import types
import optree
def is_nested(structure):
return tree.is_nested(structure)
from keras.api_export import keras_export
from keras.backend.config import backend
# Register backend-specific node classes
if backend() == "tensorflow":
from tensorflow.python.trackable.data_structures import ListWrapper
def flatten(structure):
return tree.flatten(structure)
def map_structure(func, *structures, **kwargs):
return tree.map_structure(func, *structures, **kwargs)
def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
return tree.map_structure_up_to(
shallow_structure, func, *structures, **kwargs
optree.register_pytree_node(
ListWrapper,
lambda x: (x, None),
lambda metadata, children: ListWrapper(list(children)),
namespace="keras",
)
def traverse(func, structure, top_down=True):
return tree.traverse(func, structure, top_down=top_down)
@keras_export("keras.tree.is_nested")
def is_nested(structure):
"""Checks if a given structure is nested.
Examples:
def assert_same_structure(a, b, check_types=True):
return tree.assert_same_structure(a, b, check_types=check_types)
def sequence_like(instance, args):
"""Converts the sequence `args` to the same type as `instance`.
>>> keras.tree.is_nested(42)
False
>>> keras.tree.is_nested({"foo": 42})
True
Args:
instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
`collections.OrderedDict`.
args: elements to be converted to the `instance` type.
structure: A structure to check.
Returns:
`args` with the type of `instance`.
`True` if a given structure is nested, i.e. is a sequence, a mapping,
or a namedtuple, and `False` otherwise.
"""
return tree._sequence_like(instance, args)
return not optree.tree_is_leaf(
structure, none_is_leaf=True, namespace="keras"
)
@keras_export("keras.tree.traverse")
def traverse(func, structure, top_down=True):
"""Traverses the given nested structure, applying the given function.
The traversal is depth-first. If `top_down` is True (default), parents
are returned before their children (giving the option to avoid traversing
into a sub-tree).
Examples:
>>> v = []
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True)
[(1, 2), [3], {'a': 4}]
>>> v
[[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]
>>> v = []
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False)
[(1, 2), [3], {'a': 4}]
>>> v
[1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]
Args:
func: The function to be applied to each sub-nest of the structure.
When traversing top-down:
If `func(subtree) is None` the traversal continues into the
sub-tree.
If `func(subtree) is not None` the traversal does not continue
into the sub-tree. The sub-tree will be replaced by `func(subtree)`
in the returned structure (to replace the sub-tree with `None`, use
the special value `_MAP_TO_NONE`).
When traversing bottom-up:
If `func(subtree) is None` the traversed sub-tree is returned
unaltered.
If `func(subtree) is not None` the sub-tree will be replaced by
`func(subtree)` in the returned structure (to replace the sub-tree
with None, use the special value `_MAP_TO_NONE`).
structure: The structure to traverse.
top_down: If True, parent structures will be visited before their
children.
Returns:
The structured output from the traversal.
"""
# From https://github.com/google/jax/pull/19695
def traverse_children():
children, treedef = optree.tree_flatten(
structure,
is_leaf=lambda x: x is not structure,
none_is_leaf=True,
namespace="keras",
)
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
return structure
else:
return optree.tree_unflatten(
treedef,
[traverse(func, c, top_down=top_down) for c in children],
)
if top_down:
ret = func(structure)
if ret is None:
return traverse_children()
else:
traversed_structure = traverse_children()
ret = func(traversed_structure)
if ret is None:
return traversed_structure
return None if ret is _MAP_TO_NONE else ret
@keras_export("keras.tree.flatten")
def flatten(structure):
"""Flattens a possibly nested structure into a list.
In the case of dict instances, the sequence consists of the values,
sorted by key to ensure deterministic behavior. This is true also for
`collections.OrderedDict` instances: their sequence order is
considered. The same convention is followed in `unflatten_as`.
This correctly unflattens dicts and `OrderedDict` after they have been
flattened, or vice-versa.
Dictionaries with non-sortable keys cannot be flattened.
Examples:
>>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]])
[1, 2, 3, 4, 5, 6]
>>> keras.tree.flatten(None)
[None]
>>> keras.tree.flatten(1)
[1]
>>> keras.tree.flatten({100: 'world!', 6: 'Hello'})
['Hello', 'world!']
Args:
structure: An arbitrarily nested structure.
Returns:
A list, the flattened version of the input `structure`.
"""
# optree.tree_flatten returns a pair (leaves, treespec) where the first
# element is a list of leaf values and the second element is a treespec
# representing the structure of the pytree.
leaves, _ = optree.tree_flatten(
structure, none_is_leaf=True, namespace="keras"
)
return leaves
@keras_export("keras.tree.unflatten_as")
def unflatten_as(structure, flat_sequence):
"""Unflattens a sequence into a given structure.
If `structure` is a scalar, `flat_sequence` must be a single-element list;
in this case the return value is ``flat_sequence[0]``.
If `structure` is or contains a dict instance, the keys will be sorted to
pack the flat sequence in deterministic order. This is true also for
`collections.OrderedDict` instances: their sequence order is considered.
The same convention is followed in `flatten`. This correctly unflattens
dicts and `OrderedDict` after they have been flattened, or vice-versa.
Dictionaries with non-sortable keys cannot be unflattened.
Examples:
>>> keras.tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8])
[[5, 6], [[7], [8]]]
>>> keras.tree.unflatten_as(None, [1])
1
>>> keras.tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!'])
{1: 'Hello', 2: 'world!'}
Args:
structure: Arbitrarily nested structure.
flat_sequence: Sequence to unflatten.
Returns:
`flat_sequence` unflattened into `structure`.
"""
if not is_nested(flat_sequence):
raise TypeError(
f"flat_sequence must be a sequence not a {type(flat_sequence)}:\n"
f"{flat_sequence}"
)
if not is_nested(structure):
if len(flat_sequence) != 1:
raise ValueError(
"Structure is a scalar but "
f"len(flat_sequence) == {len(flat_sequence)} > 1"
)
return flat_sequence[0]
structure_spec = optree.tree_structure(
structure, none_is_leaf=True, namespace="keras"
)
return structure_spec.unflatten(flat_sequence)
@keras_export("keras.tree.map_structure")
def map_structure(func, *structures):
"""Maps `func` through given structures.
Examples:
>>> structure = [[1], [2], [3]]
>>> keras.tree.map_structure(lambda v: v**2, structure)
[[1], [4], [9]]
>>> keras.tree.map_structure(lambda x, y: x * y, structure, structure)
[[1], [4], [9]]
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
>>> structure = Foo(a=1, b=2)
>>> keras.tree.map_structure(lambda v: v * 2, structure)
Foo(a=2, b=4)
Args:
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
Returns:
A new structure with the same layout as the given ones.
"""
if not callable(func):
raise TypeError(f"`func` must be callable. Received: func={func}")
if not structures:
raise ValueError("Must provide at least one structure")
for other in structures[1:]:
assert_same_structure(structures[0], other, check_types=False)
return optree.tree_map(
func, *structures, none_is_leaf=True, namespace="keras"
)
@keras_export("keras.tree.map_structure_up_to")
def map_structure_up_to(shallow_structure, func, *structures):
"""Maps `func` through given structures up to `shallow_structure`.
This is a variant of `map_structure` which only maps the given structures
up to `shallow_structure`. All further nested components are retained as-is.
Examples:
>>> shallow_structure = [None, None]
>>> structure = [[1, 1], [2, 2]]
>>> keras.tree.map_structure_up_to(shallow_structure, len, structure)
[2, 2]
>>> shallow_structure = [None, [None, None]]
>>> keras.tree.map_structure_up_to(shallow_structure, str, structure)
['[1, 1]', ['2', '2']]
Args:
shallow_structure: A structure with layout common to all `structures`.
func: A callable that accepts as many arguments as there are structures.
*structures: Arbitrarily nested structures of the same layout.
Returns:
A new structure with the same layout as `shallow_structure`.
"""
return _map_structure_with_path_up_to(
shallow_structure,
lambda _, *args: func(*args), # Discards path.
*structures,
)
@keras_export("keras.tree.assert_same_structure")
def assert_same_structure(a, b, check_types=True):
"""Asserts that two structures are nested in the same way.
Note that namedtuples with identical name and fields will not be considered
as same structures even `check_types=False`.
Examples:
>>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)])
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
>>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])
>>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3))
>>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))
Traceback (most recent call last):
...
ValueError: `a` and `b` don't have the same structure.
...
Args:
a: an arbitrarily nested structure.
b: an arbitrarily nested structure.
check_types: if `True` (default) types of leaves are checked as well.
"""
a_structure = optree.tree_structure(a, none_is_leaf=True, namespace="keras")
b_structure = optree.tree_structure(b, none_is_leaf=True, namespace="keras")
if a_structure != b_structure:
raise ValueError(
"`a` and `b` don't have the same structure. "
f"Received: structure of a={a_structure}, "
f"structure of b={b_structure}"
)
if check_types:
type_structure = optree.tree_map(
lambda x, y: type(x) is type(y),
a,
b,
none_is_leaf=True,
namespace="keras",
)
if not optree.tree_all(
type_structure, none_is_leaf=True, namespace="keras"
):
raise TypeError(
"The type of the leaves of `a` and `b` doesn't match."
)
@keras_export("keras.tree.pack_sequence_as")
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
"""Implements sequence packing, i.e. nest.pack_sequence_as()."""
is_nested_fn = tree.is_nested
sequence_fn = sequence_fn or tree._sequence_like
"""Returns a given flattened sequence packed into a given structure.
If `structure` is an atom, `flat_sequence` must be a single-item list; in
this case the return value is `flat_sequence[0]`.
If `structure` is or contains a dict instance, the keys will be sorted to
pack the flat sequence in deterministic order. This is true also for
`OrderedDict` instances: their sequence order is considered. The same
convention is followed in `flatten`. This correctly repacks dicts and
`OrderedDicts` after they have been flattened, or vice-versa.
Dictionaries with non-sortable keys cannot be flattened.
Examples:
>>> structure = {"key3": "", "key1": "", "key2": ""}
>>> flat_sequence = ["value1", "value2", "value3"]
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
{"key3": "value3", "key1": "value1", "key2": "value2"}
>>> structure = (("a", "b"), ("c", "d", "e"), "f")
>>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
>>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
... "key1": {"e": "val1", "d": "val2"}}
>>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
{'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
>>> structure = ["a"]
>>> flat_sequence = [np.array([[1, 2], [3, 4]])]
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
[array([[1, 2],
[3, 4]])]
>>> structure = ["a"]
>>> flat_sequence = [keras.ops.ones([2, 2])]
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
[array([[1., 1.],
[1., 1.]]]
Args:
structure: Arbitrarily nested structure.
flat_sequence: Flat sequence to pack.
sequence_fn: Defaults to `_sequence_like`.
Returns:
`flat_sequence` converted to have the same recursive structure as
`structure`.
"""
sequence_fn = sequence_fn or _sequence_like
def truncate(value, length):
value_str = str(value)
return value_str[:length] + (value_str[length:] and "...")
if not is_nested_fn(flat_sequence):
if not is_nested(flat_sequence):
raise TypeError(
"Attempted to pack value:\n {}\ninto a structure, but found "
"incompatible type `{}` instead.".format(
@ -58,7 +391,7 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
)
)
if not is_nested_fn(structure):
if not is_nested(structure):
if len(flat_sequence) != 1:
raise ValueError(
"The target structure is of type `{}`\n {}\nHowever the input "
@ -74,13 +407,13 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
return flat_sequence[0]
try:
final_index, packed = packed_nest_with_indices(
structure, flat_sequence, 0, is_nested_fn, sequence_fn
final_index, packed = _packed_nest_with_indices(
structure, flat_sequence, 0, sequence_fn
)
if final_index < len(flat_sequence):
raise IndexError
except IndexError:
flat_structure = tree.flatten(structure)
flat_structure = flatten(structure)
if len(flat_structure) != len(flat_sequence):
# pylint: disable=raise-missing-from
raise ValueError(
@ -92,33 +425,147 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
return sequence_fn(structure, packed)
def packed_nest_with_indices(
structure, flat, index, is_nested_fn, sequence_fn=None
):
"""Helper function for pack_sequence_as.
@keras_export("keras.tree.lists_to_tuples")
def lists_to_tuples(structure):
"""Converts `list`s to `tuple`s."""
Args:
structure: structure to mimic.
flat: Flattened values to output substructure for.
index: Index at which to start reading from flat.
is_nested_fn: Function used to test if a value should
be treated as a nested structure.
sequence_fn: Function used to generate a new structure instance.
def sequence_fn(instance, args):
if isinstance(instance, list):
return tuple(args)
return _sequence_like(instance, args)
Returns:
The tuple (new_index, child), where:
* new_index - the updated index into `flat`
having processed `structure`.
* packed - the subset of `flat` corresponding to `structure`,
having started at `index`, and packed into the same nested
format.
"""
return pack_sequence_as(
structure, flatten(structure), sequence_fn=sequence_fn
)
class _MapToNone:
"""A special object used as a sentinel within `traverse`."""
def __repr__(self):
return "keras.utils.tree._MAP_TO_NONE"
_MAP_TO_NONE = _MapToNone()
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
if isinstance(shallow_tree, (str, bytes)) or not (
isinstance(
shallow_tree, (collections.abc.Mapping, collections.abc.Sequence)
)
or optree.is_namedtuple(shallow_tree)
):
yield (path, input_tree)
else:
input_tree = dict(_yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
shallow_subtree, input_subtree, path=subpath
):
yield (leaf_path, leaf_value)
def _multiyield_flat_up_to(shallow_tree, *input_trees):
"""Same as `_yield_flat_up_to`, but takes multiple input trees."""
zipped_iterators = zip(
*[
_yield_flat_up_to(shallow_tree, input_tree)
for input_tree in input_trees
]
)
try:
for paths_and_values in zipped_iterators:
paths, values = zip(*paths_and_values)
yield paths[:1] + values
except KeyError as e:
paths = locals().get("paths", ((),))
raise ValueError(
f"Could not find key '{e.args[0]}' in some `input_trees`. "
"Please ensure the structure of all `input_trees` are "
"compatible with `shallow_tree`. The last valid path "
f"yielded was {paths[0]}."
) from e
def _map_structure_with_path_up_to(shallow_structure, func, *structures):
results = []
for path_and_values in _multiyield_flat_up_to(
shallow_structure, *structures
):
results.append(func(*path_and_values))
shallow_structure_spec = optree.tree_structure(
shallow_structure, none_is_leaf=True, namespace="keras"
)
return shallow_structure_spec.unflatten(results)
def _sequence_like(instance, args):
# TODO: Support attrs library
if isinstance(instance, (dict, collections.abc.Mapping)):
# Pack dictionaries in a deterministic order by sorting the keys.
# Notice this means that we ignore the original order of `OrderedDict`
# instances. This is intentional, to avoid potential bugs caused by
# mixing ordered and plain dicts (e.g., flattening a dict but using a
# corresponding `OrderedDict` to pack it back).
result = dict(zip(sorted(instance), args))
keys_and_values = ((key, result[key]) for key in instance)
if isinstance(instance, collections.defaultdict):
# `defaultdict` requires a default factory as the first argument.
return type(instance)(instance.default_factory, keys_and_values)
elif isinstance(instance, types.MappingProxyType):
# MappingProxyType requires a dict to proxy to.
return type(instance)(dict(keys_and_values))
else:
return type(instance)(keys_and_values)
elif isinstance(instance, collections.abc.MappingView):
# We can't directly construct mapping views, so we create a list instead
return list(args)
elif optree.is_namedtuple(instance):
instance_type = type(instance)
try:
return instance_type(*args)
except Exception as e:
raise TypeError(
f"Couldn't traverse {instance!r} with arguments {args}"
) from e
else:
# Not a namedtuple
return type(instance)(args)
def _yield_sorted_items(iterable):
# TODO: Support attrs library
if isinstance(iterable, collections.abc.Mapping):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of
# `OrderedDict` instances. This is intentional, to avoid potential bugs
# caused by mixing ordered and plain dicts (e.g., flattening a dict but
# using a corresponding `OrderedDict` to pack it back).
for key in sorted(iterable):
yield key, iterable[key]
elif optree.is_namedtuple(iterable):
for field in iterable._fields:
yield (field, getattr(iterable, field))
else:
for item in enumerate(iterable):
yield item
def _yield_value(iterable):
for _, v in _yield_sorted_items(iterable):
yield v
def _packed_nest_with_indices(structure, flat, index, sequence_fn=None):
packed = []
sequence_fn = sequence_fn or tree._sequence_like
for s in yield_value(structure):
if is_nested_fn(s):
new_index, child = packed_nest_with_indices(
s, flat, index, is_nested_fn, sequence_fn
sequence_fn = sequence_fn or _sequence_like
for s in _yield_value(structure):
if is_nested(s):
new_index, child = _packed_nest_with_indices(
s, flat, index, sequence_fn
)
packed.append(sequence_fn(s, child))
index = new_index
@ -126,21 +573,3 @@ def packed_nest_with_indices(
packed.append(flat[index])
index += 1
return index, packed
def yield_value(iterable):
for _, v in tree._yield_sorted_items(iterable):
yield v
def lists_to_tuples(structure):
def sequence_fn(instance, args):
if isinstance(instance, list):
return tuple(args)
return tree._sequence_like(instance, args)
return pack_sequence_as(
structure,
tree.flatten(structure),
sequence_fn=sequence_fn,
)

291
keras/utils/tree_test.py Normal file

@ -0,0 +1,291 @@
import collections
import numpy as np
from keras import ops
from keras import testing
from keras.utils import tree
STRUCTURE1 = (((1, 2), 3), 4, (5, 6))
STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs")
STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,))
class TreeTest(testing.TestCase):
def test_is_nested(self):
self.assertFalse(tree.is_nested("1234"))
self.assertFalse(tree.is_nested(b"1234"))
self.assertFalse(tree.is_nested(bytearray("1234", "ascii")))
self.assertTrue(tree.is_nested([1, 3, [4, 5]]))
self.assertTrue(tree.is_nested(((7, 8), (5, 6))))
self.assertTrue(tree.is_nested([]))
self.assertTrue(tree.is_nested({"a": 1, "b": 2}))
self.assertFalse(tree.is_nested(set([1, 2])))
ones = np.ones([2, 3])
self.assertFalse(tree.is_nested(ones))
self.assertFalse(tree.is_nested(np.tanh(ones)))
self.assertFalse(tree.is_nested(np.ones((4, 5))))
def test_flatten_and_unflatten(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
self.assertEqual(
tree.unflatten_as(structure, flat),
(("a", "b"), "c", ("d", "e", ("f", "g"), "h")),
)
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(tree.flatten(structure), flat)
restructured_from_flat = tree.unflatten_as(structure, flat)
self.assertEqual(restructured_from_flat, structure)
self.assertEqual(restructured_from_flat[0].x, 4)
self.assertEqual(restructured_from_flat[0].y, 2)
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
self.assertEqual([5], tree.flatten(5))
self.assertEqual([np.array([5])], tree.flatten(np.array([5])))
self.assertEqual("a", tree.unflatten_as(5, ["a"]))
self.assertEqual(
np.array([5]), tree.unflatten_as("scalar", [np.array([5])])
)
with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
tree.unflatten_as("scalar", [4, 5])
with self.assertRaisesRegex(TypeError, "flat_sequence"):
tree.unflatten_as([4, 5], "bad_sequence")
with self.assertRaises(ValueError):
tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
self.assertEqual(
tree.unflatten_as({1: None, 2: None}, ["Hello", "world!"]),
{1: "Hello", 2: "world!"},
)
def test_flatten_dict_order(self):
ordered = collections.OrderedDict(
[("d", 3), ("b", 1), ("a", 0), ("c", 2)]
)
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = tree.flatten(ordered)
plain_flat = tree.flatten(plain)
self.assertEqual([3, 1, 0, 2], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def test_unflatten_dict_order(self):
ordered = collections.OrderedDict(
[("d", 0), ("b", 0), ("a", 0), ("c", 0)]
)
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
seq = [0, 1, 2, 3]
ordered_reconstruction = tree.unflatten_as(ordered, seq)
plain_reconstruction = tree.unflatten_as(plain, seq)
self.assertEqual(
collections.OrderedDict([("d", 0), ("b", 1), ("a", 2), ("c", 3)]),
ordered_reconstruction,
)
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
def test_map_structure(self):
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1)
tree.assert_same_structure(STRUCTURE1, structure1_plus1)
self.assertAllEqual([2, 3, 4, 5, 6, 7], tree.flatten(structure1_plus1))
structure1_plus_structure2 = tree.map_structure(
lambda x, y: x + y, STRUCTURE1, structure2
)
self.assertEqual(
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
structure1_plus_structure2,
)
self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4))
self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4))
# Empty structures
self.assertEqual((), tree.map_structure(lambda x: x + 1, ()))
self.assertEqual([], tree.map_structure(lambda x: x + 1, []))
self.assertEqual({}, tree.map_structure(lambda x: x + 1, {}))
empty_nt = collections.namedtuple("empty_nt", "")
self.assertEqual(
empty_nt(), tree.map_structure(lambda x: x + 1, empty_nt())
)
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual((), tree.map_structure(lambda x: x + 1, []))
with self.assertRaisesRegex(TypeError, "callable"):
tree.map_structure("bad", structure1_plus1)
with self.assertRaisesRegex(ValueError, "at least one structure"):
tree.map_structure(lambda x: x)
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list)
def test_map_structure_up_to(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = tree.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops.add) * ops.mul,
inp_val,
inp_ops,
)
self.assertEqual(out.a, 6)
self.assertEqual(out.b, 15)
# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = tree.map_structure_up_to(
name_list,
lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list,
data_list,
)
self.assertEqual(
out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]
)
def test_assert_same_structure(self):
tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
tree.assert_same_structure("abc", 1.0, check_types=False)
tree.assert_same_structure(b"abc", 1.0, check_types=False)
tree.assert_same_structure("abc", 1.0, check_types=False)
tree.assert_same_structure(
bytearray("abc", "ascii"), 1.0, check_types=False
)
tree.assert_same_structure("abc", np.array([0, 1]), check_types=False)
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure(
STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS
)
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure([0, 1], np.array([0, 1]))
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure(0, [0, 1])
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure((0, 1), [0, 1])
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure([[3], 4], [3, [4]])
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure({"a": 1}, {"b": 1})
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure(STRUCTURE1, structure1_list)
tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
with self.assertRaisesRegex(ValueError, "have the same structure"):
tree.assert_same_structure(
STRUCTURE1, structure1_list, check_types=False
)
def test_pack_sequence_as(self):
structure = {"key3": "", "key1": "", "key2": ""}
flat_sequence = ["value1", "value2", "value3"]
self.assertEqual(
tree.pack_sequence_as(structure, flat_sequence),
{"key3": "value3", "key1": "value1", "key2": "value2"},
)
structure = (("a", "b"), ("c", "d", "e"), "f")
flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
self.assertEqual(
tree.pack_sequence_as(structure, flat_sequence),
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0),
)
structure = {
"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
"key1": {"e": "val1", "d": "val2"},
}
flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
self.assertEqual(
tree.pack_sequence_as(structure, flat_sequence),
{
"key3": {"c": (1.0, 2.0), "a": 3.0},
"key1": {"e": "val1", "d": "val2"},
},
)
structure = ["a"]
flat_sequence = [np.array([[1, 2], [3, 4]])]
self.assertAllClose(
tree.pack_sequence_as(structure, flat_sequence),
[np.array([[1, 2], [3, 4]])],
)
structure = ["a"]
flat_sequence = [ops.ones([2, 2])]
self.assertAllClose(
tree.pack_sequence_as(structure, flat_sequence),
[ops.ones([2, 2])],
)
with self.assertRaisesRegex(TypeError, "Attempted to pack value:"):
structure = ["a"]
flat_sequence = 1
tree.pack_sequence_as(structure, flat_sequence)
with self.assertRaisesRegex(ValueError, "The target structure is of"):
structure = "a"
flat_sequence = [1, 2]
tree.pack_sequence_as(structure, flat_sequence)
def test_lists_to_tuples(self):
structure = [1, 2, 3]
self.assertEqual(tree.lists_to_tuples(structure), (1, 2, 3))
structure = [[1], [2, 3]]
self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, 3)))
structure = [[1], [2, [3]]]
self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, (3,))))
def test_traverse(self):
# Lists to tuples
structure = [(1, 2), [3], {"a": [4]}]
self.assertEqual(
((1, 2), (3,), {"a": (4,)}),
tree.traverse(
lambda x: tuple(x) if isinstance(x, list) else x,
structure,
top_down=False,
),
)
# EarlyTermination
structure = [(1, [2]), [3, (4, 5, 6)]]
visited = []
def visit(x):
visited.append(x)
return "X" if isinstance(x, tuple) and len(x) > 2 else None
output = tree.traverse(visit, structure)
self.assertEqual([(1, [2]), [3, "X"]], output)
self.assertEqual(
[
[(1, [2]), [3, (4, 5, 6)]],
(1, [2]),
1,
[2],
2,
[3, (4, 5, 6)],
3,
(4, 5, 6),
],
visited,
)

@ -15,5 +15,5 @@ google
tensorboard-plugin-profile
rich
build
dm-tree
optree
pytest-cov

@ -44,7 +44,7 @@ setup(
"rich",
"namex",
"h5py",
"dm-tree",
"optree",
"ml-dtypes",
],
# Supported Python versions