Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a72408f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a72408f
编写于
2月 23, 2021
作者:
H
Huihuang Zheng
提交者:
GitHub
2月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick][Dy2stat] Cherry-pick of PR31082 and PR31051 (#31101)
Cherry-pick of #31051 and #31082
上级
29467060
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
335 addition
and
79 deletion
+335
-79
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+49
-5
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+163
-63
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
...sts/unittests/dygraph_to_static/test_convert_operators.py
+53
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+65
-10
python/paddle/jit/dy2static/convert_operators.py
python/paddle/jit/dy2static/convert_operators.py
+5
-1
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
3a72408f
...
@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
...
@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
A function representation of the shape of variable.
A function representation of the shape of variable.
"""
"""
def
has_neg
e
tive
(
list_shape
,
idx
=
None
):
def
has_neg
a
tive
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
return
list_shape
[
idx
]
<
0
num_neg
e
tive
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
num_neg
a
tive
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_neg
e
tive
>
0
return
num_neg
a
tive
>
0
# When `x` is Variable, call nn.shape(x) in following cases:
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
# (1) The shape of `x` is used in control flow condition.
...
@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
...
@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# if x.shape[0] == 1:
# if x.shape[0] == 1:
# y = XX
# y = XX
# ```
# ```
# (2) The dim to be used is neg
e
tive
# (2) The dim to be used is neg
a
tive
# ```
# ```
# # Assume x.shape=[3, -1] in static mode
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
# ```
if
isinstance
(
x
,
Variable
)
and
(
in_control_flow
or
has_neg
e
tive
(
x
.
shape
,
if
isinstance
(
x
,
Variable
)
and
(
in_control_flow
or
has_neg
a
tive
(
x
.
shape
,
idx
)):
idx
)):
return
nn
.
shape
(
x
)
if
idx
is
None
else
nn
.
shape
(
x
)[
idx
]
return
nn
.
shape
(
x
)
if
idx
is
None
else
nn
.
shape
(
x
)[
idx
]
else
:
else
:
return
x
.
shape
if
idx
is
None
else
x
.
shape
[
idx
]
return
x
.
shape
if
idx
is
None
else
x
.
shape
[
idx
]
def
convert_var_shape_simple
(
x
):
"""
A function representation of the shape of variable.
"""
if
isinstance
(
x
,
Variable
):
return
nn
.
shape
(
x
)
else
:
return
x
.
shape
def
eval_if_exist_else_none
(
name
):
try
:
return
eval
(
name
)
except
:
return
None
def
choose_shape_attr_or_api
(
attr_shape
,
api_shape
,
idx
=
None
):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if
api_shape
is
None
:
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
if
not
isinstance
(
attr_shape
,
(
list
,
tuple
)):
# some variables like x.shape[0] is no longer a list or tuple
if
isinstance
(
attr_shape
,
int
)
and
attr_shape
<
0
:
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
def
has_negative
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
num_negative
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_negative
>
0
if
has_negative
(
attr_shape
,
idx
):
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
def
convert_shape_compare
(
left
,
*
args
):
def
convert_shape_compare
(
left
,
*
args
):
"""
"""
A function handles comparison difference between Paddle and Python.
A function handles comparison difference between Paddle and Python.
...
...
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
3a72408f
...
@@ -17,12 +17,15 @@ from __future__ import print_function
...
@@ -17,12 +17,15 @@ from __future__ import print_function
import
copy
import
copy
import
gast
import
gast
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX
=
'__static_convert_var_shape_suffix'
def
create_convert_shape_node
(
var_shape_node
,
def
create_convert_shape_node
(
var_shape_node
,
slice_node
=
None
,
slice_node
=
None
,
...
@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node,
...
@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node,
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
if
slice_node
:
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})"
.
format
(
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})"
.
format
(
","
.
join
(
args
),
in_control_flow
)
","
.
join
(
args
),
in_control_flow
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
return
gast
.
Subscript
(
value
=
api_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
return
api_shape_node
return
api_shape_node
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
...
@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node,
...
@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node,
return
result_node
return
result_node
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}')"
.
format
(
api_shape_name
)
args
=
[
attr_shape_name
,
eval_exist_func
]
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
choose_shape_func
=
"paddle.jit.dy2static.choose_shape_attr_or_api({})"
.
format
(
","
.
join
(
args
))
choose_shape_node
=
gast
.
parse
(
choose_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
return
gast
.
Subscript
(
value
=
choose_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
return
choose_shape_node
class
ShapeAttributeTransformer
(
gast
.
NodeTransformer
):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def
visit_Attribute
(
self
,
node
):
if
node
.
attr
==
'shape'
:
args
=
ast_to_source_code
(
node
.
value
).
strip
()
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape_simple({})"
.
format
(
args
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
return
api_shape_node
return
node
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
"""
"""
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
...
@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
),
"Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
),
"Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
self
.
root
=
wrapper_root
.
node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self
.
name_to_var_shape
=
{}
self
.
name_to_var_shape
=
{}
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
...
@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
self
.
visit
(
self
.
root
)
self
.
visit
(
self
.
root
)
def
visit_Assign
(
self
,
node
):
def
visit_Assign
(
self
,
node
):
if
self
.
_update_name_to_var_shape
(
node
):
update_static_shape_var_node
=
self
.
_update_name_to_var_shape
(
node
)
return
node
if
update_static_shape_var_node
is
not
None
:
ret
=
[
node
]
ret
.
extend
(
update_static_shape_var_node
)
return
ret
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
return
node
return
node
...
@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node
=
node
.
value
value_node
=
node
.
value
slice_node
=
node
.
slice
slice_node
=
node
.
slice
if
isinstance
(
value_node
,
gast
.
Name
):
if
isinstance
(
value_node
,
gast
.
Name
):
if
self
.
_is_var_shape
(
value_node
)
and
self
.
_used_by_paddle_api
(
if
value_node
.
id
in
self
.
name_to_var_shape
and
self
.
_used_by_paddle_api
(
value_node
):
var_shape_node
=
self
.
name_to_var_shape
[
value_node
.
id
]
return
create_convert_shape_node
(
var_shape_node
,
slice_node
)
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
)
and
self
.
_is_var_shape
(
value_node
):
value_node
):
return
create_choose_shape_node
(
value_node
.
id
,
self
.
name_to_var_shape
[
value_node
.
id
],
slice_node
)
elif
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
):
value_name
=
ast_to_source_code
(
value_node
).
strip
()
if
value_name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
value_name
,
self
.
name_to_var_shape
[
value_name
],
slice_node
)
if
self
.
_is_var_shape
(
value_node
):
return
create_convert_shape_node
(
value_node
,
slice_node
)
return
create_convert_shape_node
(
value_node
,
slice_node
)
return
node
return
node
def
visit_Attribute
(
self
,
node
):
def
visit_Attribute
(
self
,
node
):
if
self
.
_used_by_paddle_api
(
node
):
if
self
.
_used_by_paddle_api
(
node
):
name
=
ast_to_source_code
(
node
).
strip
()
if
name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
name
,
self
.
name_to_var_shape
[
name
])
if
self
.
_is_var_shape
(
node
):
if
self
.
_is_var_shape
(
node
):
return
create_convert_shape_node
(
node
)
return
create_convert_shape_node
(
node
)
return
node
return
node
def
visit_Name
(
self
,
node
):
def
visit_Name
(
self
,
node
):
if
self
.
_is_var_shape
(
node
)
:
if
node
.
id
in
self
.
name_to_var_shape
:
if
self
.
_used_by_paddle_api
(
node
):
if
self
.
_used_by_paddle_api
(
node
):
var_shape_node
=
self
.
name_to_var_shape
[
node
.
id
]
return
create_choose_shape_node
(
node
.
id
,
return
create_convert_shape_node
(
var_shape_node
)
self
.
name_to_var_shape
[
node
.
id
]
)
return
node
return
node
def
visit_Call
(
self
,
node
):
def
visit_Call
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Call
)
if
is_paddle_api
(
node
):
if
is_paddle_api
(
node
):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
# Don't have to visit other APIs
return
node
return
node
def
visit_If
(
self
,
node
):
def
visit_If
(
self
,
node
):
...
@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
return
False
return
False
args
=
node
.
iter
.
args
args
=
node
.
iter
.
args
for
idx
,
arg
in
enumerate
(
args
):
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
gast
.
Name
)
and
self
.
_is_var_shape
(
arg
):
if
isinstance
(
arg
,
gast
.
Name
)
and
arg
.
id
in
self
.
name_to_var_shape
:
args
[
idx
]
=
create_convert_shape_node
(
self
.
name_to_var_shape
[
args
[
idx
]
=
create_choose_shape_node
(
arg
.
id
])
arg
.
id
,
self
.
name_to_var_shape
[
arg
.
id
])
return
True
return
True
def
_transform_var_shape_if_necessary
(
self
,
cond
):
def
_transform_var_shape_if_necessary
(
self
,
cond
):
need_transformed
=
False
need_transformed
=
False
for
child_node
in
gast
.
walk
(
cond
):
for
child_node
in
gast
.
walk
(
cond
):
var_shape_node
=
None
var_shape_node
=
None
if
isinstance
(
child_node
,
(
gast
.
Attribute
,
gast
.
Subscript
)):
if
isinstance
(
child_node
,
if
self
.
_is_var_shape
(
child_node
):
(
gast
.
Name
,
gast
.
Attribute
,
gast
.
Subscript
)):
child_name
=
ast_to_source_code
(
child_node
).
strip
()
if
child_name
in
self
.
name_to_var_shape
:
var_shape_node
=
create_choose_shape_node
(
child_name
,
self
.
name_to_var_shape
[
child_name
])
elif
self
.
_is_var_shape
(
child_node
):
var_shape_node
=
child_node
var_shape_node
=
child_node
elif
isinstance
(
child_node
,
(
gast
.
Name
)):
if
self
.
_is_var_shape
(
child_node
):
var_shape_node
=
self
.
name_to_var_shape
[
child_node
.
id
]
if
var_shape_node
:
if
var_shape_node
:
need_transformed
=
True
need_transformed
=
True
...
@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
parent_node
=
wrapper_node
.
parent
.
node
parent_node
=
wrapper_node
.
parent
.
node
for
field
,
value
in
gast
.
iter_fields
(
parent_node
):
for
field
,
value
in
gast
.
iter_fields
(
parent_node
):
if
child_node
is
value
:
if
child_node
is
value
:
if
var_shape_node
is
child_node
:
setattr
(
parent_node
,
field
,
setattr
(
parent_node
,
field
,
create_convert_shape_node
(
var_shape_node
,
None
,
create_convert_shape_node
(
var_shape_node
,
True
))
None
,
True
))
else
:
setattr
(
parent_node
,
field
,
var_shape_node
)
break
break
# Some child_node may be in a list such as gast.Compare
# Some child_node may be in a list such as gast.Compare
if
isinstance
(
value
,
list
):
if
isinstance
(
value
,
list
):
has_converted_shape
=
False
has_converted_shape
=
False
for
i
,
v
in
enumerate
(
value
):
for
i
,
v
in
enumerate
(
value
):
if
child_node
is
v
:
if
child_node
is
v
:
if
var_shape_node
is
child_node
:
value
[
i
]
=
create_convert_shape_node
(
value
[
i
]
=
create_convert_shape_node
(
var_shape_node
,
None
,
True
)
var_shape_node
,
None
,
True
)
else
:
value
[
i
]
=
var_shape_node
has_converted_shape
=
True
has_converted_shape
=
True
break
break
if
has_converted_shape
:
if
has_converted_shape
:
...
@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
"""
"""
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
"""
"""
if
not
isinstance
(
node
,
(
gast
.
Name
,
gast
.
Attribute
,
gast
.
Subscript
)):
if
not
isinstance
(
node
,
(
gast
.
Attribute
,
gast
.
Subscript
)):
return
False
return
False
if
isinstance
(
node
,
gast
.
Name
)
and
node
.
id
in
self
.
name_to_var_shape
:
return
True
if
isinstance
(
node
,
gast
.
Attribute
):
if
isinstance
(
node
,
gast
.
Attribute
):
if
node
.
attr
!=
'shape'
:
if
node
.
attr
!=
'shape'
:
return
False
return
False
if
not
isinstance
(
node
.
value
,
gast
.
Name
):
return
False
return
True
return
True
if
isinstance
(
node
,
gast
.
Subscript
):
if
isinstance
(
node
,
gast
.
Subscript
):
...
@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer):
target_node
=
node
.
targets
[
0
]
target_node
=
node
.
targets
[
0
]
value_node
=
node
.
value
value_node
=
node
.
value
update_static_shape_var_node
=
None
if
isinstance
(
target_node
,
gast
.
Tuple
):
if
isinstance
(
target_node
,
gast
.
Tuple
):
has_updated
=
False
update_static_shape_var_node
=
[]
for
idx
,
element
in
enumerate
(
target_node
.
elts
):
for
idx
,
element
in
enumerate
(
target_node
.
elts
):
target_id
=
ast_to_source_code
(
element
).
strip
()
target_id
=
ast_to_source_code
(
element
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
if
value_node
.
id
in
self
.
name_to_var_shape
:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
static_shape_value_node
=
gast
.
parse
(
static_shape_value_name
).
body
[
0
].
value
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
var_shape_node
=
self
.
name_to_var_shape
[
value_node
.
id
]
sub_node
=
gast
.
Subscript
(
sub_node
=
gast
.
Subscript
(
value
=
var_shap
e_node
,
value
=
static_shape_valu
e_node
,
slice
=
slice_index_node
,
slice
=
slice_index_node
,
ctx
=
gast
.
Load
())
ctx
=
gast
.
Load
())
self
.
name_to_var_shape
[
target_id
]
=
sub_node
has_updated
=
True
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
sub_node
=
gast
.
Subscript
(
sub_node
=
gast
.
Subscript
(
value
=
value_node
,
value
=
static_shape_
value_node
,
slice
=
slice_index_node
,
slice
=
slice_index_node
,
ctx
=
gast
.
Load
())
ctx
=
gast
.
Load
())
self
.
name_to_var_shape
[
target_id
]
=
sub_node
has_updated
=
True
return
has_updated
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
else
:
else
:
target_id
=
ast_to_source_code
(
target_node
).
strip
()
target_id
=
ast_to_source_code
(
target_node
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
isinstance
(
value_node
,
gast
.
Name
):
if
self
.
_is_var_shape
(
value_node
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
self
.
name_to_var_shape
[
target_id
]
=
self
.
name_to_var_shape
[
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
value_node
.
id
]
return
True
static_shape_value_node
=
gast
.
parse
(
if
isinstance
(
value_node
,
gast
.
Attribute
):
static_shape_value_name
).
body
[
0
].
value
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
self
.
name_to_var_shape
[
target_id
]
=
value_node
update_static_shape_var_node
=
[
return
True
gast
.
Assign
(
if
isinstance
(
value_node
,
gast
.
Subscript
):
targets
=
[
static_shape_var_node
],
if
isinstance
(
value_node
.
value
,
gast
.
Attribute
):
value
=
static_shape_value_node
)
if
self
.
_is_var_shape
(
value_node
.
value
):
# eg: x.shape[0]
]
self
.
name_to_var_shape
[
target_id
]
=
value_node
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
True
elif
self
.
_is_var_shape
(
value_node
):
# eg: x.shape or x.shape[0]
return
False
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
update_static_shape_var_node
=
[
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
)
]
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
浏览文件 @
3a72408f
...
@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase):
...
@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle
.
disable_static
()
paddle
.
disable_static
()
class
TestChooseShapeAttrOrApi
(
unittest
.
TestCase
):
def
test_api_shape_is_none
(
self
):
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
,
2
],
None
),
[
1
,
2
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
],
None
),
[
1
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
2
,
3
,
7
],
None
,
0
),
2
)
def
test_attr_shape_is_int
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
0
],
paddle
.
shape
(
x
)[
0
]),
1
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
1
],
paddle
.
shape
(
x
)[
1
]),
3
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
)[
0
]),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
def
test_positive_attr_shape
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
)),
x
.
shape
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
),
3
),
x
.
shape
[
3
])
def
test_negative_attr_shape
(
self
):
x
=
paddle
.
zeros
([
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
)),
paddle
.
shape
(
x
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
浏览文件 @
3a72408f
...
@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
...
@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
return
res
return
res
def
dyfunc_tensor_shape_6
(
x
):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1,
# paddle.jit.dy2static.convert_var_shape(x)[0:]))`
x
=
fluid
.
dygraph
.
to_variable
(
x
)
s
=
x
.
shape
[
0
:]
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
s
)
return
res
def
dyfunc_tuple_shape_1
(
x
):
def
dyfunc_tuple_shape_1
(
x
):
x
=
paddle
.
to_tensor
(
x
)
x
=
paddle
.
to_tensor
(
x
)
a
,
b
=
x
.
shape
a
,
b
=
x
.
shape
...
@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x):
...
@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x):
return
x
return
x
def
dyfunc_change_shape_after_assign
(
x
):
x
=
paddle
.
to_tensor
(
x
)
a
,
b
=
x
.
shape
x
=
paddle
.
reshape
(
x
,
shape
=
(
-
1
,
1
))
res
=
paddle
.
reshape
(
x
,
shape
=
(
b
,
a
))
return
res
# 1. Basic tests without control flow
# 1. Basic tests without control flow
class
TestTensorShapeBasic
(
unittest
.
TestCase
):
class
TestTensorShapeBasic
(
unittest
.
TestCase
):
def
setUp
(
self
):
def
setUp
(
self
):
...
@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
...
@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
def
init_test_func
(
self
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_tensor_shape_5
self
.
dygraph_func
=
dyfunc_tensor_shape_5
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTensorShapeBasic6
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_tensor_shape_6
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTupleShape1
(
TestTensorShapeBasic
):
class
TestTupleShape1
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
def
init_test_func
(
self
):
...
@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
...
@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_with_if_1
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
26
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
2
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
2
self
.
expected_slice_op_num
=
1
class
TestTensorShapeInIf2
(
TestTensorShapeBasic
):
class
TestTensorShapeInIf2
(
TestTensorShapeBasic
):
...
@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
...
@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
def
init_test_func
(
self
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_for_2
self
.
dygraph_func
=
dyfunc_with_for_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
9
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
# 4. Tests with control flow while loop
# 4. Tests with control flow while loop
class
TestTensorShapeInWhile1
(
TestTensorShapeInFor1
):
class
TestTensorShapeInWhile1
(
TestTensorShapeInFor1
):
...
@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
...
@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
def
init_test_func
(
self
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_while_2
self
.
dygraph_func
=
dyfunc_with_while_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTensorShapeInWhile3
(
TestTensorShapeBasic
):
class
TestTensorShapeInWhile3
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_while_3
self
.
dygraph_func
=
dyfunc_with_while_3
def
_set_expected_op_num
(
self
):
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
2
5
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
6
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
3
self
.
expected_slice_op_num
=
0
class
TestTensorShapeInWhile4
(
TestTensorShapeBasic
):
class
TestTensorShapeInWhile4
(
TestTensorShapeBasic
):
...
@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
...
@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_tuple_shape_1
self
.
dygraph_func
=
dyfunc_tuple_shape_1
def
_set_expected_op_num
(
self
):
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
5
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
1
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
1
self
.
expected_slice_op_num
=
0
class
TestOpNumWithTensorShapeInIf1
(
TestOpNumBasicWithTensorShape
):
class
TestOpNumWithTensorShapeInIf1
(
TestOpNumBasicWithTensorShape
):
...
@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
...
@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_with_if_1
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
28
self
.
expected_op_num
=
19
self
.
expected_shape_op_num
=
4
self
.
expected_shape_op_num
=
4
self
.
expected_slice_op_num
=
2
self
.
expected_slice_op_num
=
2
...
@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
...
@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self
.
expected_slice_op_num
=
3
self
.
expected_slice_op_num
=
3
class
TestChangeShapeAfterAssign
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
input
=
numpy
.
ones
((
2
,
3
)).
astype
(
"int32"
)
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
2
,
3
],
dtype
=
"int32"
)]
self
.
dygraph_func
=
dyfunc_change_shape_after_assign
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
3
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
python/paddle/jit/dy2static/convert_operators.py
浏览文件 @
3a72408f
...
@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print
...
@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_shape_compare
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_shape_compare
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_dtype
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_dtype
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape_simple
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
eval_if_exist_else_none
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
choose_shape_attr_or_api
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_while_loop
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_while_loop
#DEFINE_ALIAS
__all__
=
[
__all__
=
[
'cast_bool_if_necessary'
,
'convert_assert'
,
'convert_ifelse'
,
'convert_len'
,
'cast_bool_if_necessary'
,
'convert_assert'
,
'convert_ifelse'
,
'convert_len'
,
'convert_logical_and'
,
'convert_logical_not'
,
'convert_logical_or'
,
'convert_logical_and'
,
'convert_logical_not'
,
'convert_logical_or'
,
'convert_pop'
,
'convert_print'
,
'convert_shape_compare'
,
'convert_pop'
,
'convert_print'
,
'convert_shape_compare'
,
'convert_var_dtype'
,
'convert_var_shape'
,
'convert_while_loop'
'convert_var_dtype'
,
'convert_var_shape'
,
'convert_var_shape_simple'
,
'eval_if_exist_else_none'
,
'choose_shape_attr_or_api'
,
'convert_while_loop'
]
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录