Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8d6de440
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8d6de440
编写于
6月 19, 2020
作者:
A
Aurelius84
提交者:
GitHub
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine check_type Error Message for @declarative (#25098)
* Refine check_type for @declarative test=develop
上级
b23801a2
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
62 addition
and
14 deletion
+62
-14
python/paddle/fluid/data_feeder.py
python/paddle/fluid/data_feeder.py
+18
-3
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
...dle/fluid/dygraph/dygraph_to_static/program_translator.py
+35
-11
python/paddle/fluid/tests/unittests/test_imperative_basic.py
python/paddle/fluid/tests/unittests/test_imperative_basic.py
+9
-0
未找到文件。
python/paddle/fluid/data_feeder.py
浏览文件 @
8d6de440
...
...
@@ -60,7 +60,7 @@ def convert_dtype(dtype):
u
'float64'
,
u
'int8'
,
u
'int16'
,
u
'int32'
,
u
'int64'
,
u
'uint8'
]:
# this code is a little bit dangerous, since error could happen
# when casting no-asci code to str in python2.
# when casting no-asci
i
code to str in python2.
# but since the set itself is limited, so currently, it is good.
# however, jointly supporting python2 and python3, (as well as python4 maybe)
# may still be a long-lasting problem.
...
...
@@ -76,8 +76,7 @@ def check_variable_and_dtype(input,
expected_dtype
,
op_name
,
extra_message
=
''
):
check_type
(
input
,
input_name
,
(
Variable
,
core
.
VarBase
),
op_name
,
extra_message
)
check_type
(
input
,
input_name
,
Variable
,
op_name
,
extra_message
)
check_dtype
(
input
.
dtype
,
input_name
,
expected_dtype
,
op_name
,
extra_message
)
...
...
@@ -91,6 +90,22 @@ def check_type(input, input_name, expected_type, op_name, extra_message=''):
# each step in dynamic graph mode, it will bring a heavy performance burden.
if
in_dygraph_mode
():
return
from
.dygraph.dygraph_to_static.program_translator
import
in_declarative_mode
# NOTE: `in_declarative_mode` is used to determined whether this op is called under
# @declarative in transformation from dygrah to static layer. We add VarBase in
# expected_type to skip checking because varBase may be created and used in unusual way.
# Need a better design to be fix this.
if
in_declarative_mode
():
if
not
isinstance
(
expected_type
,
tuple
):
expected_type
=
(
expected_type
,
)
expected_type
+=
(
core
.
VarBase
,
)
elif
isinstance
(
input
,
core
.
VarBase
):
raise
TypeError
(
"Please use `with fluid.dygraph.guard()` as context or `fluid.enable_dygraph()` to switch to imperative mode firstly. "
"Because received '{}' in {} is a imperative Variable."
.
format
(
input_name
,
op_name
))
if
not
isinstance
(
input
,
expected_type
):
raise
TypeError
(
"The type of '%s' in %s must be %s, but received %s. %s"
%
...
...
python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
浏览文件 @
8d6de440
...
...
@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static.ast_transformer
import
convert_to_static
from
paddle.fluid.dygraph.dygraph_to_static.ast_transformer
import
DygraphToStaticAst
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.wrapped_decorator
import
signature_safe_contextmanager
from
paddle.fluid.dygraph.base
import
param_guard
from
paddle.fluid.data_feeder
import
check_type
from
paddle.fluid.dygraph.dygraph_to_static.partial_program
import
partial_program_from
...
...
@@ -155,6 +156,28 @@ class FunctionSpec(object):
return
self
.
__key
()
==
self
.
__key
()
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_
=
False
def
in_declarative_mode
():
"""
Return a bool value that indicates whether running code under `@declarative`
"""
return
_in_declarative_mode_
@
signature_safe_contextmanager
def
_switch_declarative_mode_guard_
(
is_declarative
=
True
):
global
_in_declarative_mode_
original_val
=
_in_declarative_mode_
_in_declarative_mode_
=
is_declarative
yield
_in_declarative_mode_
=
original_val
class
ConcreteProgram
(
object
):
def
__init__
(
self
,
inputs
,
...
...
@@ -190,17 +213,18 @@ class ConcreteProgram(object):
).
random_seed
with
framework
.
program_guard
(
main_program
,
startup_program
):
# 1. Adds `fluid.data` layers for input if needed
inputs
=
func_spec
.
to_static_inputs
(
main_program
)
# 2. Gets all ParamBases in the function
all_parameters
=
list
(
func_spec
.
parameters
().
values
())
# 3. Builds program only once and returns the output Variables.
with
param_guard
(
func_spec
.
parameters
(
False
)):
outputs
=
static_func
(
*
inputs
)
if
not
isinstance
(
outputs
,
(
tuple
,
list
)):
outputs
=
[
outputs
]
if
outputs
else
[]
with
_switch_declarative_mode_guard_
(
is_declarative
=
True
):
# 1. Adds `fluid.data` layers for input if needed
inputs
=
func_spec
.
to_static_inputs
(
main_program
)
# 2. Gets all ParamBases in the function
all_parameters
=
list
(
func_spec
.
parameters
().
values
())
# 3. Builds program only once and returns the output Variables.
with
param_guard
(
func_spec
.
parameters
(
False
)):
outputs
=
static_func
(
*
inputs
)
if
not
isinstance
(
outputs
,
(
tuple
,
list
)):
outputs
=
[
outputs
]
if
outputs
else
[]
return
ConcreteProgram
(
inputs
=
inputs
,
...
...
python/paddle/fluid/tests/unittests/test_imperative_basic.py
浏览文件 @
8d6de440
...
...
@@ -645,5 +645,14 @@ class TestDygraphUtils(unittest.TestCase):
self
.
assertTrue
(
np
.
array_equal
(
res1
.
numpy
(),
res2
.
numpy
()))
class
TestDygraphGuardWithError
(
unittest
.
TestCase
):
def
test_without_guard
(
self
):
with
fluid
.
dygraph
.
guard
():
x
=
fluid
.
dygraph
.
to_variable
(
np
.
zeros
([
10
,
10
]))
with
self
.
assertRaisesRegexp
(
TypeError
,
"Please use `with fluid.dygraph.guard()"
):
y
=
fluid
.
layers
.
matmul
(
x
,
x
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录