Merge branch 'master' of ssh://github.com/fchollet/keras
This commit is contained in:
commit
a5a775b79f
@ -89,6 +89,7 @@ from keras.utils import data_utils
|
||||
from keras.utils import io_utils
|
||||
from keras.utils import layer_utils
|
||||
from keras.utils import np_utils
|
||||
from keras.utils import generic_utils
|
||||
|
||||
|
||||
EXCLUDE = {
|
||||
@ -265,6 +266,11 @@ PAGES = [
|
||||
'page': 'utils/np_utils.md',
|
||||
'all_module_functions': [np_utils]
|
||||
},
|
||||
{
|
||||
'page': 'utils/generic_utils.md',
|
||||
'all_module_functions': [generic_utils],
|
||||
'classes': [generic_utils.CustomObjectScope]
|
||||
},
|
||||
]
|
||||
|
||||
ROOT = 'http://keras.io/'
|
||||
|
@ -55,6 +55,7 @@ pages:
|
||||
- I/O Utils: utils/io_utils.md
|
||||
- Layer Utils: utils/layer_utils.md
|
||||
- Numpy Utils: utils/np_utils.md
|
||||
- Generic Utils: utils/generic_utils.md
|
||||
|
||||
|
||||
|
||||
|
@ -15,4 +15,4 @@ from . import objectives
|
||||
from . import optimizers
|
||||
from . import regularizers
|
||||
|
||||
__version__ = '1.2.0'
|
||||
__version__ = '1.2.1'
|
||||
|
@ -62,7 +62,7 @@ def preprocess_input(audio_path, dim_ordering='default'):
|
||||
|
||||
logam = librosa.logamplitude
|
||||
melgram = librosa.feature.melspectrogram
|
||||
x = logam(melgram(y=src, sr=sr, hop_lengthgth=hop_length,
|
||||
x = logam(melgram(y=src, sr=sr, hop_length=hop_length,
|
||||
n_fft=n_fft, n_mels=n_mels) ** 2,
|
||||
ref_power=1.0)
|
||||
|
||||
|
@ -77,6 +77,9 @@ def learning_phase():
|
||||
def set_learning_phase(value):
|
||||
"""Sets the learning phase to a fixed value,
|
||||
either 0 or 1 (integers).
|
||||
|
||||
# Raises
|
||||
ValueError: if `value` is neither `0` nor `1`.
|
||||
"""
|
||||
global _GRAPH_LEARNING_PHASES
|
||||
if value not in {0, 1}:
|
||||
@ -534,15 +537,17 @@ def eye(size, dtype=None, name=None):
|
||||
return variable(np.eye(size), dtype, name)
|
||||
|
||||
|
||||
def zeros_like(x, name=None):
|
||||
def zeros_like(x, dtype=None, name=None):
|
||||
"""Instantiates an all-zeros Keras variable
|
||||
of the same shape as another Keras variable or tensor and returns it.
|
||||
|
||||
# Arguments
|
||||
x: Keras variable or Keras tensor.
|
||||
dtype: String, dtype of returned Keras variable.
|
||||
None uses the dtype of x.
|
||||
|
||||
# Returns
|
||||
A Keras variable, filled with `0.0`.
|
||||
A Keras variable with the shape of x filled with zeros.
|
||||
|
||||
# Example
|
||||
```python
|
||||
@ -554,18 +559,20 @@ def zeros_like(x, name=None):
|
||||
[ 0., 0., 0.]], dtype=float32)
|
||||
```
|
||||
"""
|
||||
return tf.zeros_like(x, name=name)
|
||||
return tf.zeros_like(x, dtype=dtype, name=name)
|
||||
|
||||
|
||||
def ones_like(x, name=None):
|
||||
def ones_like(x, dtype=None, name=None):
|
||||
"""Instantiates an all-ones Keras variable
|
||||
of the same shape as another Keras variable or tensor and returns it.
|
||||
|
||||
# Arguments
|
||||
x: Keras variable or tensor.
|
||||
dtype: String, dtype of returned Keras variable.
|
||||
None uses the dtype of x.
|
||||
|
||||
# Returns
|
||||
A Keras variable, filled with `1.0`.
|
||||
A Keras variable with the shape of x filled with ones.
|
||||
|
||||
# Example
|
||||
```python
|
||||
@ -577,7 +584,7 @@ def ones_like(x, name=None):
|
||||
[ 1., 1., 1.]], dtype=float32)
|
||||
```
|
||||
"""
|
||||
return tf.ones_like(x, name=name)
|
||||
return tf.ones_like(x, dtype=dtype, name=name)
|
||||
|
||||
|
||||
def random_uniform_variable(shape, low, high, dtype=None,
|
||||
@ -779,16 +786,20 @@ def dot(x, y):
|
||||
(2, 4, 5)
|
||||
```
|
||||
"""
|
||||
if hasattr(tf, 'unstack'):
|
||||
unstack = tf.unstack
|
||||
else:
|
||||
unstack = tf.unpack
|
||||
if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
|
||||
x_shape = []
|
||||
for i, s in zip(int_shape(x), tf.unpack(tf.shape(x))):
|
||||
for i, s in zip(int_shape(x), unstack(tf.shape(x))):
|
||||
if i is not None:
|
||||
x_shape.append(i)
|
||||
else:
|
||||
x_shape.append(s)
|
||||
x_shape = tuple(x_shape)
|
||||
y_shape = []
|
||||
for i, s in zip(int_shape(y), tf.unpack(tf.shape(y))):
|
||||
for i, s in zip(int_shape(y), unstack(tf.shape(y))):
|
||||
if i is not None:
|
||||
y_shape.append(i)
|
||||
else:
|
||||
@ -858,6 +869,8 @@ def batch_dot(x, y, axes=None):
|
||||
(32, 1, 30)
|
||||
```
|
||||
"""
|
||||
if ndim(x) < 3 or ndim(y) < 3:
|
||||
raise ValueError('Invalid dimensions for batch_dot: ', ndim(x), ndim(y))
|
||||
if isinstance(axes, int):
|
||||
axes = (axes, axes)
|
||||
if axes is not None:
|
||||
@ -1208,6 +1221,8 @@ def log(x):
|
||||
def round(x):
|
||||
"""Element-wise rounding to the closest integer.
|
||||
|
||||
In case of tie, the rounding mode used is "half to even".
|
||||
|
||||
# Arguments
|
||||
x: input tensor.
|
||||
|
||||
@ -1448,6 +1463,9 @@ def resize_images(X, height_factor, width_factor, dim_ordering):
|
||||
|
||||
# Returns
|
||||
A tensor.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'th':
|
||||
original_shape = int_shape(X)
|
||||
@ -1480,6 +1498,9 @@ def resize_volumes(X, depth_factor, height_factor, width_factor, dim_ordering):
|
||||
|
||||
# Returns
|
||||
A tensor.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'th':
|
||||
output = repeat_elements(X, depth_factor, axis=2)
|
||||
@ -1633,6 +1654,9 @@ def spatial_2d_padding(x, padding=(1, 1), dim_ordering='default'):
|
||||
|
||||
# Returns
|
||||
A padded 4D tensor.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -1658,6 +1682,9 @@ def asymmetric_spatial_2d_padding(x, top_pad=1, bottom_pad=1,
|
||||
|
||||
# Returns
|
||||
A padded 4D tensor.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -1686,6 +1713,10 @@ def spatial_3d_padding(x, padding=(1, 1, 1), dim_ordering='default'):
|
||||
|
||||
# Returns
|
||||
A padded 5D tensor.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -1967,6 +1998,12 @@ def rnn(step_function, inputs, initial_states,
|
||||
at time `t` for sample `s`.
|
||||
new_states: list of tensors, latest states returned by
|
||||
the step function, of shape `(samples, ...)`.
|
||||
|
||||
# Raises
|
||||
ValueError: if input dimension is less than 3.
|
||||
ValueError: if `unroll` is `True` but input timestep is not a fixed number.
|
||||
ValueError: if `mask` is provided (not `None`) but states is not provided
|
||||
(`len(states)` == 0).
|
||||
"""
|
||||
ndim = len(inputs.get_shape())
|
||||
if ndim < 3:
|
||||
@ -1987,6 +2024,12 @@ def rnn(step_function, inputs, initial_states,
|
||||
# TODO: remove later.
|
||||
if hasattr(tf, 'select'):
|
||||
tf.where = tf.select
|
||||
if hasattr(tf, 'stack'):
|
||||
stack = tf.stack
|
||||
unstack = tf.unstack
|
||||
else:
|
||||
stack = tf.pack
|
||||
unstack = tf.unpack
|
||||
|
||||
if unroll:
|
||||
if not inputs.get_shape()[0]:
|
||||
@ -1996,12 +2039,12 @@ def rnn(step_function, inputs, initial_states,
|
||||
successive_states = []
|
||||
successive_outputs = []
|
||||
|
||||
input_list = tf.unpack(inputs)
|
||||
input_list = unstack(inputs)
|
||||
if go_backwards:
|
||||
input_list.reverse()
|
||||
|
||||
if mask is not None:
|
||||
mask_list = tf.unpack(mask)
|
||||
mask_list = unstack(mask)
|
||||
if go_backwards:
|
||||
mask_list.reverse()
|
||||
|
||||
@ -2066,7 +2109,10 @@ def rnn(step_function, inputs, initial_states,
|
||||
dtype=inputs.dtype,
|
||||
size=time_steps,
|
||||
tensor_array_name='input_ta')
|
||||
input_ta = input_ta.unpack(inputs)
|
||||
if hasattr(input_ta, 'unstack'):
|
||||
input_ta = input_ta.unstack(inputs)
|
||||
else:
|
||||
input_ta = input_ta.unpack(inputs)
|
||||
time = tf.constant(0, dtype='int32', name='time')
|
||||
|
||||
if mask is not None:
|
||||
@ -2084,7 +2130,10 @@ def rnn(step_function, inputs, initial_states,
|
||||
dtype=tf.bool,
|
||||
size=time_steps,
|
||||
tensor_array_name='mask_ta')
|
||||
mask_ta = mask_ta.unpack(mask)
|
||||
if hasattr(mask_ta, 'unstack'):
|
||||
mask_ta = mask_ta.unstack(mask)
|
||||
else:
|
||||
mask_ta = mask_ta.unpack(mask)
|
||||
|
||||
def _step(time, output_ta_t, *states):
|
||||
current_input = input_ta.read(time)
|
||||
@ -2121,7 +2170,10 @@ def rnn(step_function, inputs, initial_states,
|
||||
output_ta = final_outputs[1]
|
||||
new_states = final_outputs[2:]
|
||||
|
||||
outputs = output_ta.pack()
|
||||
if hasattr(output_ta, 'stack'):
|
||||
outputs = output_ta.stack()
|
||||
else:
|
||||
outputs = output_ta.pack()
|
||||
last_output = output_ta.read(last_time - 1)
|
||||
|
||||
axes = [1, 0] + list(range(2, len(outputs.get_shape())))
|
||||
@ -2469,6 +2521,7 @@ def _preprocess_deconv_output_shape(x, shape, dim_ordering):
|
||||
|
||||
if shape[0] is None:
|
||||
shape = (tf.shape(x)[0], ) + tuple(shape[1:])
|
||||
shape = tf.stack(list(shape))
|
||||
return shape
|
||||
|
||||
|
||||
@ -2588,6 +2641,9 @@ def conv2d(x, kernel, strides=(1, 1), border_mode='valid',
|
||||
|
||||
# Returns
|
||||
A tensor, result of 2D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2625,6 +2681,9 @@ def deconv2d(x, kernel, output_shape, strides=(1, 1),
|
||||
|
||||
# Returns
|
||||
A tensor, result of transposed 2D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2662,6 +2721,9 @@ def atrous_conv2d(x, kernel, rate=1,
|
||||
|
||||
# Returns
|
||||
A tensor, result of atrous transposed 2D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2682,6 +2744,9 @@ def atrous_conv2d(x, kernel, rate=1,
|
||||
def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
|
||||
border_mode='valid', dim_ordering='default'):
|
||||
"""2-D convolution with separable filters.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2716,6 +2781,9 @@ def conv3d(x, kernel, strides=(1, 1, 1),
|
||||
|
||||
# Returns
|
||||
A tensor, result of 3D convolution.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2745,6 +2813,10 @@ def pool2d(x, pool_size, strides=(1, 1),
|
||||
|
||||
# Returns
|
||||
A tensor, result of 2D pooling.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
ValueError: if `pool_mode` is neither `max` or `avg`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
@ -2780,6 +2852,10 @@ def pool3d(x, pool_size, strides=(1, 1, 1), border_mode='valid',
|
||||
|
||||
# Returns
|
||||
A tensor, result of 3D pooling.
|
||||
|
||||
# Raises
|
||||
ValueError: if `dim_ordering` is neither `tf` or `th`.
|
||||
ValueError: if `pool_mode` is neither `max` or `avg`.
|
||||
"""
|
||||
if dim_ordering == 'default':
|
||||
dim_ordering = image_dim_ordering()
|
||||
|
@ -176,12 +176,12 @@ def eye(size, dtype=None, name=None):
|
||||
return variable(np.eye(size), dtype, name)
|
||||
|
||||
|
||||
def ones_like(x, name=None):
|
||||
return T.ones_like(x)
|
||||
def ones_like(x, dtype=None, name=None):
|
||||
return T.ones_like(x, dtype=dtype)
|
||||
|
||||
|
||||
def zeros_like(x, name=None):
|
||||
return T.zeros_like(x)
|
||||
def zeros_like(x, dtype=None, name=None):
|
||||
return T.zeros_like(x, dtype=dtype)
|
||||
|
||||
|
||||
def random_uniform_variable(shape, low, high, dtype=None, name=None):
|
||||
@ -389,7 +389,7 @@ def log(x):
|
||||
|
||||
|
||||
def round(x):
|
||||
return T.round(x)
|
||||
return T.round(x, mode='half_to_even')
|
||||
|
||||
|
||||
def sign(x):
|
||||
@ -1075,6 +1075,8 @@ def rnn(step_function, inputs, initial_states,
|
||||
initial_output = step_function(inputs[0], initial_states + constants)[0] * 0
|
||||
# Theano gets confused by broadcasting patterns in the scan op
|
||||
initial_output = T.unbroadcast(initial_output, 0, 1)
|
||||
if len(initial_states) > 0:
|
||||
initial_states[0] = T.unbroadcast(initial_states[0], 0, 1)
|
||||
|
||||
def _step(input, mask, output_tm1, *states):
|
||||
output, new_states = step_function(input, states)
|
||||
@ -1122,6 +1124,10 @@ def rnn(step_function, inputs, initial_states,
|
||||
output, new_states = step_function(input, states)
|
||||
return [output] + new_states
|
||||
|
||||
# Theano likes to make shape==1 dimensions in the initial states (outputs_info) broadcastable
|
||||
if len(initial_states) > 0:
|
||||
initial_states[0] = T.unbroadcast(initial_states[0], 1)
|
||||
|
||||
results, _ = theano.scan(
|
||||
_step,
|
||||
sequences=inputs,
|
||||
|
@ -396,55 +396,94 @@ def standardize_weights(y, sample_weight=None, class_weight=None,
|
||||
return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx())
|
||||
|
||||
|
||||
def generator_queue(generator, max_q_size=10,
|
||||
wait_time=0.05, nb_worker=1, pickle_safe=False):
|
||||
class GeneratorEnqueuer(object):
|
||||
"""Builds a queue out of a data generator.
|
||||
If pickle_safe, use a multiprocessing approach. Else, use threading.
|
||||
Used in `fit_generator`, `evaluate_generator`, `predict_generator`.
|
||||
"""
|
||||
generator_threads = []
|
||||
if pickle_safe:
|
||||
q = multiprocessing.Queue(maxsize=max_q_size)
|
||||
_stop = multiprocessing.Event()
|
||||
else:
|
||||
q = queue.Queue()
|
||||
_stop = threading.Event()
|
||||
|
||||
try:
|
||||
# Arguments
|
||||
generator: a generator function which endlessly yields data
|
||||
pickle_safe: use multiprocessing if True, otherwise threading
|
||||
"""
|
||||
|
||||
def __init__(self, generator, pickle_safe=False):
|
||||
self._generator = generator
|
||||
self._pickle_safe = pickle_safe
|
||||
self._threads = []
|
||||
self._stop_event = None
|
||||
|
||||
self.queue = None
|
||||
|
||||
def start(self, nb_worker=1, max_q_size=10, wait_time=0.05):
|
||||
"""Kick off threads which add data from the generator into the queue.
|
||||
|
||||
# Arguments
|
||||
nb_worker: number of worker threads
|
||||
max_q_size: queue size (when full, threads could block on put())
|
||||
wait_time: time to sleep in-between calls to put()
|
||||
"""
|
||||
|
||||
def data_generator_task():
|
||||
while not _stop.is_set():
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
if pickle_safe or q.qsize() < max_q_size:
|
||||
generator_output = next(generator)
|
||||
q.put(generator_output)
|
||||
if self._pickle_safe or self.queue.qsize() < max_q_size:
|
||||
generator_output = next(self._generator)
|
||||
self.queue.put(generator_output)
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
except Exception:
|
||||
_stop.set()
|
||||
self._stop_event.set()
|
||||
raise
|
||||
|
||||
for i in range(nb_worker):
|
||||
if pickle_safe:
|
||||
# Reset random seed else all children processes
|
||||
# share the same seed
|
||||
np.random.seed()
|
||||
thread = multiprocessing.Process(target=data_generator_task)
|
||||
try:
|
||||
if self._pickle_safe:
|
||||
self.queue = multiprocessing.Queue(maxsize=max_q_size)
|
||||
self._stop_event = multiprocessing.Event()
|
||||
else:
|
||||
thread = threading.Thread(target=data_generator_task)
|
||||
generator_threads.append(thread)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
except:
|
||||
_stop.set()
|
||||
if pickle_safe:
|
||||
# Terminate all daemon processes
|
||||
for p in generator_threads:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
q.close()
|
||||
raise
|
||||
self.queue = queue.Queue()
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
return q, _stop, generator_threads
|
||||
for i in range(nb_worker):
|
||||
if self._pickle_safe:
|
||||
# Reset random seed else all children processes
|
||||
# share the same seed
|
||||
np.random.seed()
|
||||
thread = multiprocessing.Process(target=data_generator_task)
|
||||
thread.daemon = True
|
||||
else:
|
||||
thread = threading.Thread(target=data_generator_task)
|
||||
self._threads.append(thread)
|
||||
thread.start()
|
||||
except:
|
||||
self.stop()
|
||||
raise
|
||||
|
||||
def is_running(self):
|
||||
return self._stop_event is not None and not self._stop_event.is_set()
|
||||
|
||||
def stop(self, timeout=None):
|
||||
"""Stop running threads and wait for them to exit, if necessary.
|
||||
Should be called by the same thread which called start().
|
||||
|
||||
# Arguments
|
||||
timeout: maximum time to wait on thread.join()
|
||||
"""
|
||||
if self.is_running():
|
||||
self._stop_event.set()
|
||||
|
||||
for thread in self._threads:
|
||||
if thread.is_alive():
|
||||
if self._pickle_safe:
|
||||
thread.terminate()
|
||||
else:
|
||||
thread.join(timeout)
|
||||
|
||||
if self._pickle_safe:
|
||||
if self.queue is not None:
|
||||
self.queue.close()
|
||||
|
||||
self._threads = []
|
||||
self._stop_event = None
|
||||
self.queue = None
|
||||
|
||||
|
||||
class Model(Container):
|
||||
@ -1462,117 +1501,107 @@ class Model(Container):
|
||||
else:
|
||||
self.validation_data = None
|
||||
|
||||
# start generator thread storing batches into a queue
|
||||
data_gen_queue, _stop, generator_threads = generator_queue(
|
||||
generator,
|
||||
max_q_size=max_q_size,
|
||||
nb_worker=nb_worker,
|
||||
pickle_safe=pickle_safe)
|
||||
enqueuer = None
|
||||
|
||||
callback_model.stop_training = False
|
||||
while epoch < nb_epoch:
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
samples_seen = 0
|
||||
batch_index = 0
|
||||
while samples_seen < samples_per_epoch:
|
||||
generator_output = None
|
||||
while not _stop.is_set():
|
||||
if not data_gen_queue.empty():
|
||||
generator_output = data_gen_queue.get()
|
||||
break
|
||||
try:
|
||||
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
||||
enqueuer.start(max_q_size=max_q_size, nb_worker=nb_worker)
|
||||
|
||||
callback_model.stop_training = False
|
||||
while epoch < nb_epoch:
|
||||
callbacks.on_epoch_begin(epoch)
|
||||
samples_seen = 0
|
||||
batch_index = 0
|
||||
while samples_seen < samples_per_epoch:
|
||||
generator_output = None
|
||||
while enqueuer.is_running():
|
||||
if not enqueuer.queue.empty():
|
||||
generator_output = enqueuer.queue.get()
|
||||
break
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
|
||||
if not hasattr(generator_output, '__len__'):
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
# build batch logs
|
||||
batch_logs = {}
|
||||
if isinstance(x, list):
|
||||
batch_size = x[0].shape[0]
|
||||
elif isinstance(x, dict):
|
||||
batch_size = list(x.values())[0].shape[0]
|
||||
else:
|
||||
batch_size = x.shape[0]
|
||||
batch_logs['batch'] = batch_index
|
||||
batch_logs['size'] = batch_size
|
||||
callbacks.on_batch_begin(batch_index, batch_logs)
|
||||
|
||||
if not hasattr(generator_output, '__len__'):
|
||||
_stop.set()
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
_stop.set()
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
# build batch logs
|
||||
batch_logs = {}
|
||||
if isinstance(x, list):
|
||||
batch_size = x[0].shape[0]
|
||||
elif isinstance(x, dict):
|
||||
batch_size = list(x.values())[0].shape[0]
|
||||
else:
|
||||
batch_size = x.shape[0]
|
||||
batch_logs['batch'] = batch_index
|
||||
batch_logs['size'] = batch_size
|
||||
callbacks.on_batch_begin(batch_index, batch_logs)
|
||||
|
||||
try:
|
||||
outs = self.train_on_batch(x, y,
|
||||
sample_weight=sample_weight,
|
||||
class_weight=class_weight)
|
||||
except:
|
||||
_stop.set()
|
||||
raise
|
||||
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
for l, o in zip(out_labels, outs):
|
||||
batch_logs[l] = o
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
for l, o in zip(out_labels, outs):
|
||||
batch_logs[l] = o
|
||||
|
||||
callbacks.on_batch_end(batch_index, batch_logs)
|
||||
callbacks.on_batch_end(batch_index, batch_logs)
|
||||
|
||||
# construct epoch logs
|
||||
epoch_logs = {}
|
||||
batch_index += 1
|
||||
samples_seen += batch_size
|
||||
# construct epoch logs
|
||||
epoch_logs = {}
|
||||
batch_index += 1
|
||||
samples_seen += batch_size
|
||||
|
||||
# epoch finished
|
||||
if samples_seen > samples_per_epoch:
|
||||
warnings.warn('Epoch comprised more than '
|
||||
'`samples_per_epoch` samples, '
|
||||
'which might affect learning results. '
|
||||
'Set `samples_per_epoch` correctly '
|
||||
'to avoid this warning.')
|
||||
if samples_seen >= samples_per_epoch and do_validation:
|
||||
if val_gen:
|
||||
val_outs = self.evaluate_generator(
|
||||
validation_data,
|
||||
nb_val_samples,
|
||||
max_q_size=max_q_size,
|
||||
nb_worker=nb_worker,
|
||||
pickle_safe=pickle_safe)
|
||||
else:
|
||||
# no need for try/except because
|
||||
# data has already been validated
|
||||
val_outs = self.evaluate(
|
||||
val_x, val_y,
|
||||
batch_size=batch_size,
|
||||
sample_weight=val_sample_weights,
|
||||
verbose=0)
|
||||
if not isinstance(val_outs, list):
|
||||
val_outs = [val_outs]
|
||||
# same labels assumed
|
||||
for l, o in zip(out_labels, val_outs):
|
||||
epoch_logs['val_' + l] = o
|
||||
# epoch finished
|
||||
if samples_seen > samples_per_epoch:
|
||||
warnings.warn('Epoch comprised more than '
|
||||
'`samples_per_epoch` samples, '
|
||||
'which might affect learning results. '
|
||||
'Set `samples_per_epoch` correctly '
|
||||
'to avoid this warning.')
|
||||
if samples_seen >= samples_per_epoch and do_validation:
|
||||
if val_gen:
|
||||
val_outs = self.evaluate_generator(
|
||||
validation_data,
|
||||
nb_val_samples,
|
||||
max_q_size=max_q_size,
|
||||
nb_worker=nb_worker,
|
||||
pickle_safe=pickle_safe)
|
||||
else:
|
||||
# no need for try/except because
|
||||
# data has already been validated
|
||||
val_outs = self.evaluate(
|
||||
val_x, val_y,
|
||||
batch_size=batch_size,
|
||||
sample_weight=val_sample_weights,
|
||||
verbose=0)
|
||||
if not isinstance(val_outs, list):
|
||||
val_outs = [val_outs]
|
||||
# same labels assumed
|
||||
for l, o in zip(out_labels, val_outs):
|
||||
epoch_logs['val_' + l] = o
|
||||
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
epoch += 1
|
||||
if callback_model.stop_training:
|
||||
break
|
||||
callbacks.on_epoch_end(epoch, epoch_logs)
|
||||
epoch += 1
|
||||
if callback_model.stop_training:
|
||||
break
|
||||
|
||||
finally:
|
||||
if enqueuer is not None:
|
||||
enqueuer.stop()
|
||||
|
||||
_stop.set()
|
||||
if pickle_safe:
|
||||
# Terminate all daemon processes
|
||||
for p in generator_threads:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
data_gen_queue.close()
|
||||
callbacks.on_train_end()
|
||||
return self.history
|
||||
|
||||
@ -1611,60 +1640,53 @@ class Model(Container):
|
||||
wait_time = 0.01
|
||||
all_outs = []
|
||||
weights = []
|
||||
data_gen_queue, _stop, generator_threads = generator_queue(
|
||||
generator,
|
||||
max_q_size=max_q_size,
|
||||
nb_worker=nb_worker,
|
||||
pickle_safe=pickle_safe)
|
||||
|
||||
while processed_samples < val_samples:
|
||||
generator_output = None
|
||||
while not _stop.is_set():
|
||||
if not data_gen_queue.empty():
|
||||
generator_output = data_gen_queue.get()
|
||||
break
|
||||
enqueuer = None
|
||||
|
||||
try:
|
||||
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
||||
enqueuer.start(nb_worker=nb_worker, max_q_size=max_q_size)
|
||||
|
||||
while processed_samples < val_samples:
|
||||
generator_output = None
|
||||
while enqueuer.is_running():
|
||||
if not enqueuer.queue.empty():
|
||||
generator_output = enqueuer.queue.get()
|
||||
break
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
|
||||
if not hasattr(generator_output, '__len__'):
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' + str(generator_output))
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' + str(generator_output))
|
||||
|
||||
if not hasattr(generator_output, '__len__'):
|
||||
_stop.set()
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' + str(generator_output))
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
_stop.set()
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' + str(generator_output))
|
||||
try:
|
||||
outs = self.test_on_batch(x, y, sample_weight=sample_weight)
|
||||
except:
|
||||
_stop.set()
|
||||
raise
|
||||
|
||||
if isinstance(x, list):
|
||||
nb_samples = len(x[0])
|
||||
elif isinstance(x, dict):
|
||||
nb_samples = len(list(x.values())[0])
|
||||
else:
|
||||
nb_samples = len(x)
|
||||
all_outs.append(outs)
|
||||
if isinstance(x, list):
|
||||
nb_samples = len(x[0])
|
||||
elif isinstance(x, dict):
|
||||
nb_samples = len(list(x.values())[0])
|
||||
else:
|
||||
nb_samples = len(x)
|
||||
all_outs.append(outs)
|
||||
|
||||
processed_samples += nb_samples
|
||||
weights.append(nb_samples)
|
||||
processed_samples += nb_samples
|
||||
weights.append(nb_samples)
|
||||
|
||||
finally:
|
||||
if enqueuer is not None:
|
||||
enqueuer.stop()
|
||||
|
||||
_stop.set()
|
||||
if pickle_safe:
|
||||
# Terminate all daemon processes
|
||||
for p in generator_threads:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
data_gen_queue.close()
|
||||
if not isinstance(outs, list):
|
||||
return np.average(np.asarray(all_outs),
|
||||
weights=weights)
|
||||
@ -1704,68 +1726,61 @@ class Model(Container):
|
||||
processed_samples = 0
|
||||
wait_time = 0.01
|
||||
all_outs = []
|
||||
data_gen_queue, _stop, generator_threads = generator_queue(
|
||||
generator,
|
||||
max_q_size=max_q_size,
|
||||
nb_worker=nb_worker,
|
||||
pickle_safe=pickle_safe)
|
||||
|
||||
while processed_samples < val_samples:
|
||||
generator_output = None
|
||||
while not _stop.is_set():
|
||||
if not data_gen_queue.empty():
|
||||
generator_output = data_gen_queue.get()
|
||||
break
|
||||
enqueuer = None
|
||||
|
||||
try:
|
||||
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
||||
enqueuer.start(nb_worker=nb_worker, max_q_size=max_q_size)
|
||||
|
||||
while processed_samples < val_samples:
|
||||
generator_output = None
|
||||
while enqueuer.is_running():
|
||||
if not enqueuer.queue.empty():
|
||||
generator_output = enqueuer.queue.get()
|
||||
break
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
|
||||
if isinstance(generator_output, tuple):
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
else:
|
||||
time.sleep(wait_time)
|
||||
x = generator_output
|
||||
|
||||
if isinstance(generator_output, tuple):
|
||||
if len(generator_output) == 2:
|
||||
x, y = generator_output
|
||||
sample_weight = None
|
||||
elif len(generator_output) == 3:
|
||||
x, y, sample_weight = generator_output
|
||||
else:
|
||||
_stop.set()
|
||||
raise ValueError('output of generator should be a tuple '
|
||||
'(x, y, sample_weight) '
|
||||
'or (x, y). Found: ' +
|
||||
str(generator_output))
|
||||
else:
|
||||
x = generator_output
|
||||
|
||||
try:
|
||||
outs = self.predict_on_batch(x)
|
||||
except:
|
||||
_stop.set()
|
||||
raise
|
||||
|
||||
if isinstance(x, list):
|
||||
nb_samples = len(x[0])
|
||||
elif isinstance(x, dict):
|
||||
nb_samples = len(list(x.values())[0])
|
||||
else:
|
||||
nb_samples = len(x)
|
||||
if isinstance(x, list):
|
||||
nb_samples = len(x[0])
|
||||
elif isinstance(x, dict):
|
||||
nb_samples = len(list(x.values())[0])
|
||||
else:
|
||||
nb_samples = len(x)
|
||||
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
if not isinstance(outs, list):
|
||||
outs = [outs]
|
||||
|
||||
if len(all_outs) == 0:
|
||||
for out in outs:
|
||||
shape = (val_samples,) + out.shape[1:]
|
||||
all_outs.append(np.zeros(shape, dtype=K.floatx()))
|
||||
if len(all_outs) == 0:
|
||||
for out in outs:
|
||||
shape = (val_samples,) + out.shape[1:]
|
||||
all_outs.append(np.zeros(shape, dtype=K.floatx()))
|
||||
|
||||
for i, out in enumerate(outs):
|
||||
all_outs[i][processed_samples:(processed_samples + nb_samples)] = out
|
||||
processed_samples += nb_samples
|
||||
for i, out in enumerate(outs):
|
||||
all_outs[i][processed_samples:(processed_samples + nb_samples)] = out
|
||||
processed_samples += nb_samples
|
||||
|
||||
finally:
|
||||
if enqueuer is not None:
|
||||
enqueuer.stop()
|
||||
|
||||
_stop.set()
|
||||
if pickle_safe:
|
||||
# Terminate all daemon processes
|
||||
for p in generator_threads:
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
data_gen_queue.close()
|
||||
if len(all_outs) == 1:
|
||||
return all_outs[0]
|
||||
return all_outs
|
||||
|
@ -19,6 +19,7 @@ from ..engine import Layer
|
||||
from ..engine import Merge
|
||||
from ..utils.generic_utils import func_dump
|
||||
from ..utils.generic_utils import func_load
|
||||
from ..utils.generic_utils import get_from_module
|
||||
|
||||
|
||||
class Masking(Layer):
|
||||
@ -676,7 +677,7 @@ class Lambda(Layer):
|
||||
|
||||
function_type = config.pop('function_type')
|
||||
if function_type == 'function':
|
||||
function = globs[config['function']]
|
||||
function = get_from_module(config['function'], globs, 'core')
|
||||
elif function_type == 'lambda':
|
||||
function = func_load(config['function'], globs=globs)
|
||||
else:
|
||||
@ -684,7 +685,7 @@ class Lambda(Layer):
|
||||
|
||||
output_shape_type = config.pop('output_shape_type')
|
||||
if output_shape_type == 'function':
|
||||
output_shape = globs[config['output_shape']]
|
||||
output_shape = get_from_module(config['output_shape'], globs, 'core')
|
||||
elif output_shape_type == 'lambda':
|
||||
output_shape = func_load(config['output_shape'], globs=globs)
|
||||
else:
|
||||
|
@ -68,10 +68,7 @@ class TimeDistributed(Wrapper):
|
||||
|
||||
The output will then have shape `(32, 10, 8)`.
|
||||
|
||||
Note this is strictly equivalent to
|
||||
using `layers.core.TimeDistributedDense`.
|
||||
However what is different about `TimeDistributed`
|
||||
is that it can be used with arbitrary layers, not just `Dense`,
|
||||
`TimeDistributed` can be used with arbitrary layers, not just `Dense`,
|
||||
for instance with a `Convolution2D` layer:
|
||||
|
||||
```python
|
||||
|
@ -3,7 +3,7 @@ from __future__ import absolute_import
|
||||
from six.moves import zip
|
||||
|
||||
from . import backend as K
|
||||
from .utils.generic_utils import get_from_module
|
||||
from .utils.generic_utils import get_from_module, get_custom_objects
|
||||
|
||||
if K.backend() == 'tensorflow':
|
||||
import tensorflow as tf
|
||||
@ -42,6 +42,8 @@ def optimizer_from_config(config, custom_objects=None):
|
||||
class_name = config['class_name']
|
||||
if custom_objects and class_name in custom_objects:
|
||||
cls = custom_objects[class_name]
|
||||
elif class_name in get_custom_objects():
|
||||
cls = get_custom_objects()[class_name]
|
||||
else:
|
||||
if class_name.lower() not in all_classes:
|
||||
raise ValueError('Optimizer class not found:', class_name)
|
||||
@ -211,6 +213,9 @@ class RMSprop(Optimizer):
|
||||
rho: float >= 0.
|
||||
epsilon: float >= 0. Fuzz factor.
|
||||
decay: float >= 0. Learning rate decay over each update.
|
||||
|
||||
# References
|
||||
- [rmsprop: Divide the gradient by a running average of its recent magnitude](http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)
|
||||
"""
|
||||
|
||||
def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0.,
|
||||
|
@ -226,7 +226,7 @@ def array_to_img(x, dim_ordering='default', scale=True):
|
||||
if dim_ordering == 'th':
|
||||
x = x.transpose(1, 2, 0)
|
||||
if scale:
|
||||
x += max(-np.min(x), 0)
|
||||
x = x + max(-np.min(x), 0)
|
||||
x_max = np.max(x)
|
||||
if x_max != 0:
|
||||
x /= x_max
|
||||
|
@ -9,10 +9,91 @@ import six
|
||||
import marshal
|
||||
import types as python_types
|
||||
|
||||
_GLOBAL_CUSTOM_OBJECTS = {}
|
||||
|
||||
|
||||
class CustomObjectScope(object):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
|
||||
Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist within the enclosing `with` statement. At end of the `with`
|
||||
statement, global custom objects are reverted to state at beginning of the `with` statement.
|
||||
|
||||
# Example
|
||||
|
||||
Consider a custom object `MyObject`
|
||||
|
||||
```python
|
||||
with CustomObjectScope({"MyObject":MyObject}):
|
||||
layer = Dense(..., W_regularizer="MyObject")
|
||||
# save, load, etc. will recognize custom object by name
|
||||
```
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
self.custom_objects = args
|
||||
self.backup = None
|
||||
|
||||
def __enter__(self):
|
||||
self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
|
||||
for objects in self.custom_objects:
|
||||
_GLOBAL_CUSTOM_OBJECTS.update(objects)
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
_GLOBAL_CUSTOM_OBJECTS.clear()
|
||||
_GLOBAL_CUSTOM_OBJECTS.update(self.backup)
|
||||
|
||||
|
||||
def custom_object_scope(*args):
|
||||
"""Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
|
||||
|
||||
Convenience wrapper for `CustomObjectScope`. Code within a `with` statement will be able to access custom objects
|
||||
by name. Changes to global custom objects persist within the enclosing `with` statement. At end of the `with`
|
||||
statement, global custom objects are reverted to state at beginning of the `with` statement.
|
||||
|
||||
# Example
|
||||
|
||||
Consider a custom object `MyObject`
|
||||
|
||||
```python
|
||||
with custom_object_scope({"MyObject":MyObject}):
|
||||
layer = Dense(..., W_regularizer="MyObject")
|
||||
# save, load, etc. will recognize custom object by name
|
||||
```
|
||||
|
||||
# Arguments
|
||||
*args: Variable length list of dictionaries of name, class pairs to add to custom objects.
|
||||
|
||||
# Returns
|
||||
Object of type `CustomObjectScope`.
|
||||
"""
|
||||
return CustomObjectScope(*args)
|
||||
|
||||
|
||||
def get_custom_objects():
|
||||
"""Retrieves a live reference to the global dictionary of custom objects (`_GLOBAL_CUSTOM_OBJECTS`).
|
||||
|
||||
Updating and clearing custom objects using `custom_object_scope` is preferred, but `get_custom_objects` can
|
||||
be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
|
||||
|
||||
# Example
|
||||
|
||||
```python
|
||||
get_custom_objects().clear()
|
||||
get_custom_objects()["MyObject"] = MyObject
|
||||
```
|
||||
|
||||
# Returns
|
||||
Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
|
||||
"""
|
||||
return _GLOBAL_CUSTOM_OBJECTS
|
||||
|
||||
|
||||
def get_from_module(identifier, module_params, module_name,
|
||||
instantiate=False, kwargs=None):
|
||||
"""Retrieves a class of function member of a module.
|
||||
"""Retrieves a class or function member of a module.
|
||||
|
||||
First checks `_GLOBAL_CUSTOM_OBJECTS` for `module_name`, then checks `module_params`.
|
||||
|
||||
# Arguments
|
||||
identifier: the object to retrieve. It could be specified
|
||||
@ -34,7 +115,11 @@ def get_from_module(identifier, module_params, module_name,
|
||||
ValueError: if the identifier cannot be found.
|
||||
"""
|
||||
if isinstance(identifier, six.string_types):
|
||||
res = module_params.get(identifier)
|
||||
res = None
|
||||
if identifier in _GLOBAL_CUSTOM_OBJECTS:
|
||||
res = _GLOBAL_CUSTOM_OBJECTS[identifier]
|
||||
if not res:
|
||||
res = module_params.get(identifier)
|
||||
if not res:
|
||||
raise ValueError('Invalid ' + str(module_name) + ': ' +
|
||||
str(identifier))
|
||||
@ -46,7 +131,11 @@ def get_from_module(identifier, module_params, module_name,
|
||||
return res
|
||||
elif isinstance(identifier, dict):
|
||||
name = identifier.pop('name')
|
||||
res = module_params.get(name)
|
||||
res = None
|
||||
if name in _GLOBAL_CUSTOM_OBJECTS:
|
||||
res = _GLOBAL_CUSTOM_OBJECTS[name]
|
||||
if not res:
|
||||
res = module_params.get(name)
|
||||
if res:
|
||||
return res(**identifier)
|
||||
else:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import print_function
|
||||
import inspect
|
||||
|
||||
from .generic_utils import get_from_module
|
||||
from .generic_utils import get_from_module, get_custom_objects
|
||||
from .np_utils import convert_kernel
|
||||
from ..layers import *
|
||||
from ..models import Model, Sequential
|
||||
@ -22,8 +22,7 @@ def layer_from_config(config, custom_objects=None):
|
||||
# Insert custom layers into globals so they can
|
||||
# be accessed by `get_from_module`.
|
||||
if custom_objects:
|
||||
for cls_key in custom_objects:
|
||||
globals()[cls_key] = custom_objects[cls_key]
|
||||
get_custom_objects().update(custom_objects)
|
||||
|
||||
class_name = config['class_name']
|
||||
|
||||
|
@ -84,7 +84,7 @@ class BaseWrapper(object):
|
||||
if params_name not in legal_params:
|
||||
raise ValueError('{} is not a legal parameter'.format(params_name))
|
||||
|
||||
def get_params(self, _):
|
||||
def get_params(self, **params):
|
||||
"""Gets parameters for this estimator.
|
||||
|
||||
# Returns
|
||||
|
4
setup.py
4
setup.py
@ -3,12 +3,12 @@ from setuptools import find_packages
|
||||
|
||||
|
||||
setup(name='Keras',
|
||||
version='1.2.0',
|
||||
version='1.2.1',
|
||||
description='Deep Learning for Python',
|
||||
author='Francois Chollet',
|
||||
author_email='francois.chollet@gmail.com',
|
||||
url='https://github.com/fchollet/keras',
|
||||
download_url='https://github.com/fchollet/keras/tarball/1.2.0',
|
||||
download_url='https://github.com/fchollet/keras/tarball/1.2.1',
|
||||
license='MIT',
|
||||
install_requires=['theano', 'pyyaml', 'six'],
|
||||
extras_require={
|
||||
|
@ -81,6 +81,12 @@ class TestBackend(object):
|
||||
check_single_tensor_operation('reverse', (4, 3, 2), axes=1)
|
||||
check_single_tensor_operation('reverse', (4, 3, 2), axes=(1, 2))
|
||||
|
||||
def test_batch_dot_shape(self):
|
||||
with pytest.raises(ValueError):
|
||||
x_batch = KTF.ones(shape=(32, 20))
|
||||
y_batch = KTF.ones(shape=(32, 20))
|
||||
xy_batch_dot = KTF.batch_dot(x_batch, y_batch, axes=1)
|
||||
|
||||
def test_shape_operations(self):
|
||||
# concatenate
|
||||
xval = np.random.random((4, 3))
|
||||
|
@ -177,6 +177,102 @@ def test_multiprocessing_evaluating():
|
||||
assert reached_end
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_multiprocessing_fit_error():
|
||||
|
||||
batch_size = 32
|
||||
good_batches = 5
|
||||
|
||||
def myGenerator():
|
||||
"""Raises an exception after a few good batches"""
|
||||
for i in range(good_batches):
|
||||
yield (np.random.randint(batch_size, 256, (500, 2)),
|
||||
np.random.randint(batch_size, 2, 500))
|
||||
raise RuntimeError
|
||||
|
||||
model = Sequential()
|
||||
model.add(Dense(1, input_shape=(2, )))
|
||||
model.compile(loss='mse', optimizer='adadelta')
|
||||
|
||||
samples = batch_size * (good_batches + 1)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.fit_generator(
|
||||
myGenerator(), samples, 1,
|
||||
nb_worker=4, pickle_safe=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.fit_generator(
|
||||
myGenerator(), samples, 1,
|
||||
pickle_safe=False,
|
||||
)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_multiprocessing_evaluate_error():
|
||||
|
||||
batch_size = 32
|
||||
good_batches = 5
|
||||
|
||||
def myGenerator():
|
||||
"""Raises an exception after a few good batches"""
|
||||
for i in range(good_batches):
|
||||
yield (np.random.randint(batch_size, 256, (500, 2)),
|
||||
np.random.randint(batch_size, 2, 500))
|
||||
raise RuntimeError
|
||||
|
||||
model = Sequential()
|
||||
model.add(Dense(1, input_shape=(2, )))
|
||||
model.compile(loss='mse', optimizer='adadelta')
|
||||
|
||||
samples = batch_size * (good_batches + 1)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.evaluate_generator(
|
||||
myGenerator(), samples, 1,
|
||||
nb_worker=4, pickle_safe=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.evaluate_generator(
|
||||
myGenerator(), samples, 1,
|
||||
pickle_safe=False,
|
||||
)
|
||||
|
||||
|
||||
@keras_test
|
||||
def test_multiprocessing_predict_error():
|
||||
|
||||
batch_size = 32
|
||||
good_batches = 5
|
||||
|
||||
def myGenerator():
|
||||
"""Raises an exception after a few good batches"""
|
||||
for i in range(good_batches):
|
||||
yield (np.random.randint(batch_size, 256, (500, 2)),
|
||||
np.random.randint(batch_size, 2, 500))
|
||||
raise RuntimeError
|
||||
|
||||
model = Sequential()
|
||||
model.add(Dense(1, input_shape=(2, )))
|
||||
model.compile(loss='mse', optimizer='adadelta')
|
||||
|
||||
samples = batch_size * (good_batches + 1)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.predict_generator(
|
||||
myGenerator(), samples, 1,
|
||||
nb_worker=4, pickle_safe=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
model.predict_generator(
|
||||
myGenerator(), samples, 1,
|
||||
pickle_safe=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
pytest.main([__file__])
|
||||
|
@ -7,7 +7,7 @@ np.random.seed(1337)
|
||||
|
||||
from keras import backend as K
|
||||
from keras.models import Sequential
|
||||
from keras.layers.core import Dense, Activation, Merge, Lambda
|
||||
from keras.layers.core import Dense, Activation, Merge, Lambda, Reshape
|
||||
from keras.utils import np_utils
|
||||
from keras.utils.test_utils import get_test_data, keras_test
|
||||
from keras.models import model_from_json, model_from_yaml
|
||||
@ -287,14 +287,17 @@ def test_merge_dot():
|
||||
|
||||
left = Sequential()
|
||||
left.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
|
||||
left.add(Reshape((nb_hidden, 1)))
|
||||
left.add(Activation('relu'))
|
||||
|
||||
right = Sequential()
|
||||
right.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
|
||||
right.add(Reshape((nb_hidden, 1)))
|
||||
right.add(Activation('relu'))
|
||||
|
||||
model = Sequential()
|
||||
model.add(Merge([left, right], mode='dot', dot_axes=1))
|
||||
model.add(Reshape((1,)))
|
||||
model.add(Dense(nb_class))
|
||||
model.add(Activation('softmax'))
|
||||
|
||||
@ -302,14 +305,17 @@ def test_merge_dot():
|
||||
|
||||
left = Sequential()
|
||||
left.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
|
||||
left.add(Reshape((nb_hidden, 1)))
|
||||
left.add(Activation('relu'))
|
||||
|
||||
right = Sequential()
|
||||
right.add(Dense(input_dim=input_dim, output_dim=nb_hidden))
|
||||
right.add(Reshape((nb_hidden, 1)))
|
||||
right.add(Activation('relu'))
|
||||
|
||||
model = Sequential()
|
||||
model.add(Merge([left, right], mode='dot', dot_axes=[1, 1]))
|
||||
model.add(Reshape((1,)))
|
||||
model.add(Dense(nb_class))
|
||||
model.add(Activation('softmax'))
|
||||
|
||||
|
30
tests/keras/utils/test_generic_utils.py
Normal file
30
tests/keras/utils/test_generic_utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pytest
|
||||
import keras
|
||||
from keras import backend as K
|
||||
from keras.utils.generic_utils import custom_object_scope, get_custom_objects, get_from_module
|
||||
|
||||
|
||||
def test_custom_object_scope_adds_objects():
|
||||
get_custom_objects().clear()
|
||||
assert (len(get_custom_objects()) == 0)
|
||||
with custom_object_scope({"Test1": object, "Test2": object}, {"Test3": object}):
|
||||
assert (len(get_custom_objects()) == 3)
|
||||
assert (len(get_custom_objects()) == 0)
|
||||
|
||||
|
||||
class CustomObject(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_from_module_uses_custom_object():
|
||||
get_custom_objects().clear()
|
||||
assert (get_from_module("CustomObject", globals(), "test_generic_utils") == CustomObject)
|
||||
with pytest.raises(ValueError):
|
||||
get_from_module("TestObject", globals(), "test_generic_utils")
|
||||
with custom_object_scope({"TestObject": CustomObject}):
|
||||
assert (get_from_module("TestObject", globals(), "test_generic_utils") == CustomObject)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
Loading…
Reference in New Issue
Block a user