diff --git a/keras/layers/normalization.py b/keras/layers/normalization.py index 47b92f0cf..6e48ff24c 100644 --- a/keras/layers/normalization.py +++ b/keras/layers/normalization.py @@ -140,7 +140,7 @@ class BatchNormalization(Layer): self.updates = [K.moving_average_update(self.running_mean, mean, self.momentum), K.moving_average_update(self.running_std, std, self.momentum)] - if sorted(reduction_axes) == range(K.ndim(x))[:-1]: + if K.backend() == 'tensorflow' and sorted(reduction_axes) == range(K.ndim(x))[:-1]: x_normed_running = K.batch_normalization( x, self.running_mean, self.running_std, self.beta, self.gamma,