"""Utilities related to model visualization.""" import os import sys from keras_core.api_export import keras_core_export from keras_core.operations.function import make_node_key from keras_core.utils import io_utils try: # pydot-ng is a fork of pydot that is better maintained. import pydot_ng as pydot except ImportError: # pydotplus is an improved version of pydot try: import pydotplus as pydot except ImportError: # Fall back on pydot if necessary. try: import pydot except ImportError: pydot = None def check_pydot(): """Returns True if PyDot is available.""" return pydot is not None def check_graphviz(): """Returns True if both PyDot and Graphviz are available.""" if not check_pydot(): return False try: # Attempt to create an image of a blank graph # to check the pydot/graphviz installation. pydot.Dot.create(pydot.Dot()) return True except (OSError, pydot.InvocationException): return False def add_edge(dot, src, dst): if not dot.get_edge(src, dst): edge = pydot.Edge(src, dst) edge.set("penwidth", "2") dot.add_edge(edge) def get_layer_activation_name(layer): if hasattr(layer.activation, "name"): activation_name = layer.activation.name elif hasattr(layer.activation, "__name__"): activation_name = layer.activation.__name__ else: activation_name = str(layer.activation) return activation_name def make_layer_label(layer, **kwargs): class_name = layer.__class__.__name__ show_layer_names = kwargs.pop("show_layer_names") show_layer_activations = kwargs.pop("show_layer_activations") show_dtype = kwargs.pop("show_dtype") show_shapes = kwargs.pop("show_shapes") show_trainable = kwargs.pop("show_trainable") if kwargs: raise ValueError(f"Invalid kwargs: {kwargs}") table = ( '<' ) colspan = max( 1, sum(int(x) for x in (show_dtype, show_shapes, show_trainable)) ) if show_layer_names: table += ( f'" ) else: table += ( f'" ) if ( show_layer_activations and hasattr(layer, "activation") and layer.activation is not None ): table += ( f'" ) cols = [] if show_shapes: shape = None try: shape = layer.output.shape except ValueError: pass cols.append( ( '" ) ) if show_dtype: dtype = None try: dtype = layer.output.dtype except ValueError: pass cols.append( ( '" ) ) if show_trainable and hasattr(layer, "trainable") and layer.weights: if layer.trainable: cols.append( ( '" ) ) else: cols.append( ( '" ) ) if cols: colspan = len(cols) else: colspan = 1 if cols: table += "" + "".join(cols) + "" table += "
' '' f"{layer.name} ({class_name})" "
' '' f"{class_name}" "
' '' f"Activation: {get_layer_activation_name(layer)}" "
' f'Output shape: {shape or "?"}' "' f'Output dtype: {dtype or "?"}' "' '' "Trainable' '' "Non-trainable
>" return table def make_node(layer, **kwargs): node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs)) node.set("fontname", "Helvetica") node.set("border", "0") node.set("margin", "0") return node @keras_core_export("keras_core.utils.model_to_dot") def model_to_dot( model, show_shapes=False, show_dtype=False, show_layer_names=True, rankdir="TB", expand_nested=False, dpi=200, subgraph=False, show_layer_activations=False, show_trainable=False, **kwargs, ): """Convert a Keras model to dot format. Args: model: A Keras model instance. show_shapes: whether to display shape information. show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: `"TB"` creates a vertical plot; `"LR"` creates a horizontal plot. expand_nested: whether to expand nested Functional models into clusters. dpi: Image resolution in dots per inch. subgraph: whether to return a `pydot.Cluster` instance. show_layer_activations: Display layer activations (only for layers that have an `activation` property). show_trainable: whether to display if a layer is trainable. Returns: A `pydot.Dot` instance representing the Keras model or a `pydot.Cluster` instance representing nested model if `subgraph=True`. """ if not model.built: raise ValueError( "This model has not yet been built. " "Build the model first by calling `build()` or by calling " "the model on a batch of data." ) from keras_core.models import functional from keras_core.models import sequential # from keras_core.layers import Wrapper if not check_pydot(): raise ImportError( "You must install pydot (`pip install pydot`) for " "model_to_dot to work." ) if subgraph: dot = pydot.Cluster(style="dashed", graph_name=model.name) dot.set("label", model.name) dot.set("labeljust", "l") else: dot = pydot.Dot() dot.set("rankdir", rankdir) dot.set("concentrate", True) dot.set("dpi", dpi) dot.set("splines", "ortho") dot.set_node_defaults(shape="record") if kwargs.pop("layer_range", None) is not None: raise ValueError("Argument `layer_range` is no longer supported.") if kwargs: raise ValueError(f"Unrecognized keyword arguments: {kwargs}") kwargs = { "show_layer_names": show_layer_names, "show_layer_activations": show_layer_activations, "show_dtype": show_dtype, "show_shapes": show_shapes, "show_trainable": show_trainable, } if isinstance(model, sequential.Sequential): # TODO layers = model.layers elif not isinstance(model, functional.Functional): # We treat subclassed models as a single node. node = make_node(model, **kwargs) dot.add_node(node) return dot else: layers = model._operations # Create graph nodes. sub_n_first_node = {} sub_n_last_node = {} for i, layer in enumerate(layers): # Process nested functional models. if expand_nested and isinstance(layer, functional.Functional): submodel = model_to_dot( layer, show_shapes, show_dtype, show_layer_names, rankdir, expand_nested, subgraph=True, show_layer_activations=show_layer_activations, show_trainable=show_trainable, ) # sub_n : submodel sub_n_nodes = submodel.get_nodes() sub_n_first_node[layer.name] = sub_n_nodes[0] sub_n_last_node[layer.name] = sub_n_nodes[-1] dot.add_subgraph(submodel) else: node = make_node(layer, **kwargs) dot.add_node(node) # Connect nodes with edges. # Sequential case. if isinstance(model, sequential.Sequential): for i in range(len(layers) - 1): inbound_layer_id = str(id(layers[i])) layer_id = str(id(layers[i + 1])) add_edge(dot, inbound_layer_id, layer_id) return dot # Functional case. for i, layer in enumerate(layers): layer_id = str(id(layer)) for i, node in enumerate(layer._inbound_nodes): node_key = make_node_key(layer, i) if node_key in model._nodes: for parent_node in node.parent_nodes: inbound_layer = parent_node.operation inbound_layer_id = str(id(inbound_layer)) if not expand_nested: assert dot.get_node(inbound_layer_id) assert dot.get_node(layer_id) add_edge(dot, inbound_layer_id, layer_id) else: # if inbound_layer is not Functional if not isinstance(inbound_layer, functional.Functional): # if current layer is not Functional if not isinstance(layer, functional.Functional): assert dot.get_node(inbound_layer_id) assert dot.get_node(layer_id) add_edge(dot, inbound_layer_id, layer_id) # if current layer is Functional elif isinstance(layer, functional.Functional): add_edge( dot, inbound_layer_id, sub_n_first_node[layer.name].get_name(), ) # if inbound_layer is Functional elif isinstance(inbound_layer, functional.Functional): name = sub_n_last_node[ inbound_layer.name ].get_name() if isinstance(layer, functional.Functional): output_name = sub_n_first_node[ layer.name ].get_name() add_edge(dot, name, output_name) else: add_edge(dot, name, layer_id) return dot @keras_core_export("keras_core.utils.plot_model") def plot_model( model, to_file="model.png", show_shapes=False, show_dtype=False, show_layer_names=False, rankdir="TB", expand_nested=False, dpi=200, show_layer_activations=False, show_trainable=False, **kwargs, ): """Converts a Keras model to dot format and save to a file. Example: ```python inputs = ... outputs = ... model = keras_core.Model(inputs=inputs, outputs=outputs) dot_img_file = '/tmp/model_1.png' keras_core.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) ``` Args: model: A Keras model instance to_file: File name of the plot image. show_shapes: whether to display shape information. show_dtype: whether to display layer dtypes. show_layer_names: whether to display layer names. rankdir: `rankdir` argument passed to PyDot, a string specifying the format of the plot: `"TB"` creates a vertical plot; `"LR"` creates a horizontal plot. expand_nested: whether to expand nested Functional models into clusters. dpi: Image resolution in dots per inch. show_layer_activations: Display layer activations (only for layers that have an `activation` property). show_trainable: whether to display if a layer is trainable. Returns: A Jupyter notebook Image object if Jupyter is installed. This enables in-line display of the model plots in notebooks. """ if not model.built: raise ValueError( "This model has not yet been built. " "Build the model first by calling `build()` or by calling " "the model on a batch of data." ) if not check_pydot(): message = ( "You must install pydot (`pip install pydot`) " "for `plot_model` to work." ) if "IPython.core.magics.namespace" in sys.modules: # We don't raise an exception here in order to avoid crashing # notebook tests where graphviz is not available. io_utils.print_msg(message) return else: raise ImportError(message) if not check_graphviz(): message = ( "You must install graphviz " "(see instructions at https://graphviz.gitlab.io/download/) " "for `plot_model` to work." ) if "IPython.core.magics.namespace" in sys.modules: # We don't raise an exception here in order to avoid crashing # notebook tests where graphviz is not available. io_utils.print_msg(message) return else: raise ImportError(message) if kwargs.pop("layer_range", None) is not None: raise ValueError("Argument `layer_range` is no longer supported.") if kwargs: raise ValueError(f"Unrecognized keyword arguments: {kwargs}") dot = model_to_dot( model, show_shapes=show_shapes, show_dtype=show_dtype, show_layer_names=show_layer_names, rankdir=rankdir, expand_nested=expand_nested, dpi=dpi, show_layer_activations=show_layer_activations, show_trainable=show_trainable, ) to_file = str(to_file) if dot is None: return _, extension = os.path.splitext(to_file) if not extension: extension = "png" else: extension = extension[1:] # Save image to disk. dot.write(to_file, format=extension) # Return the image as a Jupyter Image object, to be displayed in-line. # Note that we cannot easily detect whether the code is running in a # notebook, and thus we always return the Image if Jupyter is available. if extension != "pdf": try: from IPython import display return display.Image(filename=to_file) except ImportError: pass