Fix wrong gradients for torch batch_norm and moments (#19102)
* fix wront gradients for torch batch_norm and moments. * fix the tests
This commit is contained in:
parent
d84e14bf14
commit
adacf2c495
@ -36,10 +36,12 @@ def build_keras_model(keras_module, num_classes):
|
||||
keras_module.layers.Conv2D(
|
||||
32, kernel_size=(3, 3), activation="relu"
|
||||
),
|
||||
keras_module.layers.BatchNormalization(),
|
||||
keras_module.layers.MaxPooling2D(pool_size=(2, 2)),
|
||||
keras_module.layers.Conv2D(
|
||||
64, kernel_size=(3, 3), activation="relu"
|
||||
),
|
||||
keras_module.layers.BatchNormalization(scale=False, center=True),
|
||||
keras_module.layers.MaxPooling2D(pool_size=(2, 2)),
|
||||
keras_module.layers.Flatten(),
|
||||
keras_module.layers.Dense(num_classes, activation="softmax"),
|
||||
|
@ -684,7 +684,7 @@ def moments(x, axes, keepdims=False, synchronized=False):
|
||||
# gradient is zero.
|
||||
variance = torch.mean(
|
||||
torch.square(x), dim=axes, keepdim=True
|
||||
) - torch.square(mean.detach())
|
||||
) - torch.square(mean)
|
||||
|
||||
if not keepdims:
|
||||
mean = torch.squeeze(mean, axes)
|
||||
@ -710,45 +710,30 @@ def batch_normalization(
|
||||
x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3
|
||||
):
|
||||
x = convert_to_tensor(x)
|
||||
mean = convert_to_tensor(mean).detach()
|
||||
variance = convert_to_tensor(variance).detach()
|
||||
mean = convert_to_tensor(mean)
|
||||
variance = convert_to_tensor(variance)
|
||||
|
||||
shape = [1] * len(x.shape)
|
||||
shape[axis] = mean.shape[0]
|
||||
mean = torch.reshape(mean, shape)
|
||||
variance = torch.reshape(variance, shape)
|
||||
|
||||
if offset is not None:
|
||||
offset = convert_to_tensor(offset)
|
||||
offset = torch.reshape(offset, shape)
|
||||
else:
|
||||
offset = torch.zeros_like(mean)
|
||||
if scale is not None:
|
||||
scale = convert_to_tensor(scale)
|
||||
scale = torch.reshape(scale, shape)
|
||||
else:
|
||||
scale = torch.ones_like(variance)
|
||||
|
||||
def _batch_norm():
|
||||
return tnn.batch_norm(
|
||||
input=x,
|
||||
running_mean=mean,
|
||||
running_var=variance,
|
||||
weight=scale,
|
||||
bias=offset,
|
||||
training=False,
|
||||
eps=epsilon,
|
||||
)
|
||||
|
||||
if axis == 1:
|
||||
return _batch_norm()
|
||||
|
||||
if axis < 0:
|
||||
axis = len(x.shape) + axis
|
||||
|
||||
order = list(range(len(x.shape)))
|
||||
order.pop(axis)
|
||||
order.insert(1, axis)
|
||||
x = x.permute(order)
|
||||
|
||||
x = _batch_norm()
|
||||
|
||||
order = list(range(len(x.shape)))
|
||||
order.pop(1)
|
||||
order.insert(axis, 1)
|
||||
return x.permute(order)
|
||||
return (
|
||||
x.subtract(mean)
|
||||
.mul_(variance.add(epsilon).rsqrt_().mul(scale))
|
||||
.add_(offset)
|
||||
)
|
||||
|
||||
|
||||
def ctc_loss(
|
||||
|
Loading…
Reference in New Issue
Block a user