keras/tests/test_multiprocessing.py
2017-03-29 14:02:43 +02:00

246 lines
6.9 KiB
Python

from __future__ import print_function
import os
import pytest
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense
from keras.utils.test_utils import keras_test
@keras_test
def test_multiprocessing_training():
arr_data = np.random.randint(0, 256, (50, 2))
arr_labels = np.random.randint(0, 2, 50)
def custom_generator():
batch_size = 10
n_samples = 50
while True:
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr_data[start: end]
y = arr_labels[start: end]
yield X, y
# Build a NN
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_q_size=10,
workers=4,
pickle_safe=True)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_q_size=10,
pickle_safe=False)
@keras_test
def test_multiprocessing_training_fromfile():
arr_data = np.random.randint(0, 256, (50, 2))
arr_labels = np.random.randint(0, 2, 50)
np.savez('data.npz', **{'data': arr_data, 'labels': arr_labels})
def custom_generator():
batch_size = 10
n_samples = 50
arr = np.load('data.npz')
while True:
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr['data'][start: end]
y = arr['labels'][start: end]
yield X, y
# Build a NN
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_q_size=10,
workers=2,
pickle_safe=True)
model.fit_generator(custom_generator(),
steps_per_epoch=5,
epochs=1,
verbose=1,
max_q_size=10,
pickle_safe=False)
os.remove('data.npz')
@keras_test
def test_multiprocessing_predicting():
arr_data = np.random.randint(0, 256, (50, 2))
def custom_generator():
batch_size = 10
n_samples = 50
while True:
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr_data[start: end]
yield X
# Build a NN
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
model.predict_generator(custom_generator(),
steps=5,
max_q_size=10,
workers=2,
pickle_safe=True)
model.predict_generator(custom_generator(),
steps=5,
max_q_size=10,
pickle_safe=False)
@keras_test
def test_multiprocessing_evaluating():
arr_data = np.random.randint(0, 256, (50, 2))
arr_labels = np.random.randint(0, 2, 50)
def custom_generator():
batch_size = 10
n_samples = 50
while True:
batch_index = np.random.randint(0, n_samples - batch_size)
start = batch_index
end = start + batch_size
X = arr_data[start: end]
y = arr_labels[start: end]
yield X, y
# Build a NN
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
model.evaluate_generator(custom_generator(),
steps=5,
max_q_size=10,
workers=2,
pickle_safe=True)
model.evaluate_generator(custom_generator(),
steps=5,
max_q_size=10,
pickle_safe=False)
@keras_test
def test_multiprocessing_fit_error():
batch_size = 10
good_batches = 3
def custom_generator():
"""Raises an exception after a few good batches"""
for i in range(good_batches):
yield (np.random.randint(batch_size, 256, (50, 2)),
np.random.randint(batch_size, 2, 50))
raise RuntimeError
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
samples = batch_size * (good_batches + 1)
with pytest.raises(Exception):
model.fit_generator(
custom_generator(), samples, 1,
workers=4, pickle_safe=True,
)
with pytest.raises(Exception):
model.fit_generator(
custom_generator(), samples, 1,
pickle_safe=False,
)
@keras_test
def test_multiprocessing_evaluate_error():
batch_size = 10
good_batches = 3
def custom_generator():
"""Raises an exception after a few good batches"""
for i in range(good_batches):
yield (np.random.randint(batch_size, 256, (50, 2)),
np.random.randint(batch_size, 2, 50))
raise RuntimeError
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
with pytest.raises(Exception):
model.evaluate_generator(
custom_generator(), good_batches + 1, 1,
workers=4, pickle_safe=True,
)
with pytest.raises(Exception):
model.evaluate_generator(
custom_generator(), good_batches + 1, 1,
pickle_safe=False,
)
@keras_test
def test_multiprocessing_predict_error():
batch_size = 10
good_batches = 3
def custom_generator():
"""Raises an exception after a few good batches"""
for i in range(good_batches):
yield (np.random.randint(batch_size, 256, (50, 2)),
np.random.randint(batch_size, 2, 50))
raise RuntimeError
model = Sequential()
model.add(Dense(1, input_shape=(2, )))
model.compile(loss='mse', optimizer='adadelta')
with pytest.raises(Exception):
model.predict_generator(
custom_generator(), good_batches + 1, 1,
workers=4, pickle_safe=True,
)
with pytest.raises(Exception):
model.predict_generator(
custom_generator(), good_batches + 1, 1,
pickle_safe=False,
)
if __name__ == '__main__':
pytest.main([__file__])