Attempt to add support for saving/loading bfloat16 (#19091)

Currently attempting to save or load weights in bfloat16 will fail.
There may be better ways to do this, but the approach jax and tf seem
to take is to use the ml-dtypes library to allow bfloat16 to work with
numpy.

This is further compounded by the h5py format, which saves bfloat16 as
a void type. This implementation currently just assumes any two byte
void type is bfloat16, which seems a bit hacky. Quite possibly a better
way to do this.
This commit is contained in:
Matt Watson 2024-01-24 14:38:25 -08:00 committed by GitHub
parent bf1f463e0e
commit 7924aff566
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 87 additions and 21 deletions

@ -1,7 +1,9 @@
import types
import h5py
import jax
import jax.numpy as jnp
import ml_dtypes
import numpy as np
import tree
from jax.tree_util import Partial
@ -61,10 +63,20 @@ def convert_to_tensor(x, dtype=None, sparse=None):
if dtype and dtype != x.dtype:
return x.value.astype(dtype)
return x.value
if isinstance(x, h5py.Dataset):
# h5py will handle bfloat16 as an opaque dtype.
# We assume any two byte void dtypes are in fact bfloat16 type.
if x.dtype == np.dtype((np.void, 2)):
x = np.array(x, dtype=ml_dtypes.bfloat16)
# h5py Datasets do not support converting on the fly for many dtypes.
# Instead we convert "as is" and cast.
return jnp.asarray(x).astype(dtype)
return jnp.asarray(x, dtype=dtype)
def convert_to_numpy(x):
if is_tensor(x) and x.dtype == "bfloat16":
return np.asarray(x, ml_dtypes.bfloat16)
return np.asarray(x)

@ -1,3 +1,5 @@
import h5py
import ml_dtypes
import numpy as np
import tree
@ -39,6 +41,14 @@ def convert_to_tensor(x, dtype=None, sparse=None):
dtype = result_type(
*[getattr(item, "dtype", type(item)) for item in tree.flatten(x)]
)
if isinstance(x, h5py.Dataset):
# h5py will handle bfloat16 as an opaque dtype.
# We assume any two byte void dtypes are in fact bfloat16 type.
if x.dtype == np.dtype((np.void, 2)):
x = np.array(x, dtype=ml_dtypes.bfloat16)
# h5py Datasets do not support converting on the fly for many dtypes.
# Instead we convert "as is" and cast.
return np.array(x).astype(dtype)
return np.array(x, dtype=dtype)

@ -1,5 +1,7 @@
import types
import h5py
import ml_dtypes
import numpy as np
import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice
@ -107,6 +109,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):
if dtype is not None:
dtype = standardize_dtype(dtype)
if not tf.is_tensor(x):
# h5py will handle bfloat16 as an opaque dtype.
# We assume any two byte void dtypes are in fact bfloat16 type.
if isinstance(x, h5py.Dataset) and x.dtype == np.dtype((np.void, 2)):
x = np.array(x, dtype=ml_dtypes.bfloat16)
if dtype == "bool":
# TensorFlow boolean conversion is stricter than other backends.
# It does not allow ints. We convert without dtype and cast instead.

@ -1,6 +1,8 @@
import contextlib
import os
import h5py
import ml_dtypes
import numpy as np
import torch
import tree
@ -179,6 +181,10 @@ def convert_to_tensor(x, dtype=None, sparse=None):
x, dtype=to_torch_dtype(floatx()), device=get_device()
)
# h5py will handle bfloat16 as an opaque dtype.
# We assume any two byte void dtypes are in fact bfloat16 type.
if isinstance(x, h5py.Dataset) and x.dtype == np.dtype((np.void, 2)):
x = np.array(x, dtype=ml_dtypes.bfloat16)
# Convert to np in case of any array-like that is not list or tuple.
if not isinstance(x, (list, tuple)):
x = np.array(x)
@ -210,6 +216,12 @@ def convert_to_numpy(x):
# Tensor has to be moved to CPU before converting to numpy.
if x.is_cuda or x.is_mps:
x = x.cpu()
if x.dtype == torch.bfloat16:
# Attempting to call .numpy() on a bfloat16 torch tensor leads
# to an immediate error. Instead we upcast to float32 and then
# convert to the numpy friendly bfloat16 type.
# https://github.com/pytorch/pytorch/issues/90574
return np.array(x.to(torch.float32)).astype(ml_dtypes.bfloat16)
return np.array(x)
if isinstance(x, (list, tuple)):

@ -188,7 +188,7 @@ class Dense(Layer):
kernel_value = ops.convert_to_numpy(self.kernel)
store["0"] = kernel_value
if self.use_bias:
store["1"] = self.bias.numpy()
store["1"] = ops.convert_to_numpy(self.bias)
def load_own_variables(self, store):
if not self.lora_enabled:

