class JaxLayer: def _post_build(self): """Can be overriden to perform post-build actions.""" pass