Attempt to add support for saving/loading bfloat16 (#19091)
Currently attempting to save or load weights in bfloat16 will fail. There may be better ways to do this, but the approach jax and tf seem to take is to use the ml-dtypes library to allow bfloat16 to work with numpy. This is further compounded by the h5py format, which saves bfloat16 as a void type. This implementation currently just assumes any two byte void type is bfloat16, which seems a bit hacky. Quite possibly a better way to do this.
This commit is contained in:
parent
bf1f463e0e
commit
7924aff566
@ -1,7 +1,9 @@
|
||||
import types
|
||||
|
||||
import h5py
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
import tree
|
||||
from jax.tree_util import Partial
|
||||
@ -61,10 +63,20 @@ def convert_to_tensor(x, dtype=None, sparse=None):
|
||||
if dtype and dtype != x.dtype:
|
||||
return x.value.astype(dtype)
|
||||
return x.value
|
||||
if isinstance(x, h5py.Dataset):
|
||||
# h5py will handle bfloat16 as an opaque dtype.
|
||||
# We assume any two byte void dtypes are in fact bfloat16 type.
|
||||
if x.dtype == np.dtype((np.void, 2)):
|
||||
x = np.array(x, dtype=ml_dtypes.bfloat16)
|
||||
# h5py Datasets do not support converting on the fly for many dtypes.
|
||||
# Instead we convert "as is" and cast.
|
||||
return jnp.asarray(x).astype(dtype)
|
||||
return jnp.asarray(x, dtype=dtype)
|
||||
|
||||
|
||||
def convert_to_numpy(x):
|
||||
if is_tensor(x) and x.dtype == "bfloat16":
|
||||
return np.asarray(x, ml_dtypes.bfloat16)
|
||||
return np.asarray(x)
|
||||
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
import h5py
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
import tree
|
||||
|
||||
@ -39,6 +41,14 @@ def convert_to_tensor(x, dtype=None, sparse=None):
|
||||
dtype = result_type(
|
||||
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
|
||||
)
|
||||
if isinstance(x, h5py.Dataset):
|
||||
# h5py will handle bfloat16 as an opaque dtype.
|
||||
# We assume any two byte void dtypes are in fact bfloat16 type.
|
||||
if x.dtype == np.dtype((np.void, 2)):
|
||||
x = np.array(x, dtype=ml_dtypes.bfloat16)
|
||||
# h5py Datasets do not support converting on the fly for many dtypes.
|
||||
# Instead we convert "as is" and cast.
|
||||
return np.array(x).astype(dtype)
|
||||
return np.array(x, dtype=dtype)
|
||||
|
||||
|
||||
|
@ -1,5 +1,7 @@
|
||||
import types
|
||||
|
||||
import h5py
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
|
||||
@ -107,6 +109,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):
|
||||
if dtype is not None:
|
||||
dtype = standardize_dtype(dtype)
|
||||
if not tf.is_tensor(x):
|
||||
# h5py will handle bfloat16 as an opaque dtype.
|
||||
# We assume any two byte void dtypes are in fact bfloat16 type.
|
||||
if isinstance(x, h5py.Dataset) and x.dtype == np.dtype((np.void, 2)):
|
||||
x = np.array(x, dtype=ml_dtypes.bfloat16)
|
||||
if dtype == "bool":
|
||||
# TensorFlow boolean conversion is stricter than other backends.
|
||||
# It does not allow ints. We convert without dtype and cast instead.
|
||||
|
@ -1,6 +1,8 @@
|
||||
import contextlib
|
||||
import os
|
||||
|
||||
import h5py
|
||||
import ml_dtypes
|
||||
import numpy as np
|
||||
import torch
|
||||
import tree
|
||||
@ -179,6 +181,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):
|
||||
x, dtype=to_torch_dtype(floatx()), device=get_device()
|
||||
)
|
||||
|
||||
# h5py will handle bfloat16 as an opaque dtype.
|
||||
# We assume any two byte void dtypes are in fact bfloat16 type.
|
||||
if isinstance(x, h5py.Dataset) and x.dtype == np.dtype((np.void, 2)):
|
||||
x = np.array(x, dtype=ml_dtypes.bfloat16)
|
||||
# Convert to np in case of any array-like that is not list or tuple.
|
||||
if not isinstance(x, (list, tuple)):
|
||||
x = np.array(x)
|
||||
@ -210,6 +216,12 @@ def convert_to_numpy(x):
|
||||
# Tensor has to be moved to CPU before converting to numpy.
|
||||
if x.is_cuda or x.is_mps:
|
||||
x = x.cpu()
|
||||
if x.dtype == torch.bfloat16:
|
||||
# Attempting to call .numpy() on a bfloat16 torch tensor leads
|
||||
# to an immediate error. Instead we upcast to float32 and then
|
||||
# convert to the numpy friendly bfloat16 type.
|
||||
# https://github.com/pytorch/pytorch/issues/90574
|
||||
return np.array(x.to(torch.float32)).astype(ml_dtypes.bfloat16)
|
||||
return np.array(x)
|
||||
|
||||
if isinstance(x, (list, tuple)):
|
||||
|
@ -188,7 +188,7 @@ class Dense(Layer):
|
||||
kernel_value = ops.convert_to_numpy(self.kernel)
|
||||
store["0"] = kernel_value
|
||||
if self.use_bias:
|
||||
store["1"] = self.bias.numpy()
|
||||
store["1"] = ops.convert_to_numpy(self.bias)
|
||||
|
||||
def load_own_variables(self, store):
|
||||
if not self.lora_enabled:
|
||||
|
@ -265,7 +265,7 @@ class EinsumDense(Layer):
|
||||
kernel_value = ops.convert_to_numpy(self.kernel)
|
||||
store["0"] = kernel_value
|
||||
if self.bias is not None:
|
||||
store["1"] = self.bias.numpy()
|
||||
store["1"] = ops.convert_to_numpy(self.bias)
|
||||
|
||||
def load_own_variables(self, store):
|
||||
if not self.lora_enabled:
|
||||
|
@ -26,6 +26,7 @@ from keras import backend
|
||||
from keras import constraints
|
||||
from keras import initializers
|
||||
from keras import mixed_precision
|
||||
from keras import ops
|
||||
from keras import regularizers
|
||||
from keras import utils
|
||||
from keras.api_export import keras_export
|
||||
@ -715,7 +716,7 @@ class Layer(BackendLayer, Operation):
|
||||
x.dtype = self.input_dtype
|
||||
return x
|
||||
elif hasattr(x, "__array__"):
|
||||
return backend.convert_to_tensor(x, dtype=self.input_dtype)
|
||||
return ops.convert_to_tensor(x, dtype=self.input_dtype)
|
||||
return x
|
||||
|
||||
# Used to avoid expensive `tree` operations in the most common case.
|
||||
@ -1119,7 +1120,7 @@ class Layer(BackendLayer, Operation):
|
||||
"""
|
||||
all_vars = self._trainable_variables + self._non_trainable_variables
|
||||
for i, v in enumerate(all_vars):
|
||||
store[f"{i}"] = v.numpy()
|
||||
store[f"{i}"] = ops.convert_to_numpy(v)
|
||||
|
||||
def load_own_variables(self, store):
|
||||
"""Loads the state of the layer.
|
||||
|
@ -3,11 +3,13 @@ import unittest.mock as mock
|
||||
|
||||
import numpy as np
|
||||
from absl import logging
|
||||
from absl.testing import parameterized
|
||||
|
||||
from keras import layers
|
||||
from keras.models import Sequential
|
||||
from keras.saving import saving_api
|
||||
from keras.testing import test_case
|
||||
from keras.testing.test_utils import named_product
|
||||
|
||||
|
||||
class SaveModelTests(test_case.TestCase):
|
||||
@ -63,23 +65,32 @@ class SaveModelTests(test_case.TestCase):
|
||||
saving_api.save_model(model, "model.png")
|
||||
|
||||
|
||||
class LoadModelTests(test_case.TestCase):
|
||||
def get_model(self):
|
||||
class LoadModelTests(test_case.TestCase, parameterized.TestCase):
|
||||
def get_model(self, dtype=None):
|
||||
return Sequential(
|
||||
[
|
||||
layers.Dense(5, input_shape=(3,)),
|
||||
layers.Dense(5, input_shape=(3,), dtype=dtype),
|
||||
layers.Softmax(),
|
||||
]
|
||||
)
|
||||
|
||||
def test_basic_load(self):
|
||||
@parameterized.named_parameters(
|
||||
[
|
||||
{"testcase_name": "bfloat16", "dtype": "bfloat16"},
|
||||
{"testcase_name": "float16", "dtype": "float16"},
|
||||
{"testcase_name": "float32", "dtype": "float32"},
|
||||
{"testcase_name": "float64", "dtype": "float64"},
|
||||
]
|
||||
)
|
||||
def test_basic_load(self, dtype):
|
||||
"""Test basic model loading."""
|
||||
model = self.get_model()
|
||||
model = self.get_model(dtype)
|
||||
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
|
||||
saving_api.save_model(model, filepath)
|
||||
|
||||
loaded_model = saving_api.load_model(filepath)
|
||||
x = np.random.uniform(size=(10, 3))
|
||||
self.assertEqual(loaded_model.weights[0].dtype, dtype)
|
||||
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
|
||||
|
||||
def test_load_unsupported_format(self):
|
||||
@ -119,25 +130,37 @@ class LoadModelTests(test_case.TestCase):
|
||||
os.remove(filepath)
|
||||
|
||||
|
||||
class LoadWeightsTests(test_case.TestCase):
|
||||
def get_model(self):
|
||||
class LoadWeightsTests(test_case.TestCase, parameterized.TestCase):
|
||||
def get_model(self, dtype=None):
|
||||
return Sequential(
|
||||
[
|
||||
layers.Dense(5, input_shape=(3,)),
|
||||
layers.Dense(5, input_shape=(3,), dtype=dtype),
|
||||
layers.Softmax(),
|
||||
]
|
||||
)
|
||||
|
||||
def test_load_keras_weights(self):
|
||||
@parameterized.named_parameters(
|
||||
named_product(
|
||||
source_dtype=["float64", "float32", "float16", "bfloat16"],
|
||||
dest_dtype=["float64", "float32", "float16", "bfloat16"],
|
||||
)
|
||||
)
|
||||
def test_load_keras_weights(self, source_dtype, dest_dtype):
|
||||
"""Test loading keras weights."""
|
||||
model = self.get_model()
|
||||
src_model = self.get_model(dtype=source_dtype)
|
||||
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
|
||||
model.save_weights(filepath)
|
||||
original_weights = model.get_weights()
|
||||
model.load_weights(filepath)
|
||||
loaded_weights = model.get_weights()
|
||||
for orig, loaded in zip(original_weights, loaded_weights):
|
||||
self.assertTrue(np.array_equal(orig, loaded))
|
||||
src_model.save_weights(filepath)
|
||||
src_weights = src_model.get_weights()
|
||||
dest_model = self.get_model(dtype=dest_dtype)
|
||||
dest_model.load_weights(filepath)
|
||||
dest_weights = dest_model.get_weights()
|
||||
for orig, loaded in zip(src_weights, dest_weights):
|
||||
self.assertAllClose(
|
||||
orig.astype("float32"),
|
||||
loaded.astype("float32"),
|
||||
atol=0.001,
|
||||
rtol=0.01,
|
||||
)
|
||||
|
||||
def test_load_h5_weights_by_name(self):
|
||||
"""Test loading h5 weights by name."""
|
||||
|
@ -9,10 +9,11 @@ pandas
|
||||
absl-py
|
||||
requests
|
||||
h5py
|
||||
ml-dtypes
|
||||
protobuf
|
||||
google
|
||||
tensorboard-plugin-profile
|
||||
rich
|
||||
build
|
||||
dm-tree
|
||||
pytest-cov
|
||||
pytest-cov
|
||||
|
1
setup.py
1
setup.py
@ -45,6 +45,7 @@ setup(
|
||||
"namex",
|
||||
"h5py",
|
||||
"dm-tree",
|
||||
"ml-dtypes",
|
||||
],
|
||||
# Supported Python versions
|
||||
python_requires=">=3.9",
|
||||
|
Loading…
Reference in New Issue
Block a user