Update layout_map to use re.search instead of re.match. (#18555)

* Update layout_map to use re.search instead of re.match.

This allow user to skip the leading ".*' and make the rule more readable.

* Address review comments

* Fix unit test
This commit is contained in:
Qianli Scott Zhu 2023-10-06 14:44:58 -07:00 committed by GitHub
parent c57e454f20
commit f1ed36d5ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 17 deletions

@ -346,10 +346,10 @@ class ModelParallel(Distribution):
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['.*dense.*kernel'] = (None, 'model')
layout_map['.*dense.*bias'] = ('model',)
layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')
layout_map['.*conv2d.*bias'] = ('model',)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(device_mesh=device_mesh,
layout_map=layout_map,
@ -437,10 +437,10 @@ class LayoutMap(collections.abc.MutableMapping):
```python
layout_map = LayoutMap(device_mesh=None)
layout_map['.*dense.*kernel'] = (None, 'model') # layout_2d
layout_map['.*dense.*bias'] = ('model',) # layout_1d
layout_map['.*conv2d.*kernel'] = TensorLayout((None, None, None, 'model'))
layout_map['.*conv2d.*bias'] = TensorLayout(('model',)) # layout_1d
layout_map['dense.*kernel'] = (None, 'model') # layout_2d
layout_map['dense.*bias'] = ('model',) # layout_1d
layout_map['conv2d.*kernel'] = TensorLayout((None, None, None, 'model'))
layout_map['conv2d.*bias'] = TensorLayout(('model',)) # layout_1d
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
@ -465,9 +465,9 @@ class LayoutMap(collections.abc.MutableMapping):
"""Retrieves the corresponding layout by the string key.
When there isn't an exact match, all the existing keys in the layout map
will be treated as a regex and map against the input key again. The
first match will be returned, based on the key insertion order. Returns
`None` if there isn't any match found.
will be treated as a regex and map against the input key again. When
there are multiple matches for the regex, an `ValueError` will be
raised. Returns `None` if there isn't any match found.
Args:
key: String key to query a layout.
@ -478,9 +478,19 @@ class LayoutMap(collections.abc.MutableMapping):
if key in self._layout_map:
return self._layout_map[key]
matching_keys = []
for k in self._layout_map:
if re.match(k, key):
return self._layout_map[k]
if re.search(k, key):
matching_keys.append(k)
if len(matching_keys) > 1:
raise ValueError(
f"Path '{key}' matches multiple layout "
f"specification keys: {matching_keys}. Please make "
"sure each tensor/variable path only matches at most "
"one layout specification key in the LayoutMap."
)
elif len(matching_keys) == 1:
return self._layout_map[matching_keys[0]]
return None
def __setitem__(self, key, layout):

@ -319,15 +319,18 @@ class LayoutMapTest(testing.TestCase):
layout_map["dense.*kernel"] = self.replicated_2d
layout_map["dense.*bias"] = self.replicated_1d
layout_map[".*bias"] = self.sharded_1d
layout_map["bias"] = self.sharded_1d
self.assertEqual(layout_map["dense/kernel"], self.sharded_2d)
self.assertEqual(layout_map["dense/bias"], self.sharded_1d)
# Map against the wildcard bias rule for dense, and based on the order
# of insertion, it will not use .*bias.
self.assertEqual(layout_map["dense_2/kernel"], self.replicated_2d)
self.assertEqual(layout_map["dense_2/bias"], self.replicated_1d)
# Map against the wildcard bias rule for dense. This will cause a
# ValueError
with self.assertRaisesRegex(
ValueError, "Path 'dense_2/bias' matches multiple layout"
):
layout_map["dense_2/bias"]
self.assertIsNone(layout_map["conv2d/kernel"])
self.assertEqual(layout_map["conv2d/bias"], self.sharded_1d)