283 lines
11 KiB
Python
283 lines
11 KiB
Python
from __future__ import print_function
|
|
import pytest
|
|
import numpy as np
|
|
import os
|
|
np.random.seed(1337)
|
|
|
|
from keras.models import Graph, Sequential
|
|
from keras.layers import containers
|
|
from keras.layers.core import Dense, Activation
|
|
from keras.utils.test_utils import get_test_data
|
|
from keras.utils.layer_utils import model_summary
|
|
|
|
X = np.random.random((100, 32))
|
|
X2 = np.random.random((100, 32))
|
|
y = np.random.random((100, 4))
|
|
y2 = np.random.random((100,))
|
|
|
|
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=1000,
|
|
nb_test=200,
|
|
input_shape=(32,),
|
|
classification=False,
|
|
output_shape=(4,))
|
|
(X2_train, y2_train), (X2_test, y2_test) = get_test_data(nb_train=1000,
|
|
nb_test=200,
|
|
input_shape=(32,),
|
|
classification=False,
|
|
output_shape=(1,))
|
|
|
|
|
|
def test_1o_1i():
|
|
# test a non-sequential graph with 1 input and 1 output
|
|
np.random.seed(1337)
|
|
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(4), name='dense3', input='dense1')
|
|
|
|
graph.add_output(name='output1',
|
|
inputs=['dense2', 'dense3'],
|
|
merge_mode='sum')
|
|
graph.compile('rmsprop', {'output1': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'output1': y_train},
|
|
nb_epoch=10)
|
|
out = graph.predict({'input1': X_test})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 1)
|
|
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.train_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.evaluate({'input1': X_test, 'output1': y_test})
|
|
assert(loss < 2.5)
|
|
|
|
# test validation split
|
|
history = graph.fit({'input1': X_train, 'output1': y_train},
|
|
validation_split=0.2, nb_epoch=1)
|
|
# test validation data
|
|
history = graph.fit({'input1': X_train, 'output1': y_train},
|
|
validation_data={'input1': X_train, 'output1': y_train},
|
|
nb_epoch=1)
|
|
|
|
|
|
def test_1o_1i_2():
|
|
# test a more complex non-sequential graph with 1 input and 1 output
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2-0', input='input1')
|
|
graph.add_node(Activation('relu'), name='dense2', input='dense2-0')
|
|
|
|
graph.add_node(Dense(16), name='dense3', input='dense2')
|
|
graph.add_node(Dense(4), name='dense4', inputs=['dense1', 'dense3'],
|
|
merge_mode='sum')
|
|
|
|
graph.add_output(name='output1', inputs=['dense2', 'dense4'],
|
|
merge_mode='sum')
|
|
graph.compile('rmsprop', {'output1': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'output1': y_train},
|
|
nb_epoch=10)
|
|
out = graph.predict({'input1': X_train})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 1)
|
|
|
|
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.train_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.evaluate({'input1': X_test, 'output1': y_test})
|
|
assert(loss < 2.5)
|
|
|
|
graph.get_config(verbose=1)
|
|
graph.summary()
|
|
|
|
|
|
|
|
def test_1o_2i():
|
|
# test a non-sequential graph with 2 inputs and 1 output
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
graph.add_input(name='input2', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input2')
|
|
graph.add_node(Dense(4), name='dense3', input='dense1')
|
|
|
|
graph.add_output(name='output1', inputs=['dense2', 'dense3'],
|
|
merge_mode='sum')
|
|
graph.compile('rmsprop', {'output1': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'input2': X2_train, 'output1': y_train},
|
|
nb_epoch=10)
|
|
out = graph.predict({'input1': X_test, 'input2': X2_test})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 1)
|
|
|
|
loss = graph.test_on_batch({'input1': X_test, 'input2': X2_test, 'output1': y_test})
|
|
loss = graph.train_on_batch({'input1': X_test, 'input2': X2_test, 'output1': y_test})
|
|
loss = graph.evaluate({'input1': X_test, 'input2': X2_test, 'output1': y_test})
|
|
assert(loss < 3.0)
|
|
|
|
graph.get_config(verbose=1)
|
|
|
|
|
|
def test_2o_1i_weights():
|
|
# test a non-sequential graph with 1 input and 2 outputs
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(1), name='dense3', input='dense1')
|
|
|
|
graph.add_output(name='output1', input='dense2')
|
|
graph.add_output(name='output2', input='dense3')
|
|
graph.compile('rmsprop', {'output1': 'mse', 'output2': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'output1': y_train, 'output2': y2_train},
|
|
nb_epoch=10)
|
|
out = graph.predict({'input1': X_test})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 2)
|
|
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test, 'output2': y2_test})
|
|
loss = graph.train_on_batch({'input1': X_test, 'output1': y_test, 'output2': y2_test})
|
|
loss = graph.evaluate({'input1': X_test, 'output1': y_test, 'output2': y2_test})
|
|
assert(loss < 4.)
|
|
|
|
# test weight saving
|
|
fname = 'test_2o_1i_weights_temp.h5'
|
|
graph.save_weights(fname, overwrite=True)
|
|
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(1), name='dense3', input='dense1')
|
|
graph.add_output(name='output1', input='dense2')
|
|
graph.add_output(name='output2', input='dense3')
|
|
graph.compile('rmsprop', {'output1': 'mse', 'output2': 'mse'})
|
|
graph.load_weights('test_2o_1i_weights_temp.h5')
|
|
os.remove(fname)
|
|
|
|
nloss = graph.evaluate({'input1': X_test, 'output1': y_test, 'output2': y2_test})
|
|
assert(loss == nloss)
|
|
|
|
|
|
def test_2o_1i_sample_weights():
|
|
# test a non-sequential graph with 1 input and 2 outputs with sample weights
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(1), name='dense3', input='dense1')
|
|
|
|
graph.add_output(name='output1', input='dense2')
|
|
graph.add_output(name='output2', input='dense3')
|
|
|
|
weights1 = np.random.uniform(size=y_train.shape[0])
|
|
weights2 = np.random.uniform(size=y2_train.shape[0])
|
|
weights1_test = np.random.uniform(size=y_test.shape[0])
|
|
weights2_test = np.random.uniform(size=y2_test.shape[0])
|
|
|
|
graph.compile('rmsprop', {'output1': 'mse', 'output2': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'output1': y_train, 'output2': y2_train},
|
|
nb_epoch=10,
|
|
sample_weight={'output1': weights1, 'output2': weights2})
|
|
out = graph.predict({'input1': X_test})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 2)
|
|
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test, 'output2': y2_test},
|
|
sample_weight={'output1': weights1_test, 'output2': weights2_test})
|
|
loss = graph.train_on_batch({'input1': X_train, 'output1': y_train, 'output2': y2_train},
|
|
sample_weight={'output1': weights1, 'output2': weights2})
|
|
loss = graph.evaluate({'input1': X_train, 'output1': y_train, 'output2': y2_train},
|
|
sample_weight={'output1': weights1, 'output2': weights2})
|
|
|
|
|
|
def test_recursive():
|
|
# test layer-like API
|
|
|
|
graph = containers.Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(4), name='dense3', input='dense1')
|
|
graph.add_output(name='output1', inputs=['dense2', 'dense3'],
|
|
merge_mode='sum')
|
|
|
|
seq = Sequential()
|
|
seq.add(Dense(32, input_shape=(32,)))
|
|
seq.add(graph)
|
|
seq.add(Dense(4))
|
|
|
|
seq.compile('rmsprop', 'mse')
|
|
|
|
history = seq.fit(X_train, y_train, batch_size=10, nb_epoch=10)
|
|
loss = seq.evaluate(X_test, y_test)
|
|
assert(loss < 2.5)
|
|
|
|
loss = seq.evaluate(X_test, y_test, show_accuracy=True)
|
|
pred = seq.predict(X_test)
|
|
seq.get_config(verbose=1)
|
|
|
|
|
|
def test_create_output():
|
|
# test create_output argument
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
|
|
graph.add_node(Dense(16), name='dense1', input='input1')
|
|
graph.add_node(Dense(4), name='dense2', input='input1')
|
|
graph.add_node(Dense(4), name='dense3', input='dense1')
|
|
graph.add_node(Dense(4), name='output1', inputs=['dense2', 'dense3'],
|
|
merge_mode='sum', create_output=True)
|
|
graph.compile('rmsprop', {'output1': 'mse'})
|
|
|
|
history = graph.fit({'input1': X_train, 'output1': y_train},
|
|
nb_epoch=10)
|
|
out = graph.predict({'input1': X_test})
|
|
assert(type(out == dict))
|
|
assert(len(out) == 1)
|
|
|
|
loss = graph.test_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.train_on_batch({'input1': X_test, 'output1': y_test})
|
|
loss = graph.evaluate({'input1': X_test, 'output1': y_test})
|
|
assert(loss < 2.5)
|
|
|
|
|
|
def test_count_params():
|
|
# test count params
|
|
|
|
nb_units = 100
|
|
nb_classes = 2
|
|
|
|
graph = Graph()
|
|
graph.add_input(name='input1', input_shape=(32,))
|
|
graph.add_input(name='input2', input_shape=(32,))
|
|
graph.add_node(Dense(nb_units),
|
|
name='dense1', input='input1')
|
|
graph.add_node(Dense(nb_classes),
|
|
name='dense2', input='input2')
|
|
graph.add_node(Dense(nb_classes),
|
|
name='dense3', input='dense1')
|
|
graph.add_output(name='output', inputs=['dense2', 'dense3'],
|
|
merge_mode='sum')
|
|
|
|
n = 32 * nb_units + nb_units
|
|
n += 32 * nb_classes + nb_classes
|
|
n += nb_units * nb_classes + nb_classes
|
|
|
|
assert(n == graph.count_params())
|
|
|
|
graph.compile('rmsprop', {'output': 'binary_crossentropy'})
|
|
|
|
assert(n == graph.count_params())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pytest.main([__file__])
|