@ -265,7 +265,7 @@ class EinsumDense(Layer):
kernel_value = ops.convert_to_numpy(self.kernel)
store["0"] = kernel_value
if self.bias is not None:
store["1"] = self.bias.numpy()
store["1"] = ops.convert_to_numpy(self.bias)
def load_own_variables(self, store):
if not self.lora_enabled:

@ -26,6 +26,7 @@ from keras import backend
from keras import constraints
from keras import initializers
from keras import mixed_precision
from keras import ops
from keras import regularizers
from keras import utils
from keras.api_export import keras_export
@ -715,7 +716,7 @@ class Layer(BackendLayer, Operation):
x.dtype = self.input_dtype
return x
elif hasattr(x, "__array__"):
return backend.convert_to_tensor(x, dtype=self.input_dtype)
return ops.convert_to_tensor(x, dtype=self.input_dtype)
return x
# Used to avoid expensive `tree` operations in the most common case.
@ -1119,7 +1120,7 @@ class Layer(BackendLayer, Operation):
"""
all_vars = self._trainable_variables + self._non_trainable_variables
for i, v in enumerate(all_vars):
store[f"{i}"] = v.numpy()
store[f"{i}"] = ops.convert_to_numpy(v)
def load_own_variables(self, store):
"""Loads the state of the layer.

@ -3,11 +3,13 @@ import unittest.mock as mock
import numpy as np
from absl import logging
from absl.testing import parameterized
from keras import layers
from keras.models import Sequential
from keras.saving import saving_api
from keras.testing import test_case
from keras.testing.test_utils import named_product
class SaveModelTests(test_case.TestCase):
@ -63,23 +65,32 @@ class SaveModelTests(test_case.TestCase):
saving_api.save_model(model, "model.png")
class LoadModelTests(test_case.TestCase):
def get_model(self):
class LoadModelTests(test_case.TestCase, parameterized.TestCase):
def get_model(self, dtype=None):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Dense(5, input_shape=(3,), dtype=dtype),
layers.Softmax(),
]
)
def test_basic_load(self):
@parameterized.named_parameters(
[
{"testcase_name": "bfloat16", "dtype": "bfloat16"},
{"testcase_name": "float16", "dtype": "float16"},
{"testcase_name": "float32", "dtype": "float32"},
{"testcase_name": "float64", "dtype": "float64"},
]
)
def test_basic_load(self, dtype):
"""Test basic model loading."""
model = self.get_model()
model = self.get_model(dtype)
filepath = os.path.join(self.get_temp_dir(), "test_model.keras")
saving_api.save_model(model, filepath)
loaded_model = saving_api.load_model(filepath)
x = np.random.uniform(size=(10, 3))
self.assertEqual(loaded_model.weights[0].dtype, dtype)
self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))
def test_load_unsupported_format(self):
@ -119,25 +130,37 @@ class LoadModelTests(test_case.TestCase):
os.remove(filepath)
class LoadWeightsTests(test_case.TestCase):
def get_model(self):
class LoadWeightsTests(test_case.TestCase, parameterized.TestCase):
def get_model(self, dtype=None):
return Sequential(
[
layers.Dense(5, input_shape=(3,)),
layers.Dense(5, input_shape=(3,), dtype=dtype),
layers.Softmax(),
]
)
def test_load_keras_weights(self):
@parameterized.named_parameters(
named_product(
source_dtype=["float64", "float32", "float16", "bfloat16"],
dest_dtype=["float64", "float32", "float16", "bfloat16"],
)
)
def test_load_keras_weights(self, source_dtype, dest_dtype):
"""Test loading keras weights."""
model = self.get_model()
src_model = self.get_model(dtype=source_dtype)
filepath = os.path.join(self.get_temp_dir(), "test_weights.weights.h5")
model.save_weights(filepath)
original_weights = model.get_weights()
model.load_weights(filepath)
loaded_weights = model.get_weights()
for orig, loaded in zip(original_weights, loaded_weights):
self.assertTrue(np.array_equal(orig, loaded))
src_model.save_weights(filepath)
src_weights = src_model.get_weights()
dest_model = self.get_model(dtype=dest_dtype)
dest_model.load_weights(filepath)
dest_weights = dest_model.get_weights()
for orig, loaded in zip(src_weights, dest_weights):
self.assertAllClose(
orig.astype("float32"),
loaded.astype("float32"),
atol=0.001,
rtol=0.01,
)
def test_load_h5_weights_by_name(self):
"""Test loading h5 weights by name."""

@ -9,10 +9,11 @@ pandas
absl-py
requests
h5py
ml-dtypes
protobuf
google
tensorboard-plugin-profile
rich
build
dm-tree
pytest-cov
pytest-cov

@ -45,6 +45,7 @@ setup(
"namex",
"h5py",
"dm-tree",
"ml-dtypes",
],
# Supported Python versions
python_requires=">=3.9",