20 lines
721 B
Python
20 lines
721 B
Python
def is_shape_tuple(x):
|
|
if isinstance(x, (list, tuple)):
|
|
if all(isinstance(e, (int, type(None))) for e in x):
|
|
return True
|
|
return False
|
|
|
|
|
|
def map_shape_structure(fn, struct):
|
|
"""Variant of nest.map_structure that operates on shape tuples."""
|
|
if is_shape_tuple(struct):
|
|
return fn(tuple(struct))
|
|
if isinstance(struct, list):
|
|
return [map_shape_structure(fn, e) for e in struct]
|
|
if isinstance(struct, tuple):
|
|
return tuple(map_shape_structure(fn, e) for e in struct)
|
|
if isinstance(struct, dict):
|
|
return {k: map_shape_structure(fn, v) for k, v in struct.items()}
|
|
else:
|
|
raise ValueError(f"Cannot map function to unknown object {struct}")
|