Fix wrong gradients for torch batch_norm and moments (#19102)

* fix wront gradients for torch batch_norm and moments.

* fix the tests
This commit is contained in:
Haifeng Jin 2024-01-25 12:18:36 -08:00 committed by GitHub
parent d84e14bf14
commit adacf2c495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 31 deletions

@ -36,10 +36,12 @@ def build_keras_model(keras_module, num_classes):
keras_module.layers.Conv2D(
32, kernel_size=(3, 3), activation="relu"
),
keras_module.layers.BatchNormalization(),
keras_module.layers.MaxPooling2D(pool_size=(2, 2)),
keras_module.layers.Conv2D(
64, kernel_size=(3, 3), activation="relu"
),
keras_module.layers.BatchNormalization(scale=False, center=True),
keras_module.layers.MaxPooling2D(pool_size=(2, 2)),
keras_module.layers.Flatten(),
keras_module.layers.Dense(num_classes, activation="softmax"),

@ -684,7 +684,7 @@ def moments(x, axes, keepdims=False, synchronized=False):
# gradient is zero.
variance = torch.mean(
torch.square(x), dim=axes, keepdim=True
) - torch.square(mean.detach())
) - torch.square(mean)
if not keepdims:
mean = torch.squeeze(mean, axes)
@ -710,45 +710,30 @@ def batch_normalization(
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
):
x = convert_to_tensor(x)
mean = convert_to_tensor(mean).detach()
variance = convert_to_tensor(variance).detach()
mean = convert_to_tensor(mean)
variance = convert_to_tensor(variance)
shape = [1] * len(x.shape)
shape[axis] = mean.shape[0]
mean = torch.reshape(mean, shape)
variance = torch.reshape(variance, shape)
if offset is not None:
offset = convert_to_tensor(offset)
offset = torch.reshape(offset, shape)
else:
offset = torch.zeros_like(mean)
if scale is not None:
scale = convert_to_tensor(scale)
scale = torch.reshape(scale, shape)
else:
scale = torch.ones_like(variance)
def _batch_norm():
return tnn.batch_norm(
input=x,
running_mean=mean,
running_var=variance,
weight=scale,
bias=offset,
training=False,
eps=epsilon,
)
if axis == 1:
return _batch_norm()
if axis < 0:
axis = len(x.shape) + axis
order = list(range(len(x.shape)))
order.pop(axis)
order.insert(1, axis)
x = x.permute(order)
x = _batch_norm()
order = list(range(len(x.shape)))
order.pop(1)
order.insert(axis, 1)
return x.permute(order)
return (
x.subtract(mean)
.mul_(variance.add(epsilon).rsqrt_().mul(scale))
.add_(offset)
)
def ctc_loss(