Improve int8 for Embedding (#19595)

This commit is contained in:
james77777778 2024-04-24 04:36:14 +08:00 committed by GitHub
parent 1ac17cc143
commit c61857d380
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -315,7 +315,6 @@ class Embedding(Layer):
embeddings_initializer="zeros",
embeddings_scale_initializer="ones",
):
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
self._embeddings = self.add_weight(
name="embeddings",
shape=(self.input_dim, self.output_dim),
@ -323,9 +322,12 @@ class Embedding(Layer):
dtype="int8",
trainable=False,
)
# We choose to reduce the axis of `output_dim` because, typically,
# `input_dim` is larger than `output_dim`. This reduces quantization
# error.
self.embeddings_scale = self.add_weight(
name="embeddings_scale",
shape=(self.output_dim,),
shape=(self.input_dim,),
initializer=embeddings_scale_initializer,
trainable=False,
)
@ -345,11 +347,12 @@ class Embedding(Layer):
# not needed
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
inputs = ops.cast(inputs, "int32")
embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
outputs = ops.take(self._embeddings, inputs, axis=0)
# De-scale outputs
outputs = ops.cast(outputs, self.compute_dtype)
outputs = ops.divide(
outputs, ops.expand_dims(self.embeddings_scale, axis=0)
ops.cast(outputs, dtype=self.compute_dtype),
ops.expand_dims(embeddings_scale, axis=-1),
)
if self.lora_enabled:
lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
@ -379,14 +382,12 @@ class Embedding(Layer):
self._tracker.unlock()
if mode == "int8":
# Configure `self.inputs_quantizer`
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
# Quantize `self._embeddings` to int8 and compute corresponding
# scale
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
self._embeddings, axis=0
self._embeddings, axis=-1
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=0)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
# Utilize a lambda expression as an initializer to prevent adding a
@ -412,15 +413,15 @@ class Embedding(Layer):
# Dequantize & quantize to merge lora weights into embeddings
# Note that this is a lossy compression
embeddings_value = ops.divide(
embeddings_value, embeddings_scale
embeddings_value, ops.expand_dims(embeddings_scale, axis=-1)
)
embeddings_value = ops.add(
embeddings_value,
ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b),
)
embeddings_value, embeddings_scale = (
quantizers.abs_max_quantize(embeddings_value, axis=0)
quantizers.abs_max_quantize(embeddings_value, axis=-1)
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=0)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
return embeddings_value, embeddings_scale
return self.embeddings, None