Replace dm-tree
with optree
(#19306)
* Refactor `keras.utils.tree` * Fix tests * Replace `dm-tree` with `optree` * Eliminate `tf.nest` * Resolve comments * Fix merge conflicts * Update exporting path
This commit is contained in:
parent
3fcb38c1a7
commit
e2b43e2e74
@ -9,6 +9,7 @@ from keras.backend.common.keras_tensor import KerasTensor
|
||||
from keras.backend.common.name_scope import name_scope as base_name_scope
|
||||
from keras.backend.common.stateless_scope import StatelessScope
|
||||
from keras.backend.common.stateless_scope import in_stateless_scope
|
||||
from keras.utils import tree
|
||||
from keras.utils.naming import auto_name
|
||||
|
||||
SUPPORTS_SPARSE_TENSORS = True
|
||||
@ -189,7 +190,7 @@ def compute_output_spec(fn, *args, **kwargs):
|
||||
)
|
||||
return x
|
||||
|
||||
args, kwargs = tf.nest.map_structure(
|
||||
args, kwargs = tree.map_structure(
|
||||
convert_keras_tensor_to_tf, (args, kwargs)
|
||||
)
|
||||
tf_out = fn(*args, **kwargs)
|
||||
@ -201,9 +202,7 @@ def compute_output_spec(fn, *args, **kwargs):
|
||||
)
|
||||
return x
|
||||
|
||||
output_spec = tf.nest.map_structure(
|
||||
convert_tf_to_keras_tensor, tf_out
|
||||
)
|
||||
output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)
|
||||
return output_spec
|
||||
|
||||
|
||||
|
@ -3,6 +3,7 @@ import tensorflow as tf
|
||||
from keras.backend.tensorflow.trackable import KerasAutoTrackable
|
||||
from keras.utils import tf_utils
|
||||
from keras.utils import tracking
|
||||
from keras.utils import tree
|
||||
|
||||
|
||||
class TFLayer(KerasAutoTrackable):
|
||||
@ -27,16 +28,16 @@ class TFLayer(KerasAutoTrackable):
|
||||
if self._saved_model_inputs_spec is not None:
|
||||
return # Already set.
|
||||
|
||||
inputs_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, inputs)
|
||||
args_spec = tf.nest.map_structure(tf_utils.get_tensor_spec, args or [])
|
||||
inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)
|
||||
args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])
|
||||
kwargs_spec = {}
|
||||
# Filter out non-tensor arguments from kwargs.
|
||||
for key, kwarg in kwargs.items():
|
||||
flat_kwarg = tf.nest.flatten(kwarg)
|
||||
flat_kwarg = tree.flatten(kwarg)
|
||||
flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]
|
||||
if any(s is None for s in flat_specs):
|
||||
continue
|
||||
kwargs_spec[key] = tf.nest.pack_sequence_as(kwarg, flat_specs)
|
||||
kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)
|
||||
|
||||
self._saved_model_inputs_spec = inputs_spec
|
||||
self._saved_model_arg_spec = (
|
||||
@ -94,7 +95,7 @@ class TFLayer(KerasAutoTrackable):
|
||||
|
||||
if inputs is not None:
|
||||
input_signature = [
|
||||
tf.nest.map_structure(
|
||||
tree.map_structure(
|
||||
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
|
||||
inputs,
|
||||
)
|
||||
@ -108,7 +109,7 @@ class TFLayer(KerasAutoTrackable):
|
||||
]
|
||||
else:
|
||||
input_signature = [
|
||||
tf.nest.map_structure(
|
||||
tree.map_structure(
|
||||
lambda x: tf.TensorSpec(x.shape, self.compute_dtype),
|
||||
shapes_dict,
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ from keras.backend.common.backend_utils import canonicalize_axis
|
||||
from keras.backend.common.backend_utils import to_tuple_or_list
|
||||
from keras.backend.tensorflow import sparse
|
||||
from keras.backend.tensorflow.core import convert_to_tensor
|
||||
from keras.utils import tree
|
||||
|
||||
|
||||
@sparse.elementwise_binary_union(tf.sparse.add)
|
||||
@ -95,7 +96,7 @@ def _normalize_einsum_subscripts(subscripts):
|
||||
|
||||
|
||||
def einsum(subscripts, *operands, **kwargs):
|
||||
operands = tf.nest.map_structure(convert_to_tensor, operands)
|
||||
operands = tree.map_structure(convert_to_tensor, operands)
|
||||
subscripts = _normalize_einsum_subscripts(subscripts)
|
||||
|
||||
def is_valid_for_custom_ops(subscripts, *operands):
|
||||
@ -240,7 +241,7 @@ def einsum(subscripts, *operands, **kwargs):
|
||||
# output_type="int32"
|
||||
if "int" in compute_dtype and output_type is None:
|
||||
compute_dtype = config.floatx()
|
||||
operands = tf.nest.map_structure(
|
||||
operands = tree.map_structure(
|
||||
lambda x: tf.cast(x, compute_dtype), operands
|
||||
)
|
||||
result = use_custom_ops(subscripts, *operands, output_type=output_type)
|
||||
@ -248,7 +249,7 @@ def einsum(subscripts, *operands, **kwargs):
|
||||
# TODO: tf.einsum doesn't support integer dtype with gpu
|
||||
if "int" in compute_dtype:
|
||||
compute_dtype = config.floatx()
|
||||
operands = tf.nest.map_structure(
|
||||
operands = tree.map_structure(
|
||||
lambda x: tf.cast(x, compute_dtype), operands
|
||||
)
|
||||
result = tf.einsum(subscripts, *operands, **kwargs)
|
||||
@ -763,11 +764,11 @@ def concatenate(xs, axis=0):
|
||||
)
|
||||
for x in xs
|
||||
]
|
||||
xs = tf.nest.map_structure(convert_to_tensor, xs)
|
||||
xs = tree.map_structure(convert_to_tensor, xs)
|
||||
dtype_set = set([x.dtype for x in xs])
|
||||
if len(dtype_set) > 1:
|
||||
dtype = dtypes.result_type(*dtype_set)
|
||||
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
|
||||
xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)
|
||||
return tf.concat(xs, axis=axis)
|
||||
|
||||
|
||||
@ -872,7 +873,7 @@ def digitize(x, bins):
|
||||
bins = list(bins)
|
||||
|
||||
# bins must be float type
|
||||
bins = tf.nest.map_structure(lambda x: float(x), bins)
|
||||
bins = tree.map_structure(lambda x: float(x), bins)
|
||||
|
||||
# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
|
||||
# int16, uint8, uint16, uint32
|
||||
@ -1023,7 +1024,7 @@ def hstack(xs):
|
||||
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
|
||||
if len(dtype_set) > 1:
|
||||
dtype = dtypes.result_type(*dtype_set)
|
||||
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
|
||||
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
|
||||
rank = tf.rank(xs[0])
|
||||
return tf.cond(
|
||||
tf.equal(rank, 1),
|
||||
@ -1328,9 +1329,7 @@ def ndim(x):
|
||||
def nonzero(x):
|
||||
x = convert_to_tensor(x)
|
||||
result = tf.unstack(tf.where(tf.cast(x, "bool")), x.shape.rank, axis=1)
|
||||
return tf.nest.map_structure(
|
||||
lambda indices: tf.cast(indices, "int32"), result
|
||||
)
|
||||
return tree.map_structure(lambda indices: tf.cast(indices, "int32"), result)
|
||||
|
||||
|
||||
def not_equal(x1, x2):
|
||||
@ -1620,7 +1619,7 @@ def stack(x, axis=0):
|
||||
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
|
||||
if len(dtype_set) > 1:
|
||||
dtype = dtypes.result_type(*dtype_set)
|
||||
x = tf.nest.map_structure(lambda a: convert_to_tensor(a, dtype), x)
|
||||
x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)
|
||||
return tf.stack(x, axis=axis)
|
||||
|
||||
|
||||
@ -1807,7 +1806,7 @@ def vstack(xs):
|
||||
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
|
||||
if len(dtype_set) > 1:
|
||||
dtype = dtypes.result_type(*dtype_set)
|
||||
xs = tf.nest.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
|
||||
xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)
|
||||
return tf.concat(xs, axis=0)
|
||||
|
||||
|
||||
|
@ -225,7 +225,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
outputs = one_step_on_data_distributed(data[:1])
|
||||
for single_step_data in data[1:]:
|
||||
step_outputs = one_step_on_data_distributed([single_step_data])
|
||||
outputs = tf.nest.map_structure(
|
||||
outputs = tree.map_structure(
|
||||
lambda t1, t2: concat([t1, t2]), outputs, step_outputs
|
||||
)
|
||||
return outputs
|
||||
@ -473,7 +473,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
|
||||
def append_to_outputs(batch_outputs, outputs):
|
||||
if outputs is None:
|
||||
outputs = tf.nest.map_structure(
|
||||
outputs = tree.map_structure(
|
||||
lambda batch_output: [batch_output],
|
||||
batch_outputs,
|
||||
)
|
||||
@ -521,7 +521,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
outputs = tree.map_structure_up_to(
|
||||
batch_outputs, potentially_ragged_concat, outputs
|
||||
)
|
||||
return tf.nest.map_structure(convert_to_np_if_not_ragged, outputs)
|
||||
return tree.map_structure(convert_to_np_if_not_ragged, outputs)
|
||||
|
||||
def train_on_batch(
|
||||
self,
|
||||
@ -549,7 +549,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
yield (x, y, sample_weight)
|
||||
|
||||
logs = self.train_function(data())
|
||||
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
|
||||
logs = tree.map_structure(lambda x: np.array(x), logs)
|
||||
if return_dict:
|
||||
return logs
|
||||
return self._flatten_metrics_in_order(logs)
|
||||
@ -568,7 +568,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
yield (x, y, sample_weight)
|
||||
|
||||
logs = self.test_function(data())
|
||||
logs = tf.nest.map_structure(lambda x: np.array(x), logs)
|
||||
logs = tree.map_structure(lambda x: np.array(x), logs)
|
||||
if return_dict:
|
||||
return logs
|
||||
return self._flatten_metrics_in_order(logs)
|
||||
@ -576,7 +576,7 @@ class TensorFlowTrainer(base_trainer.Trainer):
|
||||
def predict_on_batch(self, x):
|
||||
self.make_predict_function()
|
||||
batch_outputs = self.predict_function([(x,)])
|
||||
batch_outputs = tf.nest.map_structure(
|
||||
batch_outputs = tree.map_structure(
|
||||
convert_to_np_if_not_ragged, batch_outputs
|
||||
)
|
||||
return batch_outputs
|
||||
@ -771,7 +771,7 @@ def reduce_per_replica(values, strategy, reduction):
|
||||
f"Received: reduction={reduction}."
|
||||
)
|
||||
|
||||
return tf.nest.map_structure(_reduce, values)
|
||||
return tree.map_structure(_reduce, values)
|
||||
|
||||
|
||||
def _multi_worker_concat(v, strategy):
|
||||
|
@ -8,6 +8,7 @@ from keras.layers import Layer
|
||||
from keras.models import Functional
|
||||
from keras.models import Sequential
|
||||
from keras.utils import io_utils
|
||||
from keras.utils import tree
|
||||
from keras.utils.module_utils import tensorflow as tf
|
||||
|
||||
|
||||
@ -143,16 +144,16 @@ class ExportArchive:
|
||||
# Variables in the lists below are actually part of the trackables
|
||||
# that get saved, because the lists are created in __init__.
|
||||
if backend.backend() == "jax":
|
||||
self._tf_trackable.variables += tf.nest.flatten(
|
||||
tf.nest.map_structure(tf.Variable, resource.variables)
|
||||
self._tf_trackable.variables += tree.flatten(
|
||||
tree.map_structure(tf.Variable, resource.variables)
|
||||
)
|
||||
self._tf_trackable.trainable_variables += tf.nest.flatten(
|
||||
tf.nest.map_structure(
|
||||
self._tf_trackable.trainable_variables += tree.flatten(
|
||||
tree.map_structure(
|
||||
tf.Variable, resource.trainable_variables
|
||||
)
|
||||
)
|
||||
self._tf_trackable.non_trainable_variables += tf.nest.flatten(
|
||||
tf.nest.map_structure(
|
||||
self._tf_trackable.non_trainable_variables += tree.flatten(
|
||||
tree.map_structure(
|
||||
tf.Variable, resource.non_trainable_variables
|
||||
)
|
||||
)
|
||||
@ -362,9 +363,7 @@ class ExportArchive:
|
||||
f"{list(set(type(v) for v in variables))}"
|
||||
)
|
||||
if backend.backend() == "jax":
|
||||
variables = tf.nest.flatten(
|
||||
tf.nest.map_structure(tf.Variable, variables)
|
||||
)
|
||||
variables = tree.flatten(tree.map_structure(tf.Variable, variables))
|
||||
setattr(self._tf_trackable, name, list(variables))
|
||||
|
||||
def write_out(self, filepath, options=None):
|
||||
@ -470,7 +469,7 @@ class ExportArchive:
|
||||
|
||||
def _spec_to_poly_shape(self, spec):
|
||||
if isinstance(spec, (dict, list)):
|
||||
return tf.nest.map_structure(self._spec_to_poly_shape, spec)
|
||||
return tree.map_structure(self._spec_to_poly_shape, spec)
|
||||
spec_shape = spec.shape
|
||||
spec_shape = str(spec_shape).replace("None", "b")
|
||||
return spec_shape
|
||||
@ -500,7 +499,7 @@ def export_model(model, filepath):
|
||||
export_archive = ExportArchive()
|
||||
export_archive.track(model)
|
||||
if isinstance(model, (Functional, Sequential)):
|
||||
input_signature = tf.nest.map_structure(_make_tensor_spec, model.inputs)
|
||||
input_signature = tree.map_structure(_make_tensor_spec, model.inputs)
|
||||
if isinstance(input_signature, list) and len(input_signature) > 1:
|
||||
input_signature = [input_signature]
|
||||
export_archive.add_endpoint("serve", model.__call__, input_signature)
|
||||
|
@ -261,7 +261,7 @@ def _clone_functional_model(model, input_tensors=None, clone_function=None):
|
||||
)
|
||||
try:
|
||||
tree.assert_same_structure(input_tensors, model.input)
|
||||
except TypeError as e:
|
||||
except (ValueError, TypeError) as e:
|
||||
raise ValueError(
|
||||
"`input_tensors` must have the same structure as model.input"
|
||||
f"\nReference structure: {model.input}"
|
||||
|
@ -789,9 +789,8 @@ class CoreOpsCallsTests(testing.TestCase):
|
||||
|
||||
def test_cond_check_output_spec_other_types(self):
|
||||
cond_op = core.Cond()
|
||||
# Create mock objects with dtype and shape attributes
|
||||
mock_spec1 = Mock(dtype="float32", shape=(2, 2))
|
||||
mock_spec2 = Mock(dtype="float32", shape=(2, 2))
|
||||
mock_spec1 = KerasTensor(shape=(2, 2), dtype="float32")
|
||||
mock_spec2 = KerasTensor(shape=(2, 2), dtype="float32")
|
||||
self.assertTrue(cond_op._check_output_spec(mock_spec1, mock_spec2))
|
||||
|
||||
def test_cond_check_output_spec_none(self):
|
||||
|
@ -1,5 +1,8 @@
|
||||
from functools import wraps
|
||||
|
||||
import optree
|
||||
import optree.utils
|
||||
|
||||
from keras.backend.common.global_state import get_global_attribute
|
||||
from keras.backend.common.global_state import set_global_attribute
|
||||
from keras.utils import python_utils
|
||||
@ -110,6 +113,7 @@ class Tracker:
|
||||
self.stored_ids[store_name].add(id(value))
|
||||
|
||||
|
||||
@optree.register_pytree_node_class(namespace="keras")
|
||||
class TrackedList(list):
|
||||
def __init__(self, values=None, tracker=None):
|
||||
self.tracker = tracker
|
||||
@ -160,7 +164,17 @@ class TrackedList(list):
|
||||
if self.tracker:
|
||||
self.tracker.untrack(value)
|
||||
|
||||
def tree_flatten(self):
|
||||
# For optree
|
||||
return (self, None)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, metadata, children):
|
||||
# For optree
|
||||
return cls(children)
|
||||
|
||||
|
||||
@optree.register_pytree_node_class(namespace="keras")
|
||||
class TrackedDict(dict):
|
||||
def __init__(self, values=None, tracker=None):
|
||||
self.tracker = tracker
|
||||
@ -199,7 +213,20 @@ class TrackedDict(dict):
|
||||
self.tracker.untrack(value)
|
||||
super().clear()
|
||||
|
||||
def tree_flatten(self):
|
||||
# For optree
|
||||
keys, values = optree.utils.unzip2(
|
||||
optree.utils.total_order_sorted(self.items(), key=lambda kv: kv[0])
|
||||
)
|
||||
return values, list(keys), keys
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, keys, values):
|
||||
# For optree
|
||||
return cls(optree.utils.safe_zip(keys, values))
|
||||
|
||||
|
||||
@optree.register_pytree_node_class(namespace="keras")
|
||||
class TrackedSet(set):
|
||||
def __init__(self, values=None, tracker=None):
|
||||
self.tracker = tracker
|
||||
@ -233,3 +260,12 @@ class TrackedSet(set):
|
||||
for value in self:
|
||||
self.tracker.untrack(value)
|
||||
super().clear()
|
||||
|
||||
def tree_flatten(self):
|
||||
# For optree
|
||||
return (self, None)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, metadata, children):
|
||||
# For optree
|
||||
return cls(children)
|
||||
|
@ -1,56 +1,389 @@
|
||||
import tree
|
||||
import collections
|
||||
import collections.abc
|
||||
import types
|
||||
|
||||
import optree
|
||||
|
||||
def is_nested(structure):
|
||||
return tree.is_nested(structure)
|
||||
from keras.api_export import keras_export
|
||||
from keras.backend.config import backend
|
||||
|
||||
# Register backend-specific node classes
|
||||
if backend() == "tensorflow":
|
||||
from tensorflow.python.trackable.data_structures import ListWrapper
|
||||
|
||||
def flatten(structure):
|
||||
return tree.flatten(structure)
|
||||
|
||||
|
||||
def map_structure(func, *structures, **kwargs):
|
||||
return tree.map_structure(func, *structures, **kwargs)
|
||||
|
||||
|
||||
def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
|
||||
return tree.map_structure_up_to(
|
||||
shallow_structure, func, *structures, **kwargs
|
||||
optree.register_pytree_node(
|
||||
ListWrapper,
|
||||
lambda x: (x, None),
|
||||
lambda metadata, children: ListWrapper(list(children)),
|
||||
namespace="keras",
|
||||
)
|
||||
|
||||
|
||||
def traverse(func, structure, top_down=True):
|
||||
return tree.traverse(func, structure, top_down=top_down)
|
||||
@keras_export("keras.tree.is_nested")
|
||||
def is_nested(structure):
|
||||
"""Checks if a given structure is nested.
|
||||
|
||||
Examples:
|
||||
|
||||
def assert_same_structure(a, b, check_types=True):
|
||||
return tree.assert_same_structure(a, b, check_types=check_types)
|
||||
|
||||
|
||||
def sequence_like(instance, args):
|
||||
"""Converts the sequence `args` to the same type as `instance`.
|
||||
>>> keras.tree.is_nested(42)
|
||||
False
|
||||
>>> keras.tree.is_nested({"foo": 42})
|
||||
True
|
||||
|
||||
Args:
|
||||
instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
|
||||
`collections.OrderedDict`.
|
||||
args: elements to be converted to the `instance` type.
|
||||
structure: A structure to check.
|
||||
|
||||
Returns:
|
||||
`args` with the type of `instance`.
|
||||
`True` if a given structure is nested, i.e. is a sequence, a mapping,
|
||||
or a namedtuple, and `False` otherwise.
|
||||
"""
|
||||
return tree._sequence_like(instance, args)
|
||||
return not optree.tree_is_leaf(
|
||||
structure, none_is_leaf=True, namespace="keras"
|
||||
)
|
||||
|
||||
|
||||
@keras_export("keras.tree.traverse")
|
||||
def traverse(func, structure, top_down=True):
|
||||
"""Traverses the given nested structure, applying the given function.
|
||||
|
||||
The traversal is depth-first. If `top_down` is True (default), parents
|
||||
are returned before their children (giving the option to avoid traversing
|
||||
into a sub-tree).
|
||||
|
||||
Examples:
|
||||
|
||||
>>> v = []
|
||||
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True)
|
||||
[(1, 2), [3], {'a': 4}]
|
||||
>>> v
|
||||
[[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]
|
||||
|
||||
>>> v = []
|
||||
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False)
|
||||
[(1, 2), [3], {'a': 4}]
|
||||
>>> v
|
||||
[1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]
|
||||
|
||||
Args:
|
||||
func: The function to be applied to each sub-nest of the structure.
|
||||
|
||||
When traversing top-down:
|
||||
If `func(subtree) is None` the traversal continues into the
|
||||
sub-tree.
|
||||
If `func(subtree) is not None` the traversal does not continue
|
||||
into the sub-tree. The sub-tree will be replaced by `func(subtree)`
|
||||
in the returned structure (to replace the sub-tree with `None`, use
|
||||
the special value `_MAP_TO_NONE`).
|
||||
|
||||
When traversing bottom-up:
|
||||
If `func(subtree) is None` the traversed sub-tree is returned
|
||||
unaltered.
|
||||
If `func(subtree) is not None` the sub-tree will be replaced by
|
||||
`func(subtree)` in the returned structure (to replace the sub-tree
|
||||
with None, use the special value `_MAP_TO_NONE`).
|
||||
|
||||
structure: The structure to traverse.
|
||||
top_down: If True, parent structures will be visited before their
|
||||
children.
|
||||
|
||||
Returns:
|
||||
The structured output from the traversal.
|
||||
"""
|
||||
|
||||
# From https://github.com/google/jax/pull/19695
|
||||
def traverse_children():
|
||||
children, treedef = optree.tree_flatten(
|
||||
structure,
|
||||
is_leaf=lambda x: x is not structure,
|
||||
none_is_leaf=True,
|
||||
namespace="keras",
|
||||
)
|
||||
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
|
||||
return structure
|
||||
else:
|
||||
return optree.tree_unflatten(
|
||||
treedef,
|
||||
[traverse(func, c, top_down=top_down) for c in children],
|
||||
)
|
||||
|
||||
if top_down:
|
||||
ret = func(structure)
|
||||
if ret is None:
|
||||
return traverse_children()
|
||||
else:
|
||||
traversed_structure = traverse_children()
|
||||
ret = func(traversed_structure)
|
||||
if ret is None:
|
||||
return traversed_structure
|
||||
return None if ret is _MAP_TO_NONE else ret
|
||||
|
||||
|
||||
@keras_export("keras.tree.flatten")
|
||||
def flatten(structure):
|
||||
"""Flattens a possibly nested structure into a list.
|
||||
|
||||
In the case of dict instances, the sequence consists of the values,
|
||||
sorted by key to ensure deterministic behavior. This is true also for
|
||||
`collections.OrderedDict` instances: their sequence order is
|
||||
considered. The same convention is followed in `unflatten_as`.
|
||||
This correctly unflattens dicts and `OrderedDict` after they have been
|
||||
flattened, or vice-versa.
|
||||
|
||||
Dictionaries with non-sortable keys cannot be flattened.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]])
|
||||
[1, 2, 3, 4, 5, 6]
|
||||
>>> keras.tree.flatten(None)
|
||||
[None]
|
||||
>>> keras.tree.flatten(1)
|
||||
[1]
|
||||
>>> keras.tree.flatten({100: 'world!', 6: 'Hello'})
|
||||
['Hello', 'world!']
|
||||
|
||||
Args:
|
||||
structure: An arbitrarily nested structure.
|
||||
|
||||
Returns:
|
||||
A list, the flattened version of the input `structure`.
|
||||
"""
|
||||
# optree.tree_flatten returns a pair (leaves, treespec) where the first
|
||||
# element is a list of leaf values and the second element is a treespec
|
||||
# representing the structure of the pytree.
|
||||
leaves, _ = optree.tree_flatten(
|
||||
structure, none_is_leaf=True, namespace="keras"
|
||||
)
|
||||
return leaves
|
||||
|
||||
|
||||
@keras_export("keras.tree.unflatten_as")
|
||||
def unflatten_as(structure, flat_sequence):
|
||||
"""Unflattens a sequence into a given structure.
|
||||
|
||||
If `structure` is a scalar, `flat_sequence` must be a single-element list;
|
||||
in this case the return value is ``flat_sequence[0]``.
|
||||
|
||||
If `structure` is or contains a dict instance, the keys will be sorted to
|
||||
pack the flat sequence in deterministic order. This is true also for
|
||||
`collections.OrderedDict` instances: their sequence order is considered.
|
||||
The same convention is followed in `flatten`. This correctly unflattens
|
||||
dicts and `OrderedDict` after they have been flattened, or vice-versa.
|
||||
|
||||
Dictionaries with non-sortable keys cannot be unflattened.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> keras.tree.unflatten_as([[1, 2], [[3], [4]]], [5, 6, 7, 8])
|
||||
[[5, 6], [[7], [8]]]
|
||||
>>> keras.tree.unflatten_as(None, [1])
|
||||
1
|
||||
>>> keras.tree.unflatten_as({1: None, 2: None}, ['Hello', 'world!'])
|
||||
{1: 'Hello', 2: 'world!'}
|
||||
|
||||
Args:
|
||||
structure: Arbitrarily nested structure.
|
||||
flat_sequence: Sequence to unflatten.
|
||||
|
||||
Returns:
|
||||
`flat_sequence` unflattened into `structure`.
|
||||
"""
|
||||
if not is_nested(flat_sequence):
|
||||
raise TypeError(
|
||||
f"flat_sequence must be a sequence not a {type(flat_sequence)}:\n"
|
||||
f"{flat_sequence}"
|
||||
)
|
||||
if not is_nested(structure):
|
||||
if len(flat_sequence) != 1:
|
||||
raise ValueError(
|
||||
"Structure is a scalar but "
|
||||
f"len(flat_sequence) == {len(flat_sequence)} > 1"
|
||||
)
|
||||
return flat_sequence[0]
|
||||
structure_spec = optree.tree_structure(
|
||||
structure, none_is_leaf=True, namespace="keras"
|
||||
)
|
||||
return structure_spec.unflatten(flat_sequence)
|
||||
|
||||
|
||||
@keras_export("keras.tree.map_structure")
|
||||
def map_structure(func, *structures):
|
||||
"""Maps `func` through given structures.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> structure = [[1], [2], [3]]
|
||||
>>> keras.tree.map_structure(lambda v: v**2, structure)
|
||||
[[1], [4], [9]]
|
||||
>>> keras.tree.map_structure(lambda x, y: x * y, structure, structure)
|
||||
[[1], [4], [9]]
|
||||
|
||||
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
|
||||
>>> structure = Foo(a=1, b=2)
|
||||
>>> keras.tree.map_structure(lambda v: v * 2, structure)
|
||||
Foo(a=2, b=4)
|
||||
|
||||
Args:
|
||||
func: A callable that accepts as many arguments as there are structures.
|
||||
*structures: Arbitrarily nested structures of the same layout.
|
||||
|
||||
Returns:
|
||||
A new structure with the same layout as the given ones.
|
||||
"""
|
||||
if not callable(func):
|
||||
raise TypeError(f"`func` must be callable. Received: func={func}")
|
||||
if not structures:
|
||||
raise ValueError("Must provide at least one structure")
|
||||
for other in structures[1:]:
|
||||
assert_same_structure(structures[0], other, check_types=False)
|
||||
return optree.tree_map(
|
||||
func, *structures, none_is_leaf=True, namespace="keras"
|
||||
)
|
||||
|
||||
|
||||
@keras_export("keras.tree.map_structure_up_to")
|
||||
def map_structure_up_to(shallow_structure, func, *structures):
|
||||
"""Maps `func` through given structures up to `shallow_structure`.
|
||||
|
||||
This is a variant of `map_structure` which only maps the given structures
|
||||
up to `shallow_structure`. All further nested components are retained as-is.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> shallow_structure = [None, None]
|
||||
>>> structure = [[1, 1], [2, 2]]
|
||||
>>> keras.tree.map_structure_up_to(shallow_structure, len, structure)
|
||||
[2, 2]
|
||||
|
||||
>>> shallow_structure = [None, [None, None]]
|
||||
>>> keras.tree.map_structure_up_to(shallow_structure, str, structure)
|
||||
['[1, 1]', ['2', '2']]
|
||||
|
||||
Args:
|
||||
shallow_structure: A structure with layout common to all `structures`.
|
||||
func: A callable that accepts as many arguments as there are structures.
|
||||
*structures: Arbitrarily nested structures of the same layout.
|
||||
|
||||
Returns:
|
||||
A new structure with the same layout as `shallow_structure`.
|
||||
"""
|
||||
return _map_structure_with_path_up_to(
|
||||
shallow_structure,
|
||||
lambda _, *args: func(*args), # Discards path.
|
||||
*structures,
|
||||
)
|
||||
|
||||
|
||||
@keras_export("keras.tree.assert_same_structure")
|
||||
def assert_same_structure(a, b, check_types=True):
|
||||
"""Asserts that two structures are nested in the same way.
|
||||
|
||||
Note that namedtuples with identical name and fields will not be considered
|
||||
as same structures even `check_types=False`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)])
|
||||
|
||||
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
|
||||
>>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])
|
||||
>>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3))
|
||||
>>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: `a` and `b` don't have the same structure.
|
||||
...
|
||||
|
||||
Args:
|
||||
a: an arbitrarily nested structure.
|
||||
b: an arbitrarily nested structure.
|
||||
check_types: if `True` (default) types of leaves are checked as well.
|
||||
"""
|
||||
a_structure = optree.tree_structure(a, none_is_leaf=True, namespace="keras")
|
||||
b_structure = optree.tree_structure(b, none_is_leaf=True, namespace="keras")
|
||||
if a_structure != b_structure:
|
||||
raise ValueError(
|
||||
"`a` and `b` don't have the same structure. "
|
||||
f"Received: structure of a={a_structure}, "
|
||||
f"structure of b={b_structure}"
|
||||
)
|
||||
if check_types:
|
||||
type_structure = optree.tree_map(
|
||||
lambda x, y: type(x) is type(y),
|
||||
a,
|
||||
b,
|
||||
none_is_leaf=True,
|
||||
namespace="keras",
|
||||
)
|
||||
if not optree.tree_all(
|
||||
type_structure, none_is_leaf=True, namespace="keras"
|
||||
):
|
||||
raise TypeError(
|
||||
"The type of the leaves of `a` and `b` doesn't match."
|
||||
)
|
||||
|
||||
|
||||
@keras_export("keras.tree.pack_sequence_as")
|
||||
def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
|
||||
"""Implements sequence packing, i.e. nest.pack_sequence_as()."""
|
||||
is_nested_fn = tree.is_nested
|
||||
sequence_fn = sequence_fn or tree._sequence_like
|
||||
"""Returns a given flattened sequence packed into a given structure.
|
||||
|
||||
If `structure` is an atom, `flat_sequence` must be a single-item list; in
|
||||
this case the return value is `flat_sequence[0]`.
|
||||
|
||||
If `structure` is or contains a dict instance, the keys will be sorted to
|
||||
pack the flat sequence in deterministic order. This is true also for
|
||||
`OrderedDict` instances: their sequence order is considered. The same
|
||||
convention is followed in `flatten`. This correctly repacks dicts and
|
||||
`OrderedDicts` after they have been flattened, or vice-versa.
|
||||
|
||||
Dictionaries with non-sortable keys cannot be flattened.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> structure = {"key3": "", "key1": "", "key2": ""}
|
||||
>>> flat_sequence = ["value1", "value2", "value3"]
|
||||
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
||||
{"key3": "value3", "key1": "value1", "key2": "value2"}
|
||||
|
||||
>>> structure = (("a", "b"), ("c", "d", "e"), "f")
|
||||
>>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
||||
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
|
||||
|
||||
>>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
|
||||
... "key1": {"e": "val1", "d": "val2"}}
|
||||
>>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
|
||||
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
||||
{'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
|
||||
|
||||
>>> structure = ["a"]
|
||||
>>> flat_sequence = [np.array([[1, 2], [3, 4]])]
|
||||
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
||||
[array([[1, 2],
|
||||
[3, 4]])]
|
||||
|
||||
>>> structure = ["a"]
|
||||
>>> flat_sequence = [keras.ops.ones([2, 2])]
|
||||
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
||||
[array([[1., 1.],
|
||||
[1., 1.]]]
|
||||
|
||||
Args:
|
||||
structure: Arbitrarily nested structure.
|
||||
flat_sequence: Flat sequence to pack.
|
||||
sequence_fn: Defaults to `_sequence_like`.
|
||||
|
||||
Returns:
|
||||
`flat_sequence` converted to have the same recursive structure as
|
||||
`structure`.
|
||||
"""
|
||||
sequence_fn = sequence_fn or _sequence_like
|
||||
|
||||
def truncate(value, length):
|
||||
value_str = str(value)
|
||||
return value_str[:length] + (value_str[length:] and "...")
|
||||
|
||||
if not is_nested_fn(flat_sequence):
|
||||
if not is_nested(flat_sequence):
|
||||
raise TypeError(
|
||||
"Attempted to pack value:\n {}\ninto a structure, but found "
|
||||
"incompatible type `{}` instead.".format(
|
||||
@ -58,7 +391,7 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
|
||||
)
|
||||
)
|
||||
|
||||
if not is_nested_fn(structure):
|
||||
if not is_nested(structure):
|
||||
if len(flat_sequence) != 1:
|
||||
raise ValueError(
|
||||
"The target structure is of type `{}`\n {}\nHowever the input "
|
||||
@ -74,13 +407,13 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
|
||||
return flat_sequence[0]
|
||||
|
||||
try:
|
||||
final_index, packed = packed_nest_with_indices(
|
||||
structure, flat_sequence, 0, is_nested_fn, sequence_fn
|
||||
final_index, packed = _packed_nest_with_indices(
|
||||
structure, flat_sequence, 0, sequence_fn
|
||||
)
|
||||
if final_index < len(flat_sequence):
|
||||
raise IndexError
|
||||
except IndexError:
|
||||
flat_structure = tree.flatten(structure)
|
||||
flat_structure = flatten(structure)
|
||||
if len(flat_structure) != len(flat_sequence):
|
||||
# pylint: disable=raise-missing-from
|
||||
raise ValueError(
|
||||
@ -92,33 +425,147 @@ def pack_sequence_as(structure, flat_sequence, sequence_fn=None):
|
||||
return sequence_fn(structure, packed)
|
||||
|
||||
|
||||
def packed_nest_with_indices(
|
||||
structure, flat, index, is_nested_fn, sequence_fn=None
|
||||
):
|
||||
"""Helper function for pack_sequence_as.
|
||||
@keras_export("keras.tree.lists_to_tuples")
|
||||
def lists_to_tuples(structure):
|
||||
"""Converts `list`s to `tuple`s."""
|
||||
|
||||
Args:
|
||||
structure: structure to mimic.
|
||||
flat: Flattened values to output substructure for.
|
||||
index: Index at which to start reading from flat.
|
||||
is_nested_fn: Function used to test if a value should
|
||||
be treated as a nested structure.
|
||||
sequence_fn: Function used to generate a new structure instance.
|
||||
def sequence_fn(instance, args):
|
||||
if isinstance(instance, list):
|
||||
return tuple(args)
|
||||
return _sequence_like(instance, args)
|
||||
|
||||
Returns:
|
||||
The tuple (new_index, child), where:
|
||||
* new_index - the updated index into `flat`
|
||||
having processed `structure`.
|
||||
* packed - the subset of `flat` corresponding to `structure`,
|
||||
having started at `index`, and packed into the same nested
|
||||
format.
|
||||
"""
|
||||
return pack_sequence_as(
|
||||
structure, flatten(structure), sequence_fn=sequence_fn
|
||||
)
|
||||
|
||||
|
||||
class _MapToNone:
|
||||
"""A special object used as a sentinel within `traverse`."""
|
||||
|
||||
def __repr__(self):
|
||||
return "keras.utils.tree._MAP_TO_NONE"
|
||||
|
||||
|
||||
_MAP_TO_NONE = _MapToNone()
|
||||
|
||||
|
||||
def _yield_flat_up_to(shallow_tree, input_tree, path=()):
|
||||
if isinstance(shallow_tree, (str, bytes)) or not (
|
||||
isinstance(
|
||||
shallow_tree, (collections.abc.Mapping, collections.abc.Sequence)
|
||||
)
|
||||
or optree.is_namedtuple(shallow_tree)
|
||||
):
|
||||
yield (path, input_tree)
|
||||
else:
|
||||
input_tree = dict(_yield_sorted_items(input_tree))
|
||||
for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
|
||||
subpath = path + (shallow_key,)
|
||||
input_subtree = input_tree[shallow_key]
|
||||
for leaf_path, leaf_value in _yield_flat_up_to(
|
||||
shallow_subtree, input_subtree, path=subpath
|
||||
):
|
||||
yield (leaf_path, leaf_value)
|
||||
|
||||
|
||||
def _multiyield_flat_up_to(shallow_tree, *input_trees):
|
||||
"""Same as `_yield_flat_up_to`, but takes multiple input trees."""
|
||||
zipped_iterators = zip(
|
||||
*[
|
||||
_yield_flat_up_to(shallow_tree, input_tree)
|
||||
for input_tree in input_trees
|
||||
]
|
||||
)
|
||||
try:
|
||||
for paths_and_values in zipped_iterators:
|
||||
paths, values = zip(*paths_and_values)
|
||||
yield paths[:1] + values
|
||||
except KeyError as e:
|
||||
paths = locals().get("paths", ((),))
|
||||
raise ValueError(
|
||||
f"Could not find key '{e.args[0]}' in some `input_trees`. "
|
||||
"Please ensure the structure of all `input_trees` are "
|
||||
"compatible with `shallow_tree`. The last valid path "
|
||||
f"yielded was {paths[0]}."
|
||||
) from e
|
||||
|
||||
|
||||
def _map_structure_with_path_up_to(shallow_structure, func, *structures):
|
||||
results = []
|
||||
for path_and_values in _multiyield_flat_up_to(
|
||||
shallow_structure, *structures
|
||||
):
|
||||
results.append(func(*path_and_values))
|
||||
shallow_structure_spec = optree.tree_structure(
|
||||
shallow_structure, none_is_leaf=True, namespace="keras"
|
||||
)
|
||||
return shallow_structure_spec.unflatten(results)
|
||||
|
||||
|
||||
def _sequence_like(instance, args):
|
||||
# TODO: Support attrs library
|
||||
if isinstance(instance, (dict, collections.abc.Mapping)):
|
||||
# Pack dictionaries in a deterministic order by sorting the keys.
|
||||
# Notice this means that we ignore the original order of `OrderedDict`
|
||||
# instances. This is intentional, to avoid potential bugs caused by
|
||||
# mixing ordered and plain dicts (e.g., flattening a dict but using a
|
||||
# corresponding `OrderedDict` to pack it back).
|
||||
result = dict(zip(sorted(instance), args))
|
||||
keys_and_values = ((key, result[key]) for key in instance)
|
||||
if isinstance(instance, collections.defaultdict):
|
||||
# `defaultdict` requires a default factory as the first argument.
|
||||
return type(instance)(instance.default_factory, keys_and_values)
|
||||
elif isinstance(instance, types.MappingProxyType):
|
||||
# MappingProxyType requires a dict to proxy to.
|
||||
return type(instance)(dict(keys_and_values))
|
||||
else:
|
||||
return type(instance)(keys_and_values)
|
||||
elif isinstance(instance, collections.abc.MappingView):
|
||||
# We can't directly construct mapping views, so we create a list instead
|
||||
return list(args)
|
||||
elif optree.is_namedtuple(instance):
|
||||
instance_type = type(instance)
|
||||
try:
|
||||
return instance_type(*args)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
f"Couldn't traverse {instance!r} with arguments {args}"
|
||||
) from e
|
||||
else:
|
||||
# Not a namedtuple
|
||||
return type(instance)(args)
|
||||
|
||||
|
||||
def _yield_sorted_items(iterable):
|
||||
# TODO: Support attrs library
|
||||
if isinstance(iterable, collections.abc.Mapping):
|
||||
# Iterate through dictionaries in a deterministic order by sorting the
|
||||
# keys. Notice this means that we ignore the original order of
|
||||
# `OrderedDict` instances. This is intentional, to avoid potential bugs
|
||||
# caused by mixing ordered and plain dicts (e.g., flattening a dict but
|
||||
# using a corresponding `OrderedDict` to pack it back).
|
||||
for key in sorted(iterable):
|
||||
yield key, iterable[key]
|
||||
elif optree.is_namedtuple(iterable):
|
||||
for field in iterable._fields:
|
||||
yield (field, getattr(iterable, field))
|
||||
else:
|
||||
for item in enumerate(iterable):
|
||||
yield item
|
||||
|
||||
|
||||
def _yield_value(iterable):
|
||||
for _, v in _yield_sorted_items(iterable):
|
||||
yield v
|
||||
|
||||
|
||||
def _packed_nest_with_indices(structure, flat, index, sequence_fn=None):
|
||||
packed = []
|
||||
sequence_fn = sequence_fn or tree._sequence_like
|
||||
for s in yield_value(structure):
|
||||
if is_nested_fn(s):
|
||||
new_index, child = packed_nest_with_indices(
|
||||
s, flat, index, is_nested_fn, sequence_fn
|
||||
sequence_fn = sequence_fn or _sequence_like
|
||||
for s in _yield_value(structure):
|
||||
if is_nested(s):
|
||||
new_index, child = _packed_nest_with_indices(
|
||||
s, flat, index, sequence_fn
|
||||
)
|
||||
packed.append(sequence_fn(s, child))
|
||||
index = new_index
|
||||
@ -126,21 +573,3 @@ def packed_nest_with_indices(
|
||||
packed.append(flat[index])
|
||||
index += 1
|
||||
return index, packed
|
||||
|
||||
|
||||
def yield_value(iterable):
|
||||
for _, v in tree._yield_sorted_items(iterable):
|
||||
yield v
|
||||
|
||||
|
||||
def lists_to_tuples(structure):
|
||||
def sequence_fn(instance, args):
|
||||
if isinstance(instance, list):
|
||||
return tuple(args)
|
||||
return tree._sequence_like(instance, args)
|
||||
|
||||
return pack_sequence_as(
|
||||
structure,
|
||||
tree.flatten(structure),
|
||||
sequence_fn=sequence_fn,
|
||||
)
|
||||
|
291
keras/utils/tree_test.py
Normal file
291
keras/utils/tree_test.py
Normal file
@ -0,0 +1,291 @@
|
||||
import collections
|
||||
|
||||
import numpy as np
|
||||
|
||||
from keras import ops
|
||||
from keras import testing
|
||||
from keras.utils import tree
|
||||
|
||||
STRUCTURE1 = (((1, 2), 3), 4, (5, 6))
|
||||
STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
|
||||
STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs")
|
||||
STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,))
|
||||
|
||||
|
||||
class TreeTest(testing.TestCase):
|
||||
def test_is_nested(self):
|
||||
self.assertFalse(tree.is_nested("1234"))
|
||||
self.assertFalse(tree.is_nested(b"1234"))
|
||||
self.assertFalse(tree.is_nested(bytearray("1234", "ascii")))
|
||||
self.assertTrue(tree.is_nested([1, 3, [4, 5]]))
|
||||
self.assertTrue(tree.is_nested(((7, 8), (5, 6))))
|
||||
self.assertTrue(tree.is_nested([]))
|
||||
self.assertTrue(tree.is_nested({"a": 1, "b": 2}))
|
||||
self.assertFalse(tree.is_nested(set([1, 2])))
|
||||
ones = np.ones([2, 3])
|
||||
self.assertFalse(tree.is_nested(ones))
|
||||
self.assertFalse(tree.is_nested(np.tanh(ones)))
|
||||
self.assertFalse(tree.is_nested(np.ones((4, 5))))
|
||||
|
||||
def test_flatten_and_unflatten(self):
|
||||
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
|
||||
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
|
||||
|
||||
self.assertEqual(tree.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8])
|
||||
self.assertEqual(
|
||||
tree.unflatten_as(structure, flat),
|
||||
(("a", "b"), "c", ("d", "e", ("f", "g"), "h")),
|
||||
)
|
||||
point = collections.namedtuple("Point", ["x", "y"])
|
||||
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
|
||||
flat = [4, 2, 1, 0]
|
||||
self.assertEqual(tree.flatten(structure), flat)
|
||||
restructured_from_flat = tree.unflatten_as(structure, flat)
|
||||
self.assertEqual(restructured_from_flat, structure)
|
||||
self.assertEqual(restructured_from_flat[0].x, 4)
|
||||
self.assertEqual(restructured_from_flat[0].y, 2)
|
||||
self.assertEqual(restructured_from_flat[1][0][0].x, 1)
|
||||
self.assertEqual(restructured_from_flat[1][0][0].y, 0)
|
||||
|
||||
self.assertEqual([5], tree.flatten(5))
|
||||
self.assertEqual([np.array([5])], tree.flatten(np.array([5])))
|
||||
|
||||
self.assertEqual("a", tree.unflatten_as(5, ["a"]))
|
||||
self.assertEqual(
|
||||
np.array([5]), tree.unflatten_as("scalar", [np.array([5])])
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Structure is a scalar"):
|
||||
tree.unflatten_as("scalar", [4, 5])
|
||||
with self.assertRaisesRegex(TypeError, "flat_sequence"):
|
||||
tree.unflatten_as([4, 5], "bad_sequence")
|
||||
with self.assertRaises(ValueError):
|
||||
tree.unflatten_as([5, 6, [7, 8]], ["a", "b", "c"])
|
||||
|
||||
self.assertEqual(
|
||||
tree.unflatten_as({1: None, 2: None}, ["Hello", "world!"]),
|
||||
{1: "Hello", 2: "world!"},
|
||||
)
|
||||
|
||||
def test_flatten_dict_order(self):
|
||||
ordered = collections.OrderedDict(
|
||||
[("d", 3), ("b", 1), ("a", 0), ("c", 2)]
|
||||
)
|
||||
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
|
||||
ordered_flat = tree.flatten(ordered)
|
||||
plain_flat = tree.flatten(plain)
|
||||
self.assertEqual([3, 1, 0, 2], ordered_flat)
|
||||
self.assertEqual([0, 1, 2, 3], plain_flat)
|
||||
|
||||
def test_unflatten_dict_order(self):
|
||||
ordered = collections.OrderedDict(
|
||||
[("d", 0), ("b", 0), ("a", 0), ("c", 0)]
|
||||
)
|
||||
plain = {"d": 0, "b": 0, "a": 0, "c": 0}
|
||||
seq = [0, 1, 2, 3]
|
||||
ordered_reconstruction = tree.unflatten_as(ordered, seq)
|
||||
plain_reconstruction = tree.unflatten_as(plain, seq)
|
||||
self.assertEqual(
|
||||
collections.OrderedDict([("d", 0), ("b", 1), ("a", 2), ("c", 3)]),
|
||||
ordered_reconstruction,
|
||||
)
|
||||
self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction)
|
||||
|
||||
def test_map_structure(self):
|
||||
structure2 = (((7, 8), 9), 10, (11, 12))
|
||||
structure1_plus1 = tree.map_structure(lambda x: x + 1, STRUCTURE1)
|
||||
tree.assert_same_structure(STRUCTURE1, structure1_plus1)
|
||||
self.assertAllEqual([2, 3, 4, 5, 6, 7], tree.flatten(structure1_plus1))
|
||||
structure1_plus_structure2 = tree.map_structure(
|
||||
lambda x, y: x + y, STRUCTURE1, structure2
|
||||
)
|
||||
self.assertEqual(
|
||||
(((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)),
|
||||
structure1_plus_structure2,
|
||||
)
|
||||
|
||||
self.assertEqual(3, tree.map_structure(lambda x: x - 1, 4))
|
||||
|
||||
self.assertEqual(7, tree.map_structure(lambda x, y: x + y, 3, 4))
|
||||
|
||||
# Empty structures
|
||||
self.assertEqual((), tree.map_structure(lambda x: x + 1, ()))
|
||||
self.assertEqual([], tree.map_structure(lambda x: x + 1, []))
|
||||
self.assertEqual({}, tree.map_structure(lambda x: x + 1, {}))
|
||||
empty_nt = collections.namedtuple("empty_nt", "")
|
||||
self.assertEqual(
|
||||
empty_nt(), tree.map_structure(lambda x: x + 1, empty_nt())
|
||||
)
|
||||
|
||||
# This is checking actual equality of types, empty list != empty tuple
|
||||
self.assertNotEqual((), tree.map_structure(lambda x: x + 1, []))
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "callable"):
|
||||
tree.map_structure("bad", structure1_plus1)
|
||||
with self.assertRaisesRegex(ValueError, "at least one structure"):
|
||||
tree.map_structure(lambda x: x)
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.map_structure(lambda x, y: None, 3, (3,))
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
|
||||
|
||||
structure1_list = [[[1, 2], 3], 4, [5, 6]]
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.map_structure(lambda x, y: None, STRUCTURE1, structure1_list)
|
||||
|
||||
def test_map_structure_up_to(self):
|
||||
# Named tuples.
|
||||
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
|
||||
op_tuple = collections.namedtuple("op_tuple", "add, mul")
|
||||
inp_val = ab_tuple(a=2, b=3)
|
||||
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
|
||||
out = tree.map_structure_up_to(
|
||||
inp_val,
|
||||
lambda val, ops: (val + ops.add) * ops.mul,
|
||||
inp_val,
|
||||
inp_ops,
|
||||
)
|
||||
self.assertEqual(out.a, 6)
|
||||
self.assertEqual(out.b, 15)
|
||||
|
||||
# Lists.
|
||||
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
|
||||
name_list = ["evens", ["odds", "primes"]]
|
||||
out = tree.map_structure_up_to(
|
||||
name_list,
|
||||
lambda name, sec: "first_{}_{}".format(len(sec), name),
|
||||
name_list,
|
||||
data_list,
|
||||
)
|
||||
self.assertEqual(
|
||||
out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]
|
||||
)
|
||||
|
||||
def test_assert_same_structure(self):
|
||||
tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
|
||||
tree.assert_same_structure("abc", 1.0, check_types=False)
|
||||
tree.assert_same_structure(b"abc", 1.0, check_types=False)
|
||||
tree.assert_same_structure("abc", 1.0, check_types=False)
|
||||
tree.assert_same_structure(
|
||||
bytearray("abc", "ascii"), 1.0, check_types=False
|
||||
)
|
||||
tree.assert_same_structure("abc", np.array([0, 1]), check_types=False)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure(
|
||||
STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS
|
||||
)
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure([0, 1], np.array([0, 1]))
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure(0, [0, 1])
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure((0, 1), [0, 1])
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure(STRUCTURE1, STRUCTURE_DIFFERENT_NESTING)
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure([[3], 4], [3, [4]])
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure({"a": 1}, {"b": 1})
|
||||
structure1_list = [[[1, 2], 3], 4, [5, 6]]
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure(STRUCTURE1, structure1_list)
|
||||
tree.assert_same_structure(STRUCTURE1, STRUCTURE2, check_types=False)
|
||||
with self.assertRaisesRegex(ValueError, "have the same structure"):
|
||||
tree.assert_same_structure(
|
||||
STRUCTURE1, structure1_list, check_types=False
|
||||
)
|
||||
|
||||
def test_pack_sequence_as(self):
|
||||
structure = {"key3": "", "key1": "", "key2": ""}
|
||||
flat_sequence = ["value1", "value2", "value3"]
|
||||
self.assertEqual(
|
||||
tree.pack_sequence_as(structure, flat_sequence),
|
||||
{"key3": "value3", "key1": "value1", "key2": "value2"},
|
||||
)
|
||||
structure = (("a", "b"), ("c", "d", "e"), "f")
|
||||
flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
self.assertEqual(
|
||||
tree.pack_sequence_as(structure, flat_sequence),
|
||||
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0),
|
||||
)
|
||||
structure = {
|
||||
"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
|
||||
"key1": {"e": "val1", "d": "val2"},
|
||||
}
|
||||
flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
|
||||
self.assertEqual(
|
||||
tree.pack_sequence_as(structure, flat_sequence),
|
||||
{
|
||||
"key3": {"c": (1.0, 2.0), "a": 3.0},
|
||||
"key1": {"e": "val1", "d": "val2"},
|
||||
},
|
||||
)
|
||||
structure = ["a"]
|
||||
flat_sequence = [np.array([[1, 2], [3, 4]])]
|
||||
self.assertAllClose(
|
||||
tree.pack_sequence_as(structure, flat_sequence),
|
||||
[np.array([[1, 2], [3, 4]])],
|
||||
)
|
||||
structure = ["a"]
|
||||
flat_sequence = [ops.ones([2, 2])]
|
||||
self.assertAllClose(
|
||||
tree.pack_sequence_as(structure, flat_sequence),
|
||||
[ops.ones([2, 2])],
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Attempted to pack value:"):
|
||||
structure = ["a"]
|
||||
flat_sequence = 1
|
||||
tree.pack_sequence_as(structure, flat_sequence)
|
||||
with self.assertRaisesRegex(ValueError, "The target structure is of"):
|
||||
structure = "a"
|
||||
flat_sequence = [1, 2]
|
||||
tree.pack_sequence_as(structure, flat_sequence)
|
||||
|
||||
def test_lists_to_tuples(self):
|
||||
structure = [1, 2, 3]
|
||||
self.assertEqual(tree.lists_to_tuples(structure), (1, 2, 3))
|
||||
structure = [[1], [2, 3]]
|
||||
self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, 3)))
|
||||
structure = [[1], [2, [3]]]
|
||||
self.assertEqual(tree.lists_to_tuples(structure), ((1,), (2, (3,))))
|
||||
|
||||
def test_traverse(self):
|
||||
# Lists to tuples
|
||||
structure = [(1, 2), [3], {"a": [4]}]
|
||||
self.assertEqual(
|
||||
((1, 2), (3,), {"a": (4,)}),
|
||||
tree.traverse(
|
||||
lambda x: tuple(x) if isinstance(x, list) else x,
|
||||
structure,
|
||||
top_down=False,
|
||||
),
|
||||
)
|
||||
# EarlyTermination
|
||||
structure = [(1, [2]), [3, (4, 5, 6)]]
|
||||
visited = []
|
||||
|
||||
def visit(x):
|
||||
visited.append(x)
|
||||
return "X" if isinstance(x, tuple) and len(x) > 2 else None
|
||||
|
||||
output = tree.traverse(visit, structure)
|
||||
self.assertEqual([(1, [2]), [3, "X"]], output)
|
||||
self.assertEqual(
|
||||
[
|
||||
[(1, [2]), [3, (4, 5, 6)]],
|
||||
(1, [2]),
|
||||
1,
|
||||
[2],
|
||||
2,
|
||||
[3, (4, 5, 6)],
|
||||
3,
|
||||
(4, 5, 6),
|
||||
],
|
||||
visited,
|
||||
)
|
@ -15,5 +15,5 @@ google
|
||||
tensorboard-plugin-profile
|
||||
rich
|
||||
build
|
||||
dm-tree
|
||||
optree
|
||||
pytest-cov
|
||||
|
2
setup.py
2
setup.py
@ -44,7 +44,7 @@ setup(
|
||||
"rich",
|
||||
"namex",
|
||||
"h5py",
|
||||
"dm-tree",
|
||||
"optree",
|
||||
"ml-dtypes",
|
||||
],
|
||||
# Supported Python versions
|
||||
|
Loading…
Reference in New Issue
Block a user