diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index c3a410681..12d6dee53 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -314,7 +314,8 @@ def normalize_batch_in_training(x, gamma, beta, def batch_normalization(x, mean, std, beta, gamma, epsilon=0.0001): '''Apply batch normalization on x given mean, std, beta and gamma. ''' - normed = (x - mean) * (gamma * T.inv(std + epsilon)) + beta + normed = T.nnet.bn.batch_normalization(x, gamma, beta, mean, std + epsilon, + mode='high_mem') return normed