Improve int8 for Embedding
(#19595)
This commit is contained in:
parent
1ac17cc143
commit
c61857d380
@ -315,7 +315,6 @@ class Embedding(Layer):
|
||||
embeddings_initializer="zeros",
|
||||
embeddings_scale_initializer="ones",
|
||||
):
|
||||
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
|
||||
self._embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=(self.input_dim, self.output_dim),
|
||||
@ -323,9 +322,12 @@ class Embedding(Layer):
|
||||
dtype="int8",
|
||||
trainable=False,
|
||||
)
|
||||
# We choose to reduce the axis of `output_dim` because, typically,
|
||||
# `input_dim` is larger than `output_dim`. This reduces quantization
|
||||
# error.
|
||||
self.embeddings_scale = self.add_weight(
|
||||
name="embeddings_scale",
|
||||
shape=(self.output_dim,),
|
||||
shape=(self.input_dim,),
|
||||
initializer=embeddings_scale_initializer,
|
||||
trainable=False,
|
||||
)
|
||||
@ -345,11 +347,12 @@ class Embedding(Layer):
|
||||
# not needed
|
||||
if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"):
|
||||
inputs = ops.cast(inputs, "int32")
|
||||
embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)
|
||||
outputs = ops.take(self._embeddings, inputs, axis=0)
|
||||
# De-scale outputs
|
||||
outputs = ops.cast(outputs, self.compute_dtype)
|
||||
outputs = ops.divide(
|
||||
outputs, ops.expand_dims(self.embeddings_scale, axis=0)
|
||||
ops.cast(outputs, dtype=self.compute_dtype),
|
||||
ops.expand_dims(embeddings_scale, axis=-1),
|
||||
)
|
||||
if self.lora_enabled:
|
||||
lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)
|
||||
@ -379,14 +382,12 @@ class Embedding(Layer):
|
||||
|
||||
self._tracker.unlock()
|
||||
if mode == "int8":
|
||||
# Configure `self.inputs_quantizer`
|
||||
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
|
||||
# Quantize `self._embeddings` to int8 and compute corresponding
|
||||
# scale
|
||||
embeddings_value, embeddings_scale = quantizers.abs_max_quantize(
|
||||
self._embeddings, axis=0
|
||||
self._embeddings, axis=-1
|
||||
)
|
||||
embeddings_scale = ops.squeeze(embeddings_scale, axis=0)
|
||||
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
||||
self._untrack_variable(self._embeddings)
|
||||
del self._embeddings
|
||||
# Utilize a lambda expression as an initializer to prevent adding a
|
||||
@ -412,15 +413,15 @@ class Embedding(Layer):
|
||||
# Dequantize & quantize to merge lora weights into embeddings
|
||||
# Note that this is a lossy compression
|
||||
embeddings_value = ops.divide(
|
||||
embeddings_value, embeddings_scale
|
||||
embeddings_value, ops.expand_dims(embeddings_scale, axis=-1)
|
||||
)
|
||||
embeddings_value = ops.add(
|
||||
embeddings_value,
|
||||
ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b),
|
||||
)
|
||||
embeddings_value, embeddings_scale = (
|
||||
quantizers.abs_max_quantize(embeddings_value, axis=0)
|
||||
quantizers.abs_max_quantize(embeddings_value, axis=-1)
|
||||
)
|
||||
embeddings_scale = ops.squeeze(embeddings_scale, axis=0)
|
||||
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
|
||||
return embeddings_value, embeddings_scale
|
||||
return self.embeddings, None
|
||||
|
Loading…
Reference in New Issue
Block a user