Update layout_map to use re.search instead of re.match. (#18555)
* Update layout_map to use re.search instead of re.match. This allow user to skip the leading ".*' and make the rule more readable. * Address review comments * Fix unit test
This commit is contained in:
parent
c57e454f20
commit
f1ed36d5ed
@ -346,10 +346,10 @@ class ModelParallel(Distribution):
|
||||
# will be split across 4 devices. Any other variable that doesn't
|
||||
# match any key in the layout map will be fully replicated.
|
||||
layout_map = LayoutMap(device_mesh)
|
||||
layout_map['.*dense.*kernel'] = (None, 'model')
|
||||
layout_map['.*dense.*bias'] = ('model',)
|
||||
layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')
|
||||
layout_map['.*conv2d.*bias'] = ('model',)
|
||||
layout_map['dense.*kernel'] = (None, 'model')
|
||||
layout_map['dense.*bias'] = ('model',)
|
||||
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
|
||||
layout_map['conv2d.*bias'] = ('model',)
|
||||
|
||||
distribution = ModelParallel(device_mesh=device_mesh,
|
||||
layout_map=layout_map,
|
||||
@ -437,10 +437,10 @@ class LayoutMap(collections.abc.MutableMapping):
|
||||
|
||||
```python
|
||||
layout_map = LayoutMap(device_mesh=None)
|
||||
layout_map['.*dense.*kernel'] = (None, 'model') # layout_2d
|
||||
layout_map['.*dense.*bias'] = ('model',) # layout_1d
|
||||
layout_map['.*conv2d.*kernel'] = TensorLayout((None, None, None, 'model'))
|
||||
layout_map['.*conv2d.*bias'] = TensorLayout(('model',)) # layout_1d
|
||||
layout_map['dense.*kernel'] = (None, 'model') # layout_2d
|
||||
layout_map['dense.*bias'] = ('model',) # layout_1d
|
||||
layout_map['conv2d.*kernel'] = TensorLayout((None, None, None, 'model'))
|
||||
layout_map['conv2d.*bias'] = TensorLayout(('model',)) # layout_1d
|
||||
|
||||
layout_1 = layout_map['dense_1.kernel'] # layout_1 == layout_2d
|
||||
layout_2 = layout_map['dense_1.bias'] # layout_2 == layout_1d
|
||||
@ -465,9 +465,9 @@ class LayoutMap(collections.abc.MutableMapping):
|
||||
"""Retrieves the corresponding layout by the string key.
|
||||
|
||||
When there isn't an exact match, all the existing keys in the layout map
|
||||
will be treated as a regex and map against the input key again. The
|
||||
first match will be returned, based on the key insertion order. Returns
|
||||
`None` if there isn't any match found.
|
||||
will be treated as a regex and map against the input key again. When
|
||||
there are multiple matches for the regex, an `ValueError` will be
|
||||
raised. Returns `None` if there isn't any match found.
|
||||
|
||||
Args:
|
||||
key: String key to query a layout.
|
||||
@ -478,9 +478,19 @@ class LayoutMap(collections.abc.MutableMapping):
|
||||
if key in self._layout_map:
|
||||
return self._layout_map[key]
|
||||
|
||||
matching_keys = []
|
||||
for k in self._layout_map:
|
||||
if re.match(k, key):
|
||||
return self._layout_map[k]
|
||||
if re.search(k, key):
|
||||
matching_keys.append(k)
|
||||
if len(matching_keys) > 1:
|
||||
raise ValueError(
|
||||
f"Path '{key}' matches multiple layout "
|
||||
f"specification keys: {matching_keys}. Please make "
|
||||
"sure each tensor/variable path only matches at most "
|
||||
"one layout specification key in the LayoutMap."
|
||||
)
|
||||
elif len(matching_keys) == 1:
|
||||
return self._layout_map[matching_keys[0]]
|
||||
return None
|
||||
|
||||
def __setitem__(self, key, layout):
|
||||
|
@ -319,15 +319,18 @@ class LayoutMapTest(testing.TestCase):
|
||||
layout_map["dense.*kernel"] = self.replicated_2d
|
||||
layout_map["dense.*bias"] = self.replicated_1d
|
||||
|
||||
layout_map[".*bias"] = self.sharded_1d
|
||||
layout_map["bias"] = self.sharded_1d
|
||||
|
||||
self.assertEqual(layout_map["dense/kernel"], self.sharded_2d)
|
||||
self.assertEqual(layout_map["dense/bias"], self.sharded_1d)
|
||||
|
||||
# Map against the wildcard bias rule for dense, and based on the order
|
||||
# of insertion, it will not use .*bias.
|
||||
self.assertEqual(layout_map["dense_2/kernel"], self.replicated_2d)
|
||||
self.assertEqual(layout_map["dense_2/bias"], self.replicated_1d)
|
||||
# Map against the wildcard bias rule for dense. This will cause a
|
||||
# ValueError
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, "Path 'dense_2/bias' matches multiple layout"
|
||||
):
|
||||
layout_map["dense_2/bias"]
|
||||
|
||||
self.assertIsNone(layout_map["conv2d/kernel"])
|
||||
self.assertEqual(layout_map["conv2d/bias"], self.sharded_1d)
|
||||
|
Loading…
Reference in New Issue
Block a user