From c61857d380d3bf5529421db65cb03a20aca969e0 Mon Sep 17 00:00:00 2001 From: james77777778 <20734616+james77777778@users.noreply.github.com> Date: Wed, 24 Apr 2024 04:36:14 +0800 Subject: [PATCH] Improve int8 for `Embedding` (#19595) --- keras/src/layers/core/embedding.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index 0de116379..bb85c3dd1 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -315,7 +315,6 @@ class Embedding(Layer): embeddings_initializer="zeros", embeddings_scale_initializer="ones", ): - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) self._embeddings = self.add_weight( name="embeddings", shape=(self.input_dim, self.output_dim), @@ -323,9 +322,12 @@ class Embedding(Layer): dtype="int8", trainable=False, ) + # We choose to reduce the axis of `output_dim` because, typically, + # `input_dim` is larger than `output_dim`. This reduces quantization + # error. self.embeddings_scale = self.add_weight( name="embeddings_scale", - shape=(self.output_dim,), + shape=(self.input_dim,), initializer=embeddings_scale_initializer, trainable=False, ) @@ -345,11 +347,12 @@ class Embedding(Layer): # not needed if backend.standardize_dtype(inputs.dtype) not in ("int32", "int64"): inputs = ops.cast(inputs, "int32") + embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0) outputs = ops.take(self._embeddings, inputs, axis=0) # De-scale outputs - outputs = ops.cast(outputs, self.compute_dtype) outputs = ops.divide( - outputs, ops.expand_dims(self.embeddings_scale, axis=0) + ops.cast(outputs, dtype=self.compute_dtype), + ops.expand_dims(embeddings_scale, axis=-1), ) if self.lora_enabled: lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0) @@ -379,14 +382,12 @@ class Embedding(Layer): self._tracker.unlock() if mode == "int8": - # Configure `self.inputs_quantizer` - self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1) # Quantize `self._embeddings` to int8 and compute corresponding # scale embeddings_value, embeddings_scale = quantizers.abs_max_quantize( - self._embeddings, axis=0 + self._embeddings, axis=-1 ) - embeddings_scale = ops.squeeze(embeddings_scale, axis=0) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) self._untrack_variable(self._embeddings) del self._embeddings # Utilize a lambda expression as an initializer to prevent adding a @@ -412,15 +413,15 @@ class Embedding(Layer): # Dequantize & quantize to merge lora weights into embeddings # Note that this is a lossy compression embeddings_value = ops.divide( - embeddings_value, embeddings_scale + embeddings_value, ops.expand_dims(embeddings_scale, axis=-1) ) embeddings_value = ops.add( embeddings_value, ops.matmul(self.lora_embeddings_a, self.lora_embeddings_b), ) embeddings_value, embeddings_scale = ( - quantizers.abs_max_quantize(embeddings_value, axis=0) + quantizers.abs_max_quantize(embeddings_value, axis=-1) ) - embeddings_scale = ops.squeeze(embeddings_scale, axis=0) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) return embeddings_value, embeddings_scale return self.embeddings, None