Remove reference to legacy Graph model in tests.

This commit is contained in:
fchollet 2016-11-25 01:20:10 -08:00
parent 703d5a1298
commit c0b32a9a04

@ -186,13 +186,12 @@ def test_ReduceLROnPlateau():
assert np.allclose(float(K.get_value(model.optimizer.lr)), 0.1, atol=K.epsilon())
@pytest.mark.skipif((K._BACKEND != 'tensorflow'),
@pytest.mark.skipif((K.backend() != 'tensorflow'),
reason="Requires tensorflow backend")
def test_TensorBoard():
import shutil
import tensorflow as tf
import keras.backend.tensorflow_backend as KTF
old_session = KTF.get_session()
filepath = './logs'
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
nb_test=test_samples,
@ -224,10 +223,6 @@ def test_TensorBoard():
yield {'X_vars': X_test, 'output': y_test}
# case 1 Sequential
with tf.Graph().as_default():
session = tf.Session('')
KTF.set_session(session)
model = Sequential()
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
model.add(Dense(nb_class, activation='softmax'))
@ -267,50 +262,6 @@ def test_TensorBoard():
assert os.path.exists(filepath)
shutil.rmtree(filepath)
# case 2 Graph
with tf.Graph().as_default():
session = tf.Session('')
KTF.set_session(session)
model = Graph()
model.add_input(name='X_vars', input_shape=(input_dim,))
model.add_node(Dense(nb_hidden, activation="sigmoid"),
name='Dense1', input='X_vars')
model.add_node(Dense(nb_class, activation="softmax"),
name='last_dense',
input='Dense1')
model.add_output(name='output', input='last_dense')
model.compile(optimizer='sgd', loss={'output': 'mse'})
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
cbks = [tsb]
# fit with validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks, nb_epoch=2)
# fit wo validation
model.fit({'X_vars': X_train, 'output': y_train},
batch_size=batch_size,
callbacks=cbks, nb_epoch=2)
# fit generator with validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
validation_data={'X_vars': X_test, 'output': y_test},
callbacks=cbks)
# fit generator wo validation
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
callbacks=cbks)
assert os.path.exists(filepath)
shutil.rmtree(filepath)
KTF.set_session(old_session)
def test_LambdaCallback():
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
@ -343,7 +294,7 @@ def test_LambdaCallback():
assert not p.is_alive()
@pytest.mark.skipif((K._BACKEND != 'tensorflow'),
@pytest.mark.skipif((K.backend() != 'tensorflow'),
reason="Requires tensorflow backend")
def test_TensorBoard_with_ReduceLROnPlateau():
import shutil