Merge branch 'master' of https://github.com/fchollet/keras
This commit is contained in:
commit
d137d00182
@ -622,20 +622,27 @@ class Lambda(Layer):
|
|||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config):
|
def from_config(cls, config, custom_objects={}):
|
||||||
|
# Insert custom objects into globals.
|
||||||
|
if custom_objects:
|
||||||
|
globs = globals().copy()
|
||||||
|
globs.update(custom_objects)
|
||||||
|
else:
|
||||||
|
globs = globals()
|
||||||
|
|
||||||
function_type = config.pop('function_type')
|
function_type = config.pop('function_type')
|
||||||
if function_type == 'function':
|
if function_type == 'function':
|
||||||
function = globals()[config['function']]
|
function = globs[config['function']]
|
||||||
elif function_type == 'lambda':
|
elif function_type == 'lambda':
|
||||||
function = func_load(config['function'], globs=globals())
|
function = func_load(config['function'], globs=globs)
|
||||||
else:
|
else:
|
||||||
raise TypeError('Unknown function type:', function_type)
|
raise TypeError('Unknown function type:', function_type)
|
||||||
|
|
||||||
output_shape_type = config.pop('output_shape_type')
|
output_shape_type = config.pop('output_shape_type')
|
||||||
if output_shape_type == 'function':
|
if output_shape_type == 'function':
|
||||||
output_shape = globals()[config['output_shape']]
|
output_shape = globs[config['output_shape']]
|
||||||
elif output_shape_type == 'lambda':
|
elif output_shape_type == 'lambda':
|
||||||
output_shape = func_load(config['output_shape'], globs=globals())
|
output_shape = func_load(config['output_shape'], globs=globs)
|
||||||
else:
|
else:
|
||||||
output_shape = config['output_shape']
|
output_shape = config['output_shape']
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
import inspect
|
||||||
|
|
||||||
from .generic_utils import get_from_module
|
from .generic_utils import get_from_module
|
||||||
from .np_utils import convert_kernel
|
from .np_utils import convert_kernel
|
||||||
@ -31,7 +32,12 @@ def layer_from_config(config, custom_objects={}):
|
|||||||
else:
|
else:
|
||||||
layer_class = get_from_module(class_name, globals(), 'layer',
|
layer_class = get_from_module(class_name, globals(), 'layer',
|
||||||
instantiate=False)
|
instantiate=False)
|
||||||
return layer_class.from_config(config['config'])
|
|
||||||
|
arg_spec = inspect.getargspec(layer_class.from_config)
|
||||||
|
if 'custom_objects' in arg_spec.args:
|
||||||
|
return layer_class.from_config(config['config'], custom_objects=custom_objects)
|
||||||
|
else:
|
||||||
|
return layer_class.from_config(config['config'])
|
||||||
|
|
||||||
|
|
||||||
def print_summary(layers, relevant_nodes=None,
|
def print_summary(layers, relevant_nodes=None,
|
||||||
|
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
from numpy.testing import assert_allclose
|
from numpy.testing import assert_allclose
|
||||||
|
|
||||||
from keras.models import Model, Sequential
|
from keras.models import Model, Sequential
|
||||||
from keras.layers import Dense, Dropout, RepeatVector, TimeDistributed
|
from keras.layers import Dense, Dropout, Lambda, RepeatVector, TimeDistributed
|
||||||
from keras.layers import Input
|
from keras.layers import Input
|
||||||
from keras import optimizers
|
from keras import optimizers
|
||||||
from keras import objectives
|
from keras import objectives
|
||||||
@ -232,5 +232,35 @@ def test_loading_weights_by_name_2():
|
|||||||
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
|
assert_allclose(np.zeros_like(jessica[1]), jessica[1]) # biases init to 0
|
||||||
|
|
||||||
|
|
||||||
|
# a function to be called from the Lambda layer
|
||||||
|
def square_fn(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
|
||||||
|
@keras_test
|
||||||
|
def test_saving_lambda_custom_objects():
|
||||||
|
input = Input(shape=(3,))
|
||||||
|
x = Lambda(lambda x: square_fn(x), output_shape=(3,))(input)
|
||||||
|
output = Dense(3)(x)
|
||||||
|
|
||||||
|
model = Model(input, output)
|
||||||
|
model.compile(loss=objectives.MSE,
|
||||||
|
optimizer=optimizers.RMSprop(lr=0.0001),
|
||||||
|
metrics=[metrics.categorical_accuracy])
|
||||||
|
x = np.random.random((1, 3))
|
||||||
|
y = np.random.random((1, 3))
|
||||||
|
model.train_on_batch(x, y)
|
||||||
|
|
||||||
|
out = model.predict(x)
|
||||||
|
_, fname = tempfile.mkstemp('.h5')
|
||||||
|
save_model(model, fname)
|
||||||
|
|
||||||
|
model = load_model(fname, custom_objects={'square_fn': square_fn})
|
||||||
|
os.remove(fname)
|
||||||
|
|
||||||
|
out2 = model.predict(x)
|
||||||
|
assert_allclose(out, out2, atol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
Loading…
Reference in New Issue
Block a user