未验证 提交 c1b6692f 编写于 作者: G gongweibao 提交者: GitHub

Fix debuger bugs. (#9705)

Fix debuger bugs
上级 be853853
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import re import re
from graphviz import GraphPreviewGenerator from graphviz import GraphPreviewGenerator
import proto.framework_pb2 as framework_pb2 import proto.framework_pb2 as framework_pb2
from google.protobuf import text_format
_vartype2str_ = [ _vartype2str_ = [
"UNK", "UNK",
...@@ -100,7 +101,7 @@ def repr_var(vardesc): ...@@ -100,7 +101,7 @@ def repr_var(vardesc):
def pprint_program_codes(program_desc): def pprint_program_codes(program_desc):
reprs = [] reprs = []
for block_idx in range(program_desc.num_blocks()): for block_idx in range(program_desc.desc.num_blocks()):
block_desc = program_desc.block(block_idx) block_desc = program_desc.block(block_idx)
block_repr = pprint_block_codes(block_desc) block_repr = pprint_block_codes(block_desc)
reprs.append(block_repr) reprs.append(block_repr)
...@@ -127,7 +128,7 @@ def pprint_block_codes(block_desc, show_backward=False): ...@@ -127,7 +128,7 @@ def pprint_block_codes(block_desc, show_backward=False):
if type(block_desc) is not framework_pb2.BlockDesc: if type(block_desc) is not framework_pb2.BlockDesc:
block_desc = framework_pb2.BlockDesc.FromString( block_desc = framework_pb2.BlockDesc.FromString(
block_desc.serialize_to_string()) block_desc.desc.serialize_to_string())
var_reprs = [] var_reprs = []
op_reprs = [] op_reprs = []
for var in block_desc.vars: for var in block_desc.vars:
...@@ -237,13 +238,13 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"): ...@@ -237,13 +238,13 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
# draw parameters and args # draw parameters and args
vars = {} vars = {}
for var in desc.vars: for var in desc.vars:
shape = [str(i) for i in var.lod_tensor.tensor.dims] # TODO(gongwb): format the var.type
if not shape:
shape = ['null']
# create var # create var
if var.persistable: if var.persistable:
varn = graph.add_param( varn = graph.add_param(
var.name, var.type, shape, highlight=need_highlight(var.name)) var.name,
str(var.type).replace("\n", "<br />", 1),
highlight=need_highlight(var.name))
else: else:
varn = graph.add_arg(var.name, highlight=need_highlight(var.name)) varn = graph.add_arg(var.name, highlight=need_highlight(var.name))
vars[var.name] = varn vars[var.name] = varn
...@@ -268,4 +269,4 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"): ...@@ -268,4 +269,4 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
for var in op.outputs: for var in op.outputs:
add_op_link_var(opn, var, True) add_op_link_var(opn, var, True)
graph(path, show=True) graph(path, show=False)
...@@ -83,7 +83,7 @@ class Graph(object): ...@@ -83,7 +83,7 @@ class Graph(object):
file = open(dot_path, 'w') file = open(dot_path, 'w')
file.write(self.__str__()) file.write(self.__str__())
image_path = os.path.join( image_path = os.path.join(
os.path.dirname(__file__), dot_path[:-3] + "pdf") os.path.dirname(dot_path), dot_path[:-3] + "pdf")
cmd = ["dot", "-Tpdf", dot_path, "-o", image_path] cmd = ["dot", "-Tpdf", dot_path, "-o", image_path]
subprocess.Popen( subprocess.Popen(
cmd, cmd,
...@@ -199,7 +199,7 @@ class GraphPreviewGenerator(object): ...@@ -199,7 +199,7 @@ class GraphPreviewGenerator(object):
else: else:
self.graph.show(path) self.graph.show(path)
def add_param(self, name, data_type, shape, highlight=False): def add_param(self, name, data_type, highlight=False):
label = '\n'.join([ label = '\n'.join([
'<<table cellpadding="5">', '<<table cellpadding="5">',
' <tr>', ' <tr>',
...@@ -214,11 +214,6 @@ class GraphPreviewGenerator(object): ...@@ -214,11 +214,6 @@ class GraphPreviewGenerator(object):
str(data_type), str(data_type),
' </td>' ' </td>'
' </tr>', ' </tr>',
' <tr>',
' <td>',
'[%s]' % 'x'.join(shape),
' </td>'
' </tr>',
'</table>>', '</table>>',
]) ])
return self.graph.node( return self.graph.node(
......
...@@ -51,7 +51,9 @@ class TestDebugger(unittest.TestCase): ...@@ -51,7 +51,9 @@ class TestDebugger(unittest.TestCase):
outputs={"Out": mul_out}, outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1}) attrs={"x_num_col_dims": 1})
print(debuger.pprint_program_codes(p.desc)) print(debuger.pprint_program_codes(p))
debuger.draw_block_graphviz(p.block(0), path="./test.dot")
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册