提交 3026aba7 编写于 作者: M minqiyang

Fix net_drawer

test=develop
上级 e2130502
......@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
try:
from .graphviz import Digraph
from .graphviz import Graph
except ImportError:
logger.info(
'Cannot import graphviz, which is required for drawing a network. This '
......@@ -112,7 +112,7 @@ def draw_graph(startup_program, main_program, **kwargs):
filename = kwargs.get("filename")
if filename == None:
filename = str(graph_id) + ".gv"
g = Digraph(
g = Graph(
name=str(graph_id),
filename=filename,
graph_attr=GRAPH_STYLE,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册