Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b24f84c8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b24f84c8
编写于
9月 01, 2021
作者:
0
0x45f
提交者:
GitHub
9月 01, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2stat]modify dy2stat error message in compile time (#35320)
* modify dy2stat error message in compile time * fix variable name
上级
b53887fd
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
79 addition
and
15 deletion
+79
-15
python/paddle/fluid/dygraph/dygraph_to_static/error.py
python/paddle/fluid/dygraph/dygraph_to_static/error.py
+61
-12
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py
...dle/fluid/tests/unittests/dygraph_to_static/test_error.py
+18
-3
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/error.py
浏览文件 @
b24f84c8
...
...
@@ -16,6 +16,7 @@ import os
import
six
import
sys
import
traceback
import
linecache
from
paddle.fluid.dygraph.dygraph_to_static.origin_info
import
Location
,
OriginInfo
,
global_origin_info_map
...
...
@@ -29,6 +30,9 @@ DEFAULT_SIMPLIFY_NEW_ERROR = 1
DISABLE_ERROR_ENV_NAME
=
"TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR
=
0
SOURCE_CODE_RANGE
=
5
BLANK_COUNT_BEFORE_FILE_STR
=
4
def
attach_error_data
(
error
,
in_runtime
=
False
):
"""
...
...
@@ -40,6 +44,7 @@ def attach_error_data(error, in_runtime=False):
Returns:
An error attached data about original source code information and traceback.
"""
e_type
,
e_value
,
e_traceback
=
sys
.
exc_info
()
tb
=
traceback
.
extract_tb
(
e_traceback
)[
1
:]
...
...
@@ -82,12 +87,49 @@ class TraceBackFrame(OriginInfo):
def
formated_message
(
self
):
# self.source_code may be empty in some functions.
# For example, decorator generated function
return
'
File "{}", line {}, in {}
\n\t
{}'
.
format
(
return
'
'
*
BLANK_COUNT_BEFORE_FILE_STR
+
'
File "{}", line {}, in {}
\n\t
{}'
.
format
(
self
.
location
.
filepath
,
self
.
location
.
lineno
,
self
.
function_name
,
self
.
source_code
.
lstrip
()
if
isinstance
(
self
.
source_code
,
str
)
else
self
.
source_code
)
class
TraceBackFrameRange
(
OriginInfo
):
"""
Traceback frame information.
"""
def
__init__
(
self
,
location
,
function_name
):
self
.
location
=
location
self
.
function_name
=
function_name
self
.
source_code
=
[]
blank_count
=
[]
begin_lineno
=
max
(
1
,
self
.
location
.
lineno
-
int
(
SOURCE_CODE_RANGE
/
2
))
for
i
in
range
(
begin_lineno
,
begin_lineno
+
SOURCE_CODE_RANGE
):
line
=
linecache
.
getline
(
self
.
location
.
filepath
,
i
)
line_lstrip
=
line
.
strip
()
self
.
source_code
.
append
(
line_lstrip
)
blank_count
.
append
(
len
(
line
)
-
len
(
line_lstrip
))
if
i
==
self
.
location
.
lineno
:
hint_msg
=
'~'
*
len
(
self
.
source_code
[
-
1
])
+
' <--- HERE'
self
.
source_code
.
append
(
hint_msg
)
blank_count
.
append
(
blank_count
[
-
1
])
linecache
.
clearcache
()
min_black_count
=
min
(
blank_count
)
for
i
in
range
(
len
(
self
.
source_code
)):
self
.
source_code
[
i
]
=
' '
*
(
blank_count
[
i
]
-
min_black_count
+
BLANK_COUNT_BEFORE_FILE_STR
*
2
)
+
self
.
source_code
[
i
]
def
formated_message
(
self
):
msg
=
' '
*
BLANK_COUNT_BEFORE_FILE_STR
+
'File "{}", line {}, in {}
\n
'
.
format
(
self
.
location
.
filepath
,
self
.
location
.
lineno
,
self
.
function_name
)
# add empty line after range code
return
msg
+
'
\n
'
.
join
(
self
.
source_code
)
+
'
\n
'
class
ErrorData
(
object
):
"""
Error data attached to an exception which is raised in un-transformed code.
...
...
@@ -128,26 +170,34 @@ class ErrorData(object):
return
'
\n
'
.
join
(
message_lines
)
# Step2: Optimizes stack information with source code information of dygraph from user.
for
filepath
,
lineno
,
funcname
,
code
in
self
.
origin_traceback
:
whether_source_range
=
True
for
filepath
,
lineno
,
funcname
,
code
in
self
.
origin_traceback
[::
-
1
]:
loc
=
Location
(
filepath
,
lineno
)
dygraph_func_info
=
self
.
origin_info_map
.
get
(
loc
.
line_location
,
None
)
if
dygraph_func_info
:
# TODO(liym27): more information to prompt users that this is the original information.
# Replaces trace stack information about transformed static code with original dygraph code.
traceback_frame
=
self
.
origin_info_map
[
loc
.
line_location
]
else
:
traceback_frame
=
TraceBackFrame
(
loc
,
funcname
,
code
)
message_lines
.
append
(
traceback_frame
.
formated_message
())
if
whether_source_range
:
traceback_frame
=
TraceBackFrameRange
(
dygraph_func_info
.
location
,
dygraph_func_info
.
function_name
)
whether_source_range
=
False
else
:
traceback_frame
=
TraceBackFrame
(
dygraph_func_info
.
location
,
dygraph_func_info
.
function_name
,
dygraph_func_info
.
source_code
)
# Two elements already exist in message_lines: "In transformed code:" and "", so insert in index 2
message_lines
.
insert
(
2
,
traceback_frame
.
formated_message
())
# Step3: Adds error message like "TypeError: dtype must be int32, but received float32".
# NOTE: `format_exception` is a list, its length is 1 in most cases, but sometimes its length
# is gather than 1, for example, the error_type is IndentationError.
format_exception
=
traceback
.
format_exception_only
(
self
.
error_type
,
self
.
error_value
)
error_message
=
[
" "
*
4
+
line
for
line
in
format_exception
]
error_message
=
[
" "
*
BLANK_COUNT_BEFORE_FILE_STR
+
line
for
line
in
format_exception
]
message_lines
.
extend
(
error_message
)
return
'
\n
'
.
join
(
message_lines
)
...
...
@@ -175,7 +225,6 @@ class ErrorData(object):
self
.
error_value
=
self
.
error_type
(
error_value_str
)
def
raise_new_exception
(
self
):
# Raises the origin error if disable dygraph2static error module,
if
int
(
os
.
getenv
(
DISABLE_ERROR_ENV_NAME
,
DEFAULT_DISABLE_NEW_ERROR
)):
raise
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py
浏览文件 @
b24f84c8
...
...
@@ -218,7 +218,10 @@ class TestErrorStaticLayerCallInCompiletime(TestErrorBase):
[
'File "{}", line 35, in func_error_in_compile_time'
.
format
(
self
.
filepath
),
'inner_func()'
,
'File "{}", line 28, in inner_func'
.
format
(
self
.
filepath
),
'def inner_func():'
,
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")'
,
'<--- HERE'
,
'return'
,
]
def
set_func_call
(
self
):
...
...
@@ -242,7 +245,11 @@ class TestErrorStaticLayerCallInCompiletime_2(
self
.
expected_message
=
\
[
'File "{}", line 46, in func_error_in_compile_time_2'
.
format
(
self
.
filepath
),
'x = fluid.layers.reshape(x, shape=[1, 2])'
'def func_error_in_compile_time_2(x):'
,
'x = fluid.dygraph.to_variable(x)'
,
'x = fluid.layers.reshape(x, shape=[1, 2])'
,
'<--- HERE'
,
'return x'
]
...
...
@@ -261,7 +268,10 @@ class TestErrorStaticLayerCallInCompiletime_3(
def
set_message
(
self
):
self
.
expected_message
=
\
[
'File "{}", line 91, in forward'
.
format
(
self
.
filepath
),
'@paddle.jit.to_static'
,
'def forward(self):'
,
'self.test_func()'
,
'<--- HERE'
]
def
set_func_call
(
self
):
...
...
@@ -318,7 +328,12 @@ class TestJitSaveInCompiletime(TestErrorBase):
def
set_message
(
self
):
self
.
expected_message
=
\
[
'File "{}", line 80, in forward'
.
format
(
self
.
filepath
),
'fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")'
,
'def forward(self, x):'
,
'y = self._linear(x)'
,
'z = fluid.layers.fill_constant(shape=[1, 2], value=9, dtype="int")'
,
'<--- HERE'
,
'out = fluid.layers.mean(y[z])'
,
'return out'
]
def
set_func_call
(
self
):
...
...
@@ -329,7 +344,7 @@ class TestJitSaveInCompiletime(TestErrorBase):
self
.
_test_raise_new_exception
()
# Situation 4: NotImplementedError
#
#
Situation 4: NotImplementedError
class
TestErrorInOther
(
unittest
.
TestCase
):
def
test
(
self
):
paddle
.
disable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录