Use absl.testing.parameterized for tree_test.py. (#19842)

For consistency, use `absl.testing.parameterized` instead of `parameterized` for `tree_test.py` since that is used for all other tests.

It's one less dependency. It also says `optree` or `dmtree` in each test name.
This commit is contained in:
hertschuh 2024-06-11 15:55:13 -07:00 committed by GitHub
parent 26abe697a8
commit 224de28928
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 106 additions and 127 deletions

@ -1,7 +1,7 @@
import collections
import numpy as np
import parameterized
from absl.testing import parameterized
from keras.src import ops
from keras.src import testing
@ -13,97 +13,89 @@ STRUCTURE2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
STRUCTURE_DIFFERENT_NUM_ELEMENTS = ("spam", "eggs")
STRUCTURE_DIFFERENT_NESTING = (((1, 2), 3), 4, 5, (6,))
TEST_CASES = []
if dmtree.available:
from keras.src.tree import dmtree_impl
@parameterized.parameterized_class(
("tree_module",),
[
(dmtree,),
(optree,),
],
)
TEST_CASES += [
{
"testcase_name": "dmtree",
"tree_impl": dmtree_impl,
"is_optree": False,
}
]
if optree.available:
from keras.src.tree import optree_impl
TEST_CASES += [
{
"testcase_name": "optree",
"tree_impl": optree_impl,
"is_optree": True,
},
]
@parameterized.named_parameters(TEST_CASES)
class TreeTest(testing.TestCase):
tree_impl = None
tree_api_name = ""
def setUp(self):
super().setUp()
if self.tree_module.available:
if self.tree_module == dmtree:
from keras.src.tree import dmtree_impl as tree_impl
self.tree_impl = tree_impl
self.tree_api_name = "dmtree"
elif self.tree_module == optree:
from keras.src.tree import optree_impl as tree_impl
self.tree_impl = tree_impl
self.tree_api_name = "optree"
else:
raise ValueError(f"Unrecognized module {self.tree_module}.")
else:
self.skipTest(f"Skip since {self.tree_module} is not available.")
def test_is_nested(self):
self.assertFalse(self.tree_impl.is_nested("1234"))
self.assertFalse(self.tree_impl.is_nested(b"1234"))
self.assertFalse(self.tree_impl.is_nested(bytearray("1234", "ascii")))
self.assertTrue(self.tree_impl.is_nested([1, 3, [4, 5]]))
self.assertTrue(self.tree_impl.is_nested(((7, 8), (5, 6))))
self.assertTrue(self.tree_impl.is_nested([]))
self.assertTrue(self.tree_impl.is_nested({"a": 1, "b": 2}))
self.assertFalse(self.tree_impl.is_nested(set([1, 2])))
def test_is_nested(self, tree_impl, is_optree):
self.assertFalse(tree_impl.is_nested("1234"))
self.assertFalse(tree_impl.is_nested(b"1234"))
self.assertFalse(tree_impl.is_nested(bytearray("1234", "ascii")))
self.assertTrue(tree_impl.is_nested([1, 3, [4, 5]]))
self.assertTrue(tree_impl.is_nested(((7, 8), (5, 6))))
self.assertTrue(tree_impl.is_nested([]))
self.assertTrue(tree_impl.is_nested({"a": 1, "b": 2}))
self.assertFalse(tree_impl.is_nested(set([1, 2])))
ones = np.ones([2, 3])
self.assertFalse(self.tree_impl.is_nested(ones))
self.assertFalse(self.tree_impl.is_nested(np.tanh(ones)))
self.assertFalse(self.tree_impl.is_nested(np.ones((4, 5))))
self.assertFalse(tree_impl.is_nested(ones))
self.assertFalse(tree_impl.is_nested(np.tanh(ones)))
self.assertFalse(tree_impl.is_nested(np.ones((4, 5))))
def test_flatten(self):
def test_flatten(self, tree_impl, is_optree):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
flat = ["a", "b", "c", "d", "e", "f", "g", "h"]
self.assertEqual(
self.tree_impl.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]
tree_impl.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]
)
point = collections.namedtuple("Point", ["x", "y"])
structure = (point(x=4, y=2), ((point(x=1, y=0),),))
flat = [4, 2, 1, 0]
self.assertEqual(self.tree_impl.flatten(structure), flat)
self.assertEqual(tree_impl.flatten(structure), flat)
self.assertEqual([5], self.tree_impl.flatten(5))
self.assertEqual([np.array([5])], self.tree_impl.flatten(np.array([5])))
self.assertEqual([5], tree_impl.flatten(5))
self.assertEqual([np.array([5])], tree_impl.flatten(np.array([5])))
def test_flatten_dict_order(self):
def test_flatten_dict_order(self, tree_impl, is_optree):
ordered = collections.OrderedDict(
[("d", 3), ("b", 1), ("a", 0), ("c", 2)]
)
plain = {"d": 3, "b": 1, "a": 0, "c": 2}
ordered_flat = self.tree_impl.flatten(ordered)
plain_flat = self.tree_impl.flatten(plain)
ordered_flat = tree_impl.flatten(ordered)
plain_flat = tree_impl.flatten(plain)
# dmtree does not respect the ordered dict.
if self.tree_api_name == "optree":
if is_optree:
self.assertEqual([3, 1, 0, 2], ordered_flat)
else:
self.assertEqual([0, 1, 2, 3], ordered_flat)
self.assertEqual([0, 1, 2, 3], plain_flat)
def test_map_structure(self):
def test_map_structure(self, tree_impl, is_optree):
assertion_message = (
"have the same structure"
if self.tree_api_name == "optree"
if is_optree
else "have the same nested structure"
)
assertion_type_error = (
ValueError if self.tree_api_name == "optree" else TypeError
)
assertion_type_error = ValueError if is_optree else TypeError
structure2 = (((7, 8), 9), 10, (11, 12))
structure1_plus1 = self.tree_impl.map_structure(
lambda x: x + 1, STRUCTURE1
)
self.tree_impl.assert_same_structure(STRUCTURE1, structure1_plus1)
structure1_plus1 = tree_impl.map_structure(lambda x: x + 1, STRUCTURE1)
tree_impl.assert_same_structure(STRUCTURE1, structure1_plus1)
self.assertAllEqual(
[2, 3, 4, 5, 6, 7], self.tree_impl.flatten(structure1_plus1)
[2, 3, 4, 5, 6, 7], tree_impl.flatten(structure1_plus1)
)
structure1_plus_structure2 = self.tree_impl.map_structure(
structure1_plus_structure2 = tree_impl.map_structure(
lambda x, y: x + y, STRUCTURE1, structure2
)
self.assertEqual(
@ -111,57 +103,49 @@ class TreeTest(testing.TestCase):
structure1_plus_structure2,
)
self.assertEqual(3, self.tree_impl.map_structure(lambda x: x - 1, 4))
self.assertEqual(3, tree_impl.map_structure(lambda x: x - 1, 4))
self.assertEqual(
7, self.tree_impl.map_structure(lambda x, y: x + y, 3, 4)
)
self.assertEqual(7, tree_impl.map_structure(lambda x, y: x + y, 3, 4))
# Empty structures
self.assertEqual((), self.tree_impl.map_structure(lambda x: x + 1, ()))
self.assertEqual([], self.tree_impl.map_structure(lambda x: x + 1, []))
self.assertEqual({}, self.tree_impl.map_structure(lambda x: x + 1, {}))
self.assertEqual((), tree_impl.map_structure(lambda x: x + 1, ()))
self.assertEqual([], tree_impl.map_structure(lambda x: x + 1, []))
self.assertEqual({}, tree_impl.map_structure(lambda x: x + 1, {}))
empty_nt = collections.namedtuple("empty_nt", "")
self.assertEqual(
empty_nt(),
self.tree_impl.map_structure(lambda x: x + 1, empty_nt()),
tree_impl.map_structure(lambda x: x + 1, empty_nt()),
)
# This is checking actual equality of types, empty list != empty tuple
self.assertNotEqual(
(), self.tree_impl.map_structure(lambda x: x + 1, [])
)
self.assertNotEqual((), tree_impl.map_structure(lambda x: x + 1, []))
with self.assertRaisesRegex(TypeError, "callable"):
self.tree_impl.map_structure("bad", structure1_plus1)
tree_impl.map_structure("bad", structure1_plus1)
with self.assertRaisesRegex(ValueError, "at least one structure"):
self.tree_impl.map_structure(lambda x: x)
tree_impl.map_structure(lambda x: x)
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
tree_impl.map_structure(lambda x, y: None, (3, 4), (3, 4, 5))
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.map_structure(lambda x, y: None, 3, (3,))
tree_impl.map_structure(lambda x, y: None, 3, (3,))
with self.assertRaisesRegex(assertion_type_error, assertion_message):
self.tree_impl.map_structure(
lambda x, y: None, ((3, 4), 5), [(3, 4), 5]
)
tree_impl.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5])
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.map_structure(
lambda x, y: None, ((3, 4), 5), (3, (4, 5))
)
tree_impl.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)))
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(assertion_type_error, assertion_message):
self.tree_impl.map_structure(
tree_impl.map_structure(
lambda x, y: None, STRUCTURE1, structure1_list
)
def test_map_structure_up_to(self):
def test_map_structure_up_to(self, tree_impl, is_optree):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
op_tuple = collections.namedtuple("op_tuple", "add, mul")
inp_val = ab_tuple(a=2, b=3)
inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
out = self.tree_impl.map_structure_up_to(
out = tree_impl.map_structure_up_to(
inp_val,
lambda val, ops: (val + ops.add) * ops.mul,
inp_val,
@ -173,7 +157,7 @@ class TreeTest(testing.TestCase):
# Lists.
data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
name_list = ["evens", ["odds", "primes"]]
out = self.tree_impl.map_structure_up_to(
out = tree_impl.map_structure_up_to(
name_list,
lambda name, sec: "first_{}_{}".format(len(sec), name),
name_list,
@ -183,77 +167,75 @@ class TreeTest(testing.TestCase):
out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]
)
def test_assert_same_structure(self):
def test_assert_same_structure(self, tree_impl, is_optree):
assertion_message = (
"have the same structure"
if self.tree_api_name == "optree"
if is_optree
else "have the same nested structure"
)
assertion_type_error = (
ValueError if self.tree_api_name == "optree" else TypeError
)
self.tree_impl.assert_same_structure(
assertion_type_error = ValueError if is_optree else TypeError
tree_impl.assert_same_structure(
STRUCTURE1, STRUCTURE2, check_types=False
)
self.tree_impl.assert_same_structure("abc", 1.0, check_types=False)
self.tree_impl.assert_same_structure(b"abc", 1.0, check_types=False)
self.tree_impl.assert_same_structure("abc", 1.0, check_types=False)
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure("abc", 1.0, check_types=False)
tree_impl.assert_same_structure(b"abc", 1.0, check_types=False)
tree_impl.assert_same_structure("abc", 1.0, check_types=False)
tree_impl.assert_same_structure(
bytearray("abc", "ascii"), 1.0, check_types=False
)
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(
"abc", np.array([0, 1]), check_types=False
)
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(
STRUCTURE1, STRUCTURE_DIFFERENT_NUM_ELEMENTS
)
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure([0, 1], np.array([0, 1]))
tree_impl.assert_same_structure([0, 1], np.array([0, 1]))
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure(0, [0, 1])
tree_impl.assert_same_structure(0, [0, 1])
with self.assertRaisesRegex(assertion_type_error, assertion_message):
self.tree_impl.assert_same_structure((0, 1), [0, 1])
tree_impl.assert_same_structure((0, 1), [0, 1])
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(
STRUCTURE1, STRUCTURE_DIFFERENT_NESTING
)
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure([[3], 4], [3, [4]])
tree_impl.assert_same_structure([[3], 4], [3, [4]])
with self.assertRaisesRegex(ValueError, assertion_message):
self.tree_impl.assert_same_structure({"a": 1}, {"b": 1})
tree_impl.assert_same_structure({"a": 1}, {"b": 1})
structure1_list = [[[1, 2], 3], 4, [5, 6]]
with self.assertRaisesRegex(assertion_type_error, assertion_message):
self.tree_impl.assert_same_structure(STRUCTURE1, structure1_list)
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(STRUCTURE1, structure1_list)
tree_impl.assert_same_structure(
STRUCTURE1, STRUCTURE2, check_types=False
)
# dm-tree treat list and tuple only on type mismatch, but optree treat
# them as structure mismatch.
if self.tree_api_name == "optree":
if is_optree:
with self.assertRaisesRegex(
assertion_type_error, assertion_message
):
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(
STRUCTURE1, structure1_list, check_types=False
)
else:
self.tree_impl.assert_same_structure(
tree_impl.assert_same_structure(
STRUCTURE1, structure1_list, check_types=False
)
def test_pack_sequence_as(self):
def test_pack_sequence_as(self, tree_impl, is_optree):
structure = {"key3": "", "key1": "", "key2": ""}
flat_sequence = ["value1", "value2", "value3"]
self.assertEqual(
self.tree_impl.pack_sequence_as(structure, flat_sequence),
tree_impl.pack_sequence_as(structure, flat_sequence),
{"key3": "value3", "key1": "value1", "key2": "value2"},
)
structure = (("a", "b"), ("c", "d", "e"), "f")
flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
self.assertEqual(
self.tree_impl.pack_sequence_as(structure, flat_sequence),
tree_impl.pack_sequence_as(structure, flat_sequence),
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0),
)
structure = {
@ -262,7 +244,7 @@ class TreeTest(testing.TestCase):
}
flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
self.assertEqual(
self.tree_impl.pack_sequence_as(structure, flat_sequence),
tree_impl.pack_sequence_as(structure, flat_sequence),
{
"key3": {"c": (1.0, 2.0), "a": 3.0},
"key1": {"e": "val1", "d": "val2"},
@ -271,43 +253,41 @@ class TreeTest(testing.TestCase):
structure = ["a"]
flat_sequence = [np.array([[1, 2], [3, 4]])]
self.assertAllClose(
self.tree_impl.pack_sequence_as(structure, flat_sequence),
tree_impl.pack_sequence_as(structure, flat_sequence),
[np.array([[1, 2], [3, 4]])],
)
structure = ["a"]
flat_sequence = [ops.ones([2, 2])]
self.assertAllClose(
self.tree_impl.pack_sequence_as(structure, flat_sequence),
tree_impl.pack_sequence_as(structure, flat_sequence),
[ops.ones([2, 2])],
)
with self.assertRaisesRegex(TypeError, "Attempted to pack value:"):
structure = ["a"]
flat_sequence = 1
self.tree_impl.pack_sequence_as(structure, flat_sequence)
tree_impl.pack_sequence_as(structure, flat_sequence)
with self.assertRaisesRegex(ValueError, "The target structure is of"):
structure = "a"
flat_sequence = [1, 2]
self.tree_impl.pack_sequence_as(structure, flat_sequence)
tree_impl.pack_sequence_as(structure, flat_sequence)
def test_lists_to_tuples(self):
def test_lists_to_tuples(self, tree_impl, is_optree):
structure = [1, 2, 3]
self.assertEqual(self.tree_impl.lists_to_tuples(structure), (1, 2, 3))
self.assertEqual(tree_impl.lists_to_tuples(structure), (1, 2, 3))
structure = [[1], [2, 3]]
self.assertEqual(
self.tree_impl.lists_to_tuples(structure), ((1,), (2, 3))
)
self.assertEqual(tree_impl.lists_to_tuples(structure), ((1,), (2, 3)))
structure = [[1], [2, [3]]]
self.assertEqual(
self.tree_impl.lists_to_tuples(structure), ((1,), (2, (3,)))
tree_impl.lists_to_tuples(structure), ((1,), (2, (3,)))
)
def test_traverse(self):
def test_traverse(self, tree_impl, is_optree):
# Lists to tuples
structure = [(1, 2), [3], {"a": [4]}]
self.assertEqual(
((1, 2), (3,), {"a": (4,)}),
self.tree_impl.traverse(
tree_impl.traverse(
lambda x: tuple(x) if isinstance(x, list) else x,
structure,
top_down=False,
@ -321,7 +301,7 @@ class TreeTest(testing.TestCase):
visited.append(x)
return "X" if isinstance(x, tuple) and len(x) > 2 else None
output = self.tree_impl.traverse(visit, structure)
output = tree_impl.traverse(visit, structure)
self.assertEqual([(1, [2]), [3, "X"]], output)
self.assertEqual(
[

@ -19,5 +19,4 @@ optree
pytest-cov
packaging
# for tree_test.py
parameterized
dm_tree
dm_tree