提交 3026aba7 编写于 作者: M minqiyang

Fix net_drawer

test=develop
上级 e2130502
...@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__) ...@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
try: try:
from .graphviz import Digraph from .graphviz import Graph
except ImportError: except ImportError:
logger.info( logger.info(
'Cannot import graphviz, which is required for drawing a network. This ' 'Cannot import graphviz, which is required for drawing a network. This '
...@@ -112,7 +112,7 @@ def draw_graph(startup_program, main_program, **kwargs): ...@@ -112,7 +112,7 @@ def draw_graph(startup_program, main_program, **kwargs):
filename = kwargs.get("filename") filename = kwargs.get("filename")
if filename == None: if filename == None:
filename = str(graph_id) + ".gv" filename = str(graph_id) + ".gv"
g = Digraph( g = Graph(
name=str(graph_id), name=str(graph_id),
filename=filename, filename=filename,
graph_attr=GRAPH_STYLE, graph_attr=GRAPH_STYLE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册