未验证 提交 cbfd15f9 编写于 作者: G gongweibao 提交者: GitHub

Fix debugger bugs. (#9025)

上级 e13aec60
...@@ -16,6 +16,7 @@ import sys ...@@ -16,6 +16,7 @@ import sys
import re import re
from graphviz import GraphPreviewGenerator from graphviz import GraphPreviewGenerator
import proto.framework_pb2 as framework_pb2 import proto.framework_pb2 as framework_pb2
import paddle.fluid.core as core
_vartype2str_ = [ _vartype2str_ = [
"UNK", "UNK",
...@@ -52,9 +53,11 @@ reprtpl = "{ttype} {name} ({reprs})" ...@@ -52,9 +53,11 @@ reprtpl = "{ttype} {name} ({reprs})"
def repr_lodtensor(proto): def repr_lodtensor(proto):
if not proto.lod_tensor: return if proto.type.type != framework_pb2.VarType.LOD_TENSOR:
level = proto.lod_tensor.lod_level return
reprs = repr_tensor(proto.lod_tensor.tensor)
level = proto.type.lod_tensor.lod_level
reprs = repr_tensor(proto.type.lod_tensor.tensor)
return reprtpl.format( return reprtpl.format(
ttype="LoDTensor" if level > 0 else "Tensor", ttype="LoDTensor" if level > 0 else "Tensor",
name=proto.name, name=proto.name,
...@@ -62,20 +65,24 @@ def repr_lodtensor(proto): ...@@ -62,20 +65,24 @@ def repr_lodtensor(proto):
def repr_selected_rows(proto): def repr_selected_rows(proto):
if not proto.selected_rows: return if proto.type.type != framework_pb2.VarType.SELECTED_ROWS:
return
return reprtpl.format( return reprtpl.format(
ttype="SelectedRows", ttype="SelectedRows",
name=proto.name, name=proto.name,
reprs=repr_tensor(proto.selected_rows)) reprs=repr_tensor(proto.type.selected_rows))
def repr_tensor_array(proto): def repr_tensor_array(proto):
if not proto.tensor_array: return if proto.type.type != framework_pb2.VarType.LOD_TENSOR_ARRAY:
return
return reprtpl.format( return reprtpl.format(
ttype="TensorArray", ttype="TensorArray",
name=proto.name, name=proto.name,
reprs="level=%d, %s" % (proto.tensor_array.lod_level, reprs="level=%d, %s" % (proto.type.tensor_array.lod_level,
repr_tensor(proto.lod_tensor))) repr_tensor(proto.type.lod_tensor.tensor)))
type_handlers = [ type_handlers = [
...@@ -119,6 +126,7 @@ def pprint_block_codes(block_desc, show_backward=False): ...@@ -119,6 +126,7 @@ def pprint_block_codes(block_desc, show_backward=False):
def is_var_backward(var_desc): def is_var_backward(var_desc):
return "@GRAD" in var_desc.name return "@GRAD" in var_desc.name
#print(type(block_desc))
if type(block_desc) is not framework_pb2.BlockDesc: if type(block_desc) is not framework_pb2.BlockDesc:
block_desc = framework_pb2.BlockDesc.FromString( block_desc = framework_pb2.BlockDesc.FromString(
block_desc.serialize_to_string()) block_desc.serialize_to_string())
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import debuger
from paddle.fluid.framework import Program
class TestDebugger(unittest.TestCase):
def test_debug_str(self):
p = Program()
b = p.current_block()
#selected_rows
b.create_var(
name='selected_rows',
dtype="float32",
shape=[5, 10],
type=core.VarDesc.VarType.SELECTED_ROWS)
#tensor array
b.create_var(
name='tensor_array',
shape=[5, 10],
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY)
#operator
mul_x = b.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = b.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = b.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
b.append_op(
type="mul",
inputs={"X": mul_x,
"Y": mul_y},
outputs={"Out": mul_out},
attrs={"x_num_col_dims": 1})
print(debuger.pprint_program_codes(p.desc))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册