keras/integration_tests/import_test.py

118 lines
2.9 KiB
Python
Raw Normal View History

2023-09-25 16:37:22 +00:00
import os
import re
import subprocess
2023-09-25 16:41:01 +00:00
from keras import backend
2023-09-25 16:37:22 +00:00
BACKEND_REQ = {
"tensorflow": "tensorflow",
"torch": "torch torchvision",
"jax": "jax jaxlib",
}
def setup_package():
subprocess.run("rm -rf tmp_build_dir", shell=True)
build_process = subprocess.run(
"python3 pip_build.py",
capture_output=True,
text=True,
shell=True,
)
print(build_process.stdout)
match = re.search(
r"\s[^\s]*\.whl",
build_process.stdout,
)
if not match:
2023-09-25 16:41:01 +00:00
raise ValueError("Installing Keras package unsuccessful. ")
2023-09-25 16:37:22 +00:00
print(build_process.stderr)
whl_path = match.group()
return whl_path
def create_virtualenv():
env_setup = [
2023-09-25 19:42:03 +00:00
# Create virtual environment
2023-09-25 16:37:22 +00:00
"python3 -m venv test_env",
]
os.environ["PATH"] = (
"/test_env/bin/" + os.pathsep + os.environ.get("PATH", "")
)
run_commands_local(env_setup)
def manage_venv_installs(whl_path):
other_backends = list(set(BACKEND_REQ.keys()) - {backend.backend()})
install_setup = [
# Installs the backend's package and common requirements
"pip install " + BACKEND_REQ[backend.backend()],
"pip install -r requirements-common.txt",
"pip install pytest",
# Ensure other backends are uninstalled
"pip uninstall -y "
+ BACKEND_REQ[other_backends[0]]
+ " "
+ BACKEND_REQ[other_backends[1]],
# Install `.whl` package
2023-09-25 20:46:27 +00:00
"pip install " + whl_path,
2023-09-25 16:37:22 +00:00
]
run_commands_venv(install_setup)
2023-09-25 16:41:01 +00:00
def run_keras_flow():
2023-09-25 16:37:22 +00:00
test_script = [
# Runs the example script
"python -m pytest integration_tests/basic_full_flow.py",
]
run_commands_venv(test_script)
def cleanup():
cleanup_script = [
# Exits virtual environment, deletes files, and any
# miscellaneous install logs
"exit",
"rm -rf test_env",
"rm -rf tmp_build_dir",
"rm -f *+cpu",
]
run_commands_local(cleanup_script)
def run_commands_local(commands):
for command in commands:
subprocess.run(command, shell=True)
def run_commands_venv(commands):
for command in commands:
cmd_with_args = command.split(" ")
cmd_with_args[0] = "test_env/bin/" + cmd_with_args[0]
p = subprocess.Popen(cmd_with_args)
p.wait()
2023-09-25 16:41:01 +00:00
def test_keras_imports():
2023-09-25 16:37:22 +00:00
# Ensures packages from all backends are installed.
# Builds Keras core package and returns package file path.
whl_path = setup_package()
# Creates and activates a virtual environment.
create_virtualenv()
# Ensures the backend's package is installed
# and the other backends are uninstalled.
manage_venv_installs(whl_path)
# Runs test of basic flow in Keras Core.
# Tests for backend-specific imports and `model.fit()`.
2023-09-25 16:41:01 +00:00
run_keras_flow()
2023-09-25 16:37:22 +00:00
# Removes virtual environment and associated files
cleanup()
if __name__ == "__main__":
2023-09-25 16:54:45 +00:00
test_keras_imports()