Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
def27bc8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
def27bc8
编写于
3月 11, 2021
作者:
A
Aurelius84
提交者:
GitHub
3月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2stat]Fix bug with static_convert_var_shape in locals scope (#31556)
* Fix bug with static_convert_var_shape * replace dot with dash
上级
49c3d2a9
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
76 addition
and
23 deletion
+76
-23
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+6
-6
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+23
-7
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
...sts/unittests/dygraph_to_static/test_convert_operators.py
+25
-10
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+22
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
def27bc8
...
...
@@ -302,19 +302,19 @@ def convert_var_shape_simple(x):
return
x
.
shape
def
eval_if_exist_else_none
(
name
,
loc
al_symbol_table
):
def
eval_if_exist_else_none
(
name
,
glob
al_symbol_table
):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `
locals()`. DO NOT use `glob
als()`,
it has a higher priority and will hide away variable
s
from `locals()
`.
local_symbol_table(dict): Specified from `
globals()`. DO NOT use `loc
als()`,
because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars i
s
declared with keyword `global
`.
Returns:
Return the variable if found in
loc
al_symbol_table else None.
Return the variable if found in
glob
al_symbol_table else None.
"""
try
:
return
eval
(
name
,
loc
al_symbol_table
)
return
eval
(
name
,
glob
al_symbol_table
)
except
:
return
None
...
...
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
def27bc8
...
...
@@ -59,7 +59,7 @@ def create_convert_shape_node(var_shape_node,
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}',
loc
als())"
.
format
(
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}',
glob
als())"
.
format
(
api_shape_name
)
args
=
[
attr_shape_name
,
eval_exist_func
]
...
...
@@ -293,6 +293,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
return
False
def
_update_name_to_var_shape
(
self
,
node
):
def
replace_dot
(
name
):
# replace all '.' into '_'
return
name
.
replace
(
'.'
,
'_'
)
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
value_node
=
node
.
value
...
...
@@ -307,7 +311,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if
value_node
.
id
in
self
.
name_to_var_shape
:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
replace_dot
(
target_id
)
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
...
...
@@ -328,7 +333,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
replace_dot
(
target_id
)
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
...
...
@@ -341,6 +347,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
ast_to_source_code
(
static_shape_value_node
).
strip
(),
idx
)
sub_node
=
gast
.
parse
(
sub_node_str
).
body
[
0
].
value
# Note(Aurelius84): Becuase static_shape_var_name is used in
# eval_if_exist_else_none() as plain string, so it will not
# be pasred as argument in convert_loop/ifelse. We delcare it
# as global var because it has unique name.
update_static_shape_var_node
.
append
(
gast
.
Global
(
names
=
[
static_shape_var_name
]))
update_static_shape_var_node
.
append
(
gast
.
Assign
(
...
...
@@ -354,7 +366,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
replace_dot
(
target_id
)
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
...
...
@@ -370,17 +383,20 @@ class TensorShapeTransformer(gast.NodeTransformer):
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
elif
self
.
_is_var_shape
(
value_node
):
# eg: x.shape or x.shape[0]
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
replace_dot
(
target_id
)
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node
=
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
# Declare static_shape_var_name as global var
update_static_shape_var_node
=
[
gast
.
Global
(
names
=
[
static_shape_var_name
])
]
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
)
]
value
=
static_shape_value_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
浏览文件 @
def27bc8
...
...
@@ -191,29 +191,44 @@ class TestChooseShapeAttrOrApi(unittest.TestCase):
class
TestEvaIfExistElseNone
(
unittest
.
TestCase
):
def
test_locals
(
self
):
def
test_globals
(
self
):
global
x_shape
x_shape
=
[
1
,
2
,
3
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
x_shape
)
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
None
)
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
x_shape
)
def
test_globals
(
self
):
del
x_shape
def
test_enclosing_scope
(
self
):
global
x_shape
x_shape
=
[
1
,
2
,
3
]
def
foo
():
x_shape
=
[
2
,
3
,
4
]
y_shape
=
[
2
,
3
,
4
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
[
1
,
2
,
3
])
self
.
assertEqual
(
eval_if_exist_else_none
(
'
x
_shape'
,
locals
()),
[
2
,
3
,
4
])
eval_if_exist_else_none
(
'
y
_shape'
,
locals
()),
[
2
,
3
,
4
])
foo
()
del
x_shape
def
test_
invisible_of
_func
(
self
):
def
test_
global_in
_func
(
self
):
x_shape
=
[
1
,
2
,
3
]
def
foo
():
x_shape
=
[
2
,
3
,
4
]
return
x_shape
global
y_shape
y_shape
=
[
2
,
3
,
4
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
[
1
,
2
,
3
])
eval_if_exist_else_none
(
'y_shape'
,
globals
()),
[
2
,
3
,
4
])
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
None
)
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
None
)
del
y_shape
foo
()
def
test_none
(
self
):
def
foo
():
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
浏览文件 @
def27bc8
...
...
@@ -541,5 +541,27 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
self
.
expected_slice_op_num
=
2
def
dyfunc_with_static_convert_var_shape
(
x
):
# Note: this will create `batch_size__static_convert_var_shape_suffix_0` firstly.
batch_size
=
x
.
shape
[
0
]
if
len
(
x
.
shape
)
<
1
:
res
=
x
else
:
# Test for correctly to find `batch_size__static_convert_var_shape_suffix_0` in
# deeply nested scope.
res
=
fluid
.
layers
.
fill_constant
(
value
=
8
,
shape
=
[
batch_size
],
dtype
=
"int32"
)
return
res
class
TestFindStatiConvertVarShapeSuffixVar
(
unittest
.
TestCase
):
def
test
(
self
):
x_spec
=
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
10
])
func
=
paddle
.
jit
.
to_static
(
dyfunc_with_if_2
,
input_spec
=
[
x_spec
])
# Call this function to trigger program translation.
func
.
concrete_program
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录