diff --git a/python/paddle/fluid/net_drawer.py b/python/paddle/fluid/net_drawer.py index 0b61c23d07e95acf7b4564753f748e7fb497e73e..8485d7d32fed8554c6d9afd610db230f52497da1 100644 --- a/python/paddle/fluid/net_drawer.py +++ b/python/paddle/fluid/net_drawer.py @@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) try: - from .graphviz import Digraph + from .graphviz import Graph except ImportError: logger.info( 'Cannot import graphviz, which is required for drawing a network. This ' @@ -112,7 +112,7 @@ def draw_graph(startup_program, main_program, **kwargs): filename = kwargs.get("filename") if filename == None: filename = str(graph_id) + ".gv" - g = Digraph( + g = Graph( name=str(graph_id), filename=filename, graph_attr=GRAPH_STYLE,