keras/setup.py
Matt Watson 7924aff566
Attempt to add support for saving/loading bfloat16 (#19091)
Currently attempting to save or load weights in bfloat16 will fail.
There may be better ways to do this, but the approach jax and tf seem
to take is to use the ml-dtypes library to allow bfloat16 to work with
numpy.

This is further compounded by the h5py format, which saves bfloat16 as
a void type. This implementation currently just assumes any two byte
void type is bfloat16, which seems a bit hacky. Quite possibly a better
way to do this.
2024-01-24 14:38:25 -08:00

67 lines
1.8 KiB
Python

"""Setup script."""
import os
import pathlib
from setuptools import find_packages
from setuptools import setup
def read(rel_path):
here = os.path.abspath(os.path.dirname(__file__))
with open(os.path.join(here, rel_path)) as fp:
return fp.read()
def get_version(rel_path):
for line in read(rel_path).splitlines():
if line.startswith("__version__"):
delim = '"' if '"' in line else "'"
return line.split(delim)[1]
raise RuntimeError("Unable to find version string.")
HERE = pathlib.Path(__file__).parent
README = (HERE / "README.md").read_text()
if os.path.exists("keras/version.py"):
VERSION = get_version("keras/version.py")
else:
VERSION = get_version("keras/__init__.py")
setup(
name="keras",
description="Multi-backend Keras.",
long_description_content_type="text/markdown",
long_description=README,
version=VERSION,
url="https://github.com/keras-team/keras",
author="Keras team",
author_email="keras-users@googlegroups.com",
license="Apache License 2.0",
install_requires=[
"absl-py",
"numpy",
"rich",
"namex",
"h5py",
"dm-tree",
"ml-dtypes",
],
# Supported Python versions
python_requires=">=3.9",
classifiers=[
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3 :: Only",
"Operating System :: Unix",
"Operating System :: MacOS",
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering",
"Topic :: Software Development",
],
packages=find_packages(exclude=("*_test.py",)),
)