Api conversions/zeropadding cropping (#5723)
* Add ZeroPadding2/3D and Cropping2/3D API conversion interfaces. * Add ZeroPadding2/3D and Cropping2/3D API conversion interfaces.
This commit is contained in:
parent
7696a13995
commit
84711475f8
@ -1289,6 +1289,7 @@ class ZeroPadding2D(Layer):
|
||||
`(batch, channels, padded_rows, padded_cols)`
|
||||
"""
|
||||
|
||||
@interfaces.legacy_zeropadding2d_support
|
||||
def __init__(self,
|
||||
padding=(1, 1),
|
||||
data_format=None,
|
||||
@ -1383,6 +1384,7 @@ class ZeroPadding3D(Layer):
|
||||
`(batch, depth, first_padded_axis, second_padded_axis, third_axis_to_pad)`
|
||||
"""
|
||||
|
||||
@interfaces.legacy_zeropadding3d_support
|
||||
def __init__(self, padding=(1, 1, 1), data_format=None, **kwargs):
|
||||
super(ZeroPadding3D, self).__init__(**kwargs)
|
||||
self.data_format = conv_utils.normalize_data_format(data_format)
|
||||
@ -1542,6 +1544,7 @@ class Cropping2D(Layer):
|
||||
```
|
||||
"""
|
||||
|
||||
@interfaces.legacy_cropping2d_support
|
||||
def __init__(self, cropping=((0, 0), (0, 0)),
|
||||
data_format=None, **kwargs):
|
||||
super(Cropping2D, self).__init__(**kwargs)
|
||||
@ -1669,6 +1672,7 @@ class Cropping3D(Layer):
|
||||
`(batch, depth, first_cropped_axis, second_cropped_axis, third_cropped_axis)`
|
||||
"""
|
||||
|
||||
@interfaces.legacy_cropping3d_support
|
||||
def __init__(self, cropping=((1, 1), (1, 1), (1, 1)),
|
||||
data_format=None, **kwargs):
|
||||
super(Cropping3D, self).__init__(**kwargs)
|
||||
|
@ -406,7 +406,6 @@ def convlstm2d_args_preprocessor(args, kwargs):
|
||||
args, kwargs, _converted = conv2d_args_preprocessor(args, kwargs)
|
||||
return args, kwargs, converted + _converted
|
||||
|
||||
|
||||
legacy_convlstm2d_support = generate_legacy_interface(
|
||||
allowed_positional_args=['filters', 'kernel_size'],
|
||||
conversions=[('nb_filter', 'filters'),
|
||||
@ -427,9 +426,64 @@ legacy_convlstm2d_support = generate_legacy_interface(
|
||||
'default': None}},
|
||||
preprocessor=convlstm2d_args_preprocessor)
|
||||
|
||||
|
||||
legacy_batchnorm_support = generate_legacy_interface(
|
||||
allowed_positional_args=[],
|
||||
conversions=[('beta_init', 'beta_initializer'),
|
||||
('gamma_init', 'gamma_initializer')],
|
||||
preprocessor=batchnorm_args_preprocessor)
|
||||
|
||||
|
||||
def zeropadding2d_preprocessor(args, kwargs):
|
||||
converted = []
|
||||
if 'padding' in kwargs and isinstance(kwargs['padding'], dict):
|
||||
if set(kwargs['padding'].keys()) <= {'top_pad', 'bottom_pad',
|
||||
'left_pad', 'right_pad'}:
|
||||
top_pad = kwargs['padding'].get('top_pad', 0)
|
||||
bottom_pad = kwargs['padding'].get('bottom_pad', 0)
|
||||
left_pad = kwargs['padding'].get('left_pad', 0)
|
||||
right_pad = kwargs['padding'].get('right_pad', 0)
|
||||
kwargs['padding'] = ((top_pad, bottom_pad), (left_pad, right_pad))
|
||||
warnings.warn('The `padding` argument in the Keras 2 API no longer'
|
||||
'accepts dict types. You can now input argument as: '
|
||||
'`padding`=(top_pad, bottom_pad, left_pad, right_pad)')
|
||||
elif len(args) == 2 and isinstance(args[1], dict):
|
||||
if set(args[1].keys()) <= {'top_pad', 'bottom_pad',
|
||||
'left_pad', 'right_pad'}:
|
||||
top_pad = args[1].get('top_pad', 0)
|
||||
bottom_pad = args[1].get('bottom_pad', 0)
|
||||
left_pad = args[1].get('left_pad', 0)
|
||||
right_pad = args[1].get('right_pad', 0)
|
||||
args = (args[0], ((top_pad, bottom_pad), (left_pad, right_pad)))
|
||||
warnings.warn('The `padding` argument in the Keras 2 API no longer'
|
||||
'accepts dict types. You can now input argument as: '
|
||||
'`padding`=((top_pad, bottom_pad), (left_pad, right_pad))')
|
||||
return args, kwargs, converted
|
||||
|
||||
legacy_zeropadding2d_support = generate_legacy_interface(
|
||||
allowed_positional_args=['padding'],
|
||||
conversions=[('dim_ordering', 'data_format')],
|
||||
value_conversions={'dim_ordering': {'tf': 'channels_last',
|
||||
'th': 'channels_first',
|
||||
'default': None}},
|
||||
preprocessor=zeropadding2d_preprocessor)
|
||||
|
||||
legacy_zeropadding3d_support = generate_legacy_interface(
|
||||
allowed_positional_args=['padding'],
|
||||
conversions=[('dim_ordering', 'data_format')],
|
||||
value_conversions={'dim_ordering': {'tf': 'channels_last',
|
||||
'th': 'channels_first',
|
||||
'default': None}})
|
||||
|
||||
legacy_cropping2d_support = generate_legacy_interface(
|
||||
allowed_positional_args=['cropping'],
|
||||
conversions=[('dim_ordering', 'data_format')],
|
||||
value_conversions={'dim_ordering': {'tf': 'channels_last',
|
||||
'th': 'channels_first',
|
||||
'default': None}})
|
||||
|
||||
legacy_cropping3d_support = generate_legacy_interface(
|
||||
allowed_positional_args=['cropping'],
|
||||
conversions=[('dim_ordering', 'data_format')],
|
||||
value_conversions={'dim_ordering': {'tf': 'channels_last',
|
||||
'th': 'channels_first',
|
||||
'default': None}})
|
||||
|
@ -388,8 +388,12 @@ def test_upsampling2d_legacy_interface():
|
||||
|
||||
@keras_test
|
||||
def test_upsampling3d_legacy_interface():
|
||||
old_layer = keras.layers.UpSampling3D((2, 2, 2), dim_ordering='tf', name='us3d')
|
||||
new_layer = keras.layers.UpSampling3D((2, 2, 2), data_format='channels_last', name='us3d')
|
||||
old_layer = keras.layers.UpSampling3D((2, 2, 2),
|
||||
dim_ordering='tf',
|
||||
name='us3d')
|
||||
new_layer = keras.layers.UpSampling3D((2, 2, 2),
|
||||
data_format='channels_last',
|
||||
name='us3d')
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
@ -706,5 +710,44 @@ def test_atrousconv2d_legacy_interface():
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_zeropadding2d_legacy_interface():
|
||||
old_layer = keras.layers.ZeroPadding2D(padding={'right_pad': 4,
|
||||
'bottom_pad': 2,
|
||||
'top_pad': 1,
|
||||
'left_pad': 3},
|
||||
dim_ordering='tf',
|
||||
name='zp2d')
|
||||
new_layer = keras.layers.ZeroPadding2D(((1, 2), (3, 4)),
|
||||
data_format='channels_last',
|
||||
name='zp2d')
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_zeropadding3d_legacy_interface():
|
||||
old_layer = keras.layers.ZeroPadding3D((2, 2, 2),
|
||||
dim_ordering='tf',
|
||||
name='zp3d')
|
||||
new_layer = keras.layers.ZeroPadding3D((2, 2, 2),
|
||||
data_format='channels_last',
|
||||
name='zp3d')
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_cropping2d_legacy_interface():
|
||||
old_layer = keras.layers.Cropping2D(dim_ordering='tf', name='c2d')
|
||||
new_layer = keras.layers.Cropping2D(data_format='channels_last', name='c2d')
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_cropping3d_legacy_interface():
|
||||
old_layer = keras.layers.Cropping3D(dim_ordering='tf', name='c3d')
|
||||
new_layer = keras.layers.Cropping3D(data_format='channels_last', name='c3d')
|
||||
assert json.dumps(old_layer.get_config()) == json.dumps(new_layer.get_config())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
|
Loading…
Reference in New Issue
Block a user