Remove reference to legacy Graph model in tests.
This commit is contained in:
parent
703d5a1298
commit
c0b32a9a04
@ -186,13 +186,12 @@ def test_ReduceLROnPlateau():
|
||||
assert np.allclose(float(K.get_value(model.optimizer.lr)), 0.1, atol=K.epsilon())
|
||||
|
||||
|
||||
@pytest.mark.skipif((K._BACKEND != 'tensorflow'),
|
||||
@pytest.mark.skipif((K.backend() != 'tensorflow'),
|
||||
reason="Requires tensorflow backend")
|
||||
def test_TensorBoard():
|
||||
import shutil
|
||||
import tensorflow as tf
|
||||
import keras.backend.tensorflow_backend as KTF
|
||||
old_session = KTF.get_session()
|
||||
filepath = './logs'
|
||||
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
|
||||
nb_test=test_samples,
|
||||
@ -224,10 +223,6 @@ def test_TensorBoard():
|
||||
yield {'X_vars': X_test, 'output': y_test}
|
||||
|
||||
# case 1 Sequential
|
||||
|
||||
with tf.Graph().as_default():
|
||||
session = tf.Session('')
|
||||
KTF.set_session(session)
|
||||
model = Sequential()
|
||||
model.add(Dense(nb_hidden, input_dim=input_dim, activation='relu'))
|
||||
model.add(Dense(nb_class, activation='softmax'))
|
||||
@ -267,50 +262,6 @@ def test_TensorBoard():
|
||||
assert os.path.exists(filepath)
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
# case 2 Graph
|
||||
|
||||
with tf.Graph().as_default():
|
||||
session = tf.Session('')
|
||||
KTF.set_session(session)
|
||||
model = Graph()
|
||||
model.add_input(name='X_vars', input_shape=(input_dim,))
|
||||
|
||||
model.add_node(Dense(nb_hidden, activation="sigmoid"),
|
||||
name='Dense1', input='X_vars')
|
||||
model.add_node(Dense(nb_class, activation="softmax"),
|
||||
name='last_dense',
|
||||
input='Dense1')
|
||||
model.add_output(name='output', input='last_dense')
|
||||
model.compile(optimizer='sgd', loss={'output': 'mse'})
|
||||
|
||||
tsb = callbacks.TensorBoard(log_dir=filepath, histogram_freq=1)
|
||||
cbks = [tsb]
|
||||
|
||||
# fit with validation
|
||||
model.fit({'X_vars': X_train, 'output': y_train},
|
||||
batch_size=batch_size,
|
||||
validation_data={'X_vars': X_test, 'output': y_test},
|
||||
callbacks=cbks, nb_epoch=2)
|
||||
|
||||
# fit wo validation
|
||||
model.fit({'X_vars': X_train, 'output': y_train},
|
||||
batch_size=batch_size,
|
||||
callbacks=cbks, nb_epoch=2)
|
||||
|
||||
# fit generator with validation
|
||||
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
|
||||
validation_data={'X_vars': X_test, 'output': y_test},
|
||||
callbacks=cbks)
|
||||
|
||||
# fit generator wo validation
|
||||
model.fit_generator(data_generator_graph(True), 1000, nb_epoch=2,
|
||||
callbacks=cbks)
|
||||
|
||||
assert os.path.exists(filepath)
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
KTF.set_session(old_session)
|
||||
|
||||
|
||||
def test_LambdaCallback():
|
||||
(X_train, y_train), (X_test, y_test) = get_test_data(nb_train=train_samples,
|
||||
@ -343,7 +294,7 @@ def test_LambdaCallback():
|
||||
assert not p.is_alive()
|
||||
|
||||
|
||||
@pytest.mark.skipif((K._BACKEND != 'tensorflow'),
|
||||
@pytest.mark.skipif((K.backend() != 'tensorflow'),
|
||||
reason="Requires tensorflow backend")
|
||||
def test_TensorBoard_with_ReduceLROnPlateau():
|
||||
import shutil
|
||||
|
Loading…
Reference in New Issue
Block a user