Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3e20ddf7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e20ddf7
编写于
9月 11, 2020
作者:
L
liym27
提交者:
GitHub
9月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat - Error Handling] Fix bug and optimize dy2stat error. (#27225)
上级
ac8afe18
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
140 addition
and
56 deletion
+140
-56
python/paddle/fluid/dygraph/dygraph_to_static/error.py
python/paddle/fluid/dygraph/dygraph_to_static/error.py
+32
-1
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+44
-41
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py
...dle/fluid/tests/unittests/dygraph_to_static/test_error.py
+64
-14
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/error.py
浏览文件 @
3e20ddf7
...
...
@@ -13,6 +13,7 @@
# limitations under the License.
import
os
import
six
import
sys
import
traceback
...
...
@@ -20,6 +21,14 @@ from paddle.fluid.dygraph.dygraph_to_static.origin_info import Location, OriginI
ERROR_DATA
=
"Error data about original source code information and traceback."
# A flag to set whether to open the dygraph2static error reporting module
SIMPLIFY_ERROR_ENV_NAME
=
"TRANSLATOR_SIMPLIFY_NEW_ERROR"
DEFAULT_SIMPLIFY_NEW_ERROR
=
1
# A flag to set whether to display the simplified error stack
DISABLE_ERROR_ENV_NAME
=
"TRANSLATOR_DISABLE_NEW_ERROR"
DEFAULT_DISABLE_NEW_ERROR
=
0
def
attach_error_data
(
error
,
in_runtime
=
False
):
"""
...
...
@@ -103,7 +112,10 @@ class ErrorData(object):
# Simplify error value to improve readability if error is raised in runtime
if
self
.
in_runtime
:
self
.
_simplify_error_value
()
if
int
(
os
.
getenv
(
SIMPLIFY_ERROR_ENV_NAME
,
DEFAULT_SIMPLIFY_NEW_ERROR
)):
self
.
_simplify_error_value
()
message_lines
.
append
(
str
(
self
.
error_value
))
return
'
\n
'
.
join
(
message_lines
)
...
...
@@ -150,3 +162,22 @@ class ErrorData(object):
error_value_str
=
'
\n
'
.
join
(
error_value_lines
)
self
.
error_value
=
self
.
error_type
(
error_value_str
)
def
raise_new_exception
(
self
):
# Raises the origin error if disable dygraph2static error module,
if
int
(
os
.
getenv
(
DISABLE_ERROR_ENV_NAME
,
DEFAULT_DISABLE_NEW_ERROR
)):
raise
new_exception
=
self
.
create_exception
()
if
six
.
PY3
:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six
.
exec_
(
"raise new_exception from None"
)
else
:
raise
new_exception
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
3e20ddf7
...
...
@@ -32,8 +32,7 @@ from paddle.fluid.layers.utils import flatten
from
paddle.fluid.dygraph.base
import
param_guard
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static
import
DygraphToStaticAst
from
paddle.fluid.dygraph.dygraph_to_static.error
import
ERROR_DATA
from
paddle.fluid.dygraph.dygraph_to_static.error
import
attach_error_data
from
paddle.fluid.dygraph.dygraph_to_static
import
error
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static.origin_info
import
attach_origin_info
from
paddle.fluid.dygraph.dygraph_to_static.origin_info
import
create_and_update_origin_info_map
...
...
@@ -315,6 +314,7 @@ class StaticLayer(object):
# 2. trace ops from dygraph layers and cache the generated program.
args
,
kwargs
=
self
.
_function_spec
.
unified_args_and_kwargs
(
args
,
kwargs
)
try
:
concrete_program
,
partial_program_layer
=
self
.
get_concrete_program
(
*
args
,
**
kwargs
)
...
...
@@ -324,27 +324,22 @@ class StaticLayer(object):
partial_program_layer
.
training
=
self
.
_class_instance
.
training
# 4. return outputs.
return
partial_program_layer
(
args
)
try
:
return
partial_program_layer
(
args
)
except
Exception
as
e
:
if
not
hasattr
(
e
,
error
.
ERROR_DATA
):
# runtime error
error
.
attach_error_data
(
e
,
in_runtime
=
True
)
raise
except
Exception
as
e
:
if
not
hasattr
(
e
,
ERROR_DATA
):
# runtime error
attach_error_data
(
e
,
in_runtime
=
True
)
error_data
=
getattr
(
e
,
ERROR_DATA
,
None
)
error_data
=
getattr
(
e
,
error
.
ERROR_DATA
,
None
)
if
error_data
:
new_exception
=
error_data
.
create_exception
()
if
six
.
PY3
:
# NOTE(liym27):
# 1. Why `raise new_exception from None`?
# In Python 3, by default, an new exception is raised with trace information of the caught exception.
# This only raises new_exception and hides unwanted implementation details from tracebacks of the
# caught exception.
# 2. Use exec to bypass syntax error checking in Python 2.
six
.
exec_
(
"raise new_exception from None"
)
else
:
raise
new_exception
error_data
.
raise_new_exception
()
else
:
raise
logging_utils
.
warn
(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself."
.
format
(
type
(
e
)))
raise
e
def
_call_dygraph_function
(
self
,
*
args
,
**
kwargs
):
"""
...
...
@@ -593,7 +588,7 @@ class ConcreteProgram(object):
outputs
=
static_func
(
*
inputs
)
except
BaseException
as
e
:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
attach_error_data
(
e
)
error
.
attach_error_data
(
e
)
raise
if
not
isinstance
(
outputs
,
...
...
@@ -813,28 +808,36 @@ class ProgramTranslator(object):
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
"We will just return dygraph output."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
function_spec
=
FunctionSpec
(
dygraph_func
)
cache_key
=
CacheKey
.
from_func_and_args
(
function_spec
,
args
,
kwargs
,
getattr
(
dygraph_func
,
'__self__'
,
None
))
_
,
partial_program_layer
=
self
.
_program_cache
[
cache_key
]
if
args
and
isinstance
(
args
[
0
],
layers
.
Layer
):
# Synchronize self.training attribute.
partial_program_layer
.
training
=
args
[
0
].
training
args
=
args
[
1
:]
try
:
return
partial_program_layer
(
args
)
function_spec
=
FunctionSpec
(
dygraph_func
)
cache_key
=
CacheKey
.
from_func_and_args
(
function_spec
,
args
,
kwargs
,
getattr
(
dygraph_func
,
'__self__'
,
None
))
_
,
partial_program_layer
=
self
.
_program_cache
[
cache_key
]
if
args
and
isinstance
(
args
[
0
],
layers
.
Layer
):
# Synchronize self.training attribute.
partial_program_layer
.
training
=
args
[
0
].
training
args
=
args
[
1
:]
try
:
return
partial_program_layer
(
args
)
except
BaseException
as
e
:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if
not
hasattr
(
e
,
error
.
ERROR_DATA
):
# runtime error
error
.
attach_error_data
(
e
,
in_runtime
=
True
)
raise
except
BaseException
as
e
:
# NOTE:
# 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
# 2. If e raised in runtime, e should be attached to ERROR_DATA here.
if
not
hasattr
(
e
,
ERROR_DATA
):
# runtime error
attach_error_data
(
e
,
in_runtime
=
True
)
raise
error_data
=
getattr
(
e
,
error
.
ERROR_DATA
,
None
)
if
error_data
:
error_data
.
raise_new_exception
()
else
:
logging_utils
.
warn
(
"Please file an issue at 'https://github.com/PaddlePaddle/Paddle/issues'"
" if you can't handle this {} yourself."
.
format
(
type
(
e
)))
raise
e
def
get_func
(
self
,
dygraph_func
):
"""
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py
浏览文件 @
3e20ddf7
...
...
@@ -14,15 +14,15 @@
from
__future__
import
print_function
import
os
import
inspect
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid.core
import
EnforceNotMet
from
paddle.fluid.dygraph.dygraph_to_static
.error
import
ERROR_DATA
,
ErrorData
from
paddle.fluid.dygraph.dygraph_to_static
import
error
from
paddle.fluid.dygraph.dygraph_to_static.origin_info
import
unwrap
from
paddle.fluid.dygraph.jit
import
declarative
def
inner_func
():
...
...
@@ -30,7 +30,7 @@ def inner_func():
return
@
declarative
@
paddle
.
jit
.
to_static
def
func_error_in_compile_time
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
inner_func
()
...
...
@@ -41,14 +41,14 @@ def func_error_in_compile_time(x):
return
x_v
@
declarative
@
paddle
.
jit
.
to_static
def
func_error_in_compile_time_2
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
x
=
fluid
.
layers
.
reshape
(
x
,
shape
=
[
1
,
2
])
return
x
@
declarative
@
paddle
.
jit
.
to_static
def
func_error_in_runtime
(
x
,
iter_num
=
3
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
two
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
2
,
dtype
=
"int32"
)
...
...
@@ -61,6 +61,9 @@ class TestErrorInCompileTime(unittest.TestCase):
self
.
set_func
()
self
.
set_input
()
self
.
set_exception_type
()
self
.
prog_trans
=
paddle
.
jit
.
ProgramTranslator
()
self
.
simplify_error
=
1
self
.
disable_error
=
0
def
set_func
(
self
):
self
.
func
=
func_error_in_compile_time
...
...
@@ -88,14 +91,38 @@ class TestErrorInCompileTime(unittest.TestCase):
for
m
in
self
.
expected_message
:
self
.
assertIn
(
m
,
error_message
)
def
test
(
self
):
with
fluid
.
dygraph
.
guard
():
with
self
.
assertRaises
(
self
.
exception_type
)
as
cm
:
self
.
func
(
self
.
input
)
exception
=
cm
.
exception
error_data
=
getattr
(
exception
,
ERROR_DATA
)
self
.
assertIsInstance
(
error_data
,
ErrorData
)
self
.
_test_create_message
(
error_data
)
def
_test_attach_and_raise_new_exception
(
self
,
func_call
):
paddle
.
disable_static
()
with
self
.
assertRaises
(
self
.
exception_type
)
as
cm
:
func_call
()
exception
=
cm
.
exception
error_data
=
getattr
(
exception
,
error
.
ERROR_DATA
,
None
)
self
.
assertIsInstance
(
error_data
,
error
.
ErrorData
)
self
.
_test_create_message
(
error_data
)
def
test_static_layer_call
(
self
):
# NOTE: self.func(self.input) is the StaticLayer().__call__(self.input)
call_dy2static
=
lambda
:
self
.
func
(
self
.
input
)
self
.
set_flags
(
0
)
self
.
_test_attach_and_raise_new_exception
(
call_dy2static
)
def
test_program_translator_get_output
(
self
):
call_dy2static
=
lambda
:
self
.
prog_trans
.
get_output
(
unwrap
(
self
.
func
),
self
.
input
)
self
.
set_flags
(
0
)
self
.
_test_attach_and_raise_new_exception
(
call_dy2static
)
def
set_flags
(
self
,
disable_error
=
0
,
simplify_error
=
1
):
os
.
environ
[
error
.
DISABLE_ERROR_ENV_NAME
]
=
str
(
disable_error
)
self
.
disable_error
=
int
(
os
.
getenv
(
error
.
DISABLE_ERROR_ENV_NAME
,
0
))
self
.
assertEqual
(
self
.
disable_error
,
disable_error
)
os
.
environ
[
error
.
SIMPLIFY_ERROR_ENV_NAME
]
=
str
(
simplify_error
)
self
.
simplify_error
=
int
(
os
.
getenv
(
error
.
SIMPLIFY_ERROR_ENV_NAME
,
1
))
self
.
assertEqual
(
self
.
simplify_error
,
simplify_error
)
class
TestErrorInCompileTime2
(
TestErrorInCompileTime
):
...
...
@@ -143,5 +170,28 @@ class TestErrorInRuntime(TestErrorInCompileTime):
self
.
assertIn
(
m
,
error_message
)
@
unwrap
@
paddle
.
jit
.
to_static
()
def
func_decorated_by_other_1
():
return
1
@
paddle
.
jit
.
to_static
()
@
unwrap
def
func_decorated_by_other_2
():
return
1
class
TestErrorInOther
(
unittest
.
TestCase
):
def
test
(
self
):
paddle
.
disable_static
()
prog_trans
=
paddle
.
jit
.
ProgramTranslator
()
with
self
.
assertRaises
(
NotImplementedError
):
prog_trans
.
get_output
(
func_decorated_by_other_1
)
with
self
.
assertRaises
(
NotImplementedError
):
func_decorated_by_other_2
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录