Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
d82d5b8c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
d82d5b8c
编写于
6月 27, 2022
作者:
A
Aurelius84
提交者:
GitHub
6月 27, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat]Refactor convert_shape transformer logic (#43846)
* [Dy2Stat]Refactor convert_shape transformer logic * clean usless unittest
上级
a5dc0a79
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
73 addition
and
610 deletion
+73
-610
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+16
-74
python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py
...le/fluid/dygraph/dygraph_to_static/logical_transformer.py
+0
-22
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+10
-354
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
...sts/unittests/dygraph_to_static/test_convert_operators.py
+0
-103
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+45
-49
python/paddle/jit/dy2static/__init__.py
python/paddle/jit/dy2static/__init__.py
+1
-4
python/paddle/jit/dy2static/convert_operators.py
python/paddle/jit/dy2static/convert_operators.py
+1
-4
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
d82d5b8c
...
...
@@ -338,88 +338,30 @@ def convert_zip(*args):
return
zip
(
*
args
)
def
convert_
var_shape
(
x
,
idx
=
None
,
in_control_flow
=
False
):
def
convert_
shape
(
x
):
"""
A function representation of the shape of variable.
"""
def
has_negative
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
num_negative
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_negative
>
0
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
# ```
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is negative
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if
isinstance
(
x
,
Variable
)
and
has_negative
(
x
.
shape
,
idx
):
return
nn
.
shape
(
x
)
if
idx
is
None
else
nn
.
shape
(
x
)[
idx
]
else
:
return
list
(
x
.
shape
)
if
idx
is
None
else
x
.
shape
[
idx
]
def
has_negative
(
list_shape
):
return
any
([
x
<
0
for
x
in
list_shape
])
# When `x` is Variable:
# (1) if x.shape contains -1, such as [2, -1, 64], returns [2, var, 64],
# where var = paddle.shape(x)[1]
# (2) if x.shape does not contains -1, return lsit(x.shape) directly
def
convert_var_shape_simple
(
x
):
"""
A function representation of the shape of variable.
"""
if
isinstance
(
x
,
Variable
):
return
nn
.
shape
(
x
)
values
=
list
(
x
.
shape
)
if
has_negative
(
values
):
shape_tensor
=
nn
.
shape
(
x
)
for
i
,
v
in
enumerate
(
values
):
if
v
is
None
or
v
<
0
:
values
[
i
]
=
shape_tensor
[
i
]
return
values
else
:
# Use list() to make returned type consistant with dygraph
return
list
(
x
.
shape
)
def
eval_if_exist_else_none
(
name
,
global_symbol_table
):
"""
Args:
name([str]): Expression passed into `eval`.
local_symbol_table(dict): Specified from `globals()`. DO NOT use `locals()`,
because all STATIC_CONVERT_VAR_SHAPE_SUFFIX vars is
declared with keyword `global`.
Returns:
Return the variable if found in global_symbol_table else None.
"""
try
:
return
eval
(
name
,
global_symbol_table
)
except
:
return
None
def
choose_shape_attr_or_api
(
attr_shape
,
api_shape
,
idx
=
None
):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if
api_shape
is
None
:
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
if
not
isinstance
(
attr_shape
,
(
list
,
tuple
)):
# some variables like x.shape[0] is no longer a list or tuple
if
isinstance
(
attr_shape
,
int
)
and
attr_shape
<
0
:
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
def
has_negative
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
num_negative
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_negative
>
0
if
has_negative
(
attr_shape
,
idx
):
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
return
x
.
shape
def
convert_shape_compare
(
left
,
*
args
):
...
...
python/paddle/fluid/dygraph/dygraph_to_static/logical_transformer.py
浏览文件 @
d82d5b8c
...
...
@@ -63,28 +63,6 @@ class LogicalTransformer(gast.NodeTransformer):
return
new_node
return
node
def
visit_Compare
(
self
,
node
):
self
.
generic_visit
(
node
)
left_str
=
ast_to_source_code
(
node
.
left
).
strip
()
if
left_str
.
startswith
(
"_jst.convert_var_shape"
):
# check left and comparators are all converted var shape
compare_arg_strs
=
left_str
for
i
,
comparator
in
enumerate
(
node
.
comparators
):
comparator_str
=
ast_to_source_code
(
comparator
).
strip
()
if
not
comparator_str
.
startswith
(
"_jst.convert_var_shape"
):
return
node
op_str
=
cmpop_node_to_str
(
node
.
ops
[
i
])
compare_arg_strs
+=
(
", '"
+
op_str
+
"', "
+
comparator_str
)
# Now all left and comparators are converted shape
# Replace some comparsion operation because of difference between
# Python and Paddle
new_node_str
=
"_jst.convert_shape_compare({})"
.
format
(
compare_arg_strs
)
new_node
=
gast
.
parse
(
new_node_str
).
body
[
0
].
value
return
new_node
return
node
def
visit_BoolOp
(
self
,
node
):
self
.
generic_visit
(
node
)
if
isinstance
(
node
.
op
,
gast
.
And
):
...
...
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
d82d5b8c
...
...
@@ -25,77 +25,11 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX
=
'__static_convert_var_shape_suffix'
def
create_convert_shape_node
(
var_shape_node
,
slice_node
=
None
,
in_control_flow
=
False
):
assert
isinstance
(
var_shape_node
,
(
gast
.
Attribute
,
gast
.
Subscript
))
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index or gast.Constant
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index or gast.Constant
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if
slice_node
is
not
None
and
slice_is_num
(
slice_node
):
args
.
append
(
ast_to_source_code
(
slice_node
.
slice
).
strip
())
convert_var_shape_func
=
"_jst.convert_var_shape({}, in_control_flow={})"
.
format
(
","
.
join
(
args
),
in_control_flow
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
slice_is_num
(
slice_node
):
return
gast
.
Subscript
(
value
=
api_shape_node
,
slice
=
slice_node
.
slice
,
ctx
=
gast
.
Load
())
return
api_shape_node
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
result_node
=
copy
.
deepcopy
(
var_shape_node
)
result_node
=
create_convert_shape_node
(
result_node
.
value
,
result_node
,
in_control_flow
)
return
result_node
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
eval_exist_func
=
"_jst.eval_if_exist_else_none('{}', globals())"
.
format
(
api_shape_name
)
args
=
[
attr_shape_name
,
eval_exist_func
]
if
slice_node
is
not
None
and
slice_is_num
(
slice_node
):
args
.
append
(
ast_to_source_code
(
slice_node
.
slice
).
strip
())
choose_shape_func
=
"_jst.choose_shape_attr_or_api({})"
.
format
(
","
.
join
(
args
))
choose_shape_node
=
gast
.
parse
(
choose_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
slice_is_num
(
slice_node
):
return
gast
.
Subscript
(
value
=
choose_shape_node
,
slice
=
slice_node
.
slice
,
ctx
=
gast
.
Load
())
return
choose_shape_node
class
ShapeAttributeTransformer
(
gast
.
NodeTransformer
):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def
visit_Attribute
(
self
,
node
):
if
node
.
attr
==
'shape'
:
args
=
ast_to_source_code
(
node
.
value
).
strip
()
convert_var_shape_func
=
"_jst.convert_var_shape_simple({})"
.
format
(
args
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
return
api_shape_node
return
node
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
"""
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
This class transforms variable.shape into Static Graph Ast.
All 'xxx.shape' will be converted int '_jst.convert_shape(x)'.
"""
def
__init__
(
self
,
wrapper_root
):
...
...
@@ -104,295 +38,17 @@ class TensorShapeTransformer(gast.NodeTransformer):
),
"Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self
.
name_to_var_shape
=
{}
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
self
.
node_to_wrapper_map
=
self
.
static_analysis_visitor
.
get_node_to_wrapper_map
(
)
var_env
=
self
.
static_analysis_visitor
.
get_var_env
()
var_env
.
cur_scope
=
var_env
.
cur_scope
.
sub_scopes
[
0
]
self
.
scope_var_type_dict
=
var_env
.
get_scope_var_type
()
def
transform
(
self
):
SplitAssignTransformer
(
self
.
root
).
transform
()
self
.
visit
(
self
.
root
)
def
visit_Assign
(
self
,
node
):
update_static_shape_var_node
=
self
.
_update_name_to_var_shape
(
node
)
if
update_static_shape_var_node
is
not
None
:
ret
=
[
node
]
ret
.
extend
(
update_static_shape_var_node
)
return
ret
self
.
generic_visit
(
node
)
return
node
def
visit_Subscript
(
self
,
node
):
value_node
=
node
.
value
slice_node
=
node
.
slice
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
and
self
.
_used_by_paddle_api
(
value_node
):
return
create_choose_shape_node
(
value_node
.
id
,
self
.
name_to_var_shape
[
value_node
.
id
],
node
)
elif
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
):
value_name
=
ast_to_source_code
(
value_node
).
strip
()
if
value_name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
value_name
,
self
.
name_to_var_shape
[
value_name
],
node
)
if
self
.
_is_var_shape
(
value_node
):
return
create_convert_shape_node
(
value_node
,
node
)
return
node
def
visit_Attribute
(
self
,
node
):
if
self
.
_used_by_paddle_api
(
node
):
name
=
ast_to_source_code
(
node
).
strip
()
if
name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
name
,
self
.
name_to_var_shape
[
name
])
if
self
.
_is_var_shape
(
node
):
return
create_convert_shape_node
(
node
)
return
node
def
visit_Name
(
self
,
node
):
if
node
.
id
in
self
.
name_to_var_shape
:
if
self
.
_used_by_paddle_api
(
node
):
return
create_choose_shape_node
(
node
.
id
,
self
.
name_to_var_shape
[
node
.
id
])
return
node
def
visit_Call
(
self
,
node
):
if
is_paddle_api
(
node
):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self
.
generic_visit
(
node
)
# Don't have to visit other APIs
return
node
def
visit_If
(
self
,
node
):
# Call generic_visit first to transform var.shape that is used in Paddle Api.
self
.
generic_visit
(
node
)
cond
=
node
.
test
self
.
_transform_var_shape_if_necessary
(
cond
)
return
node
def
visit_While
(
self
,
node
):
self
.
generic_visit
(
node
)
cond
=
node
.
test
self
.
_transform_var_shape_if_necessary
(
cond
)
return
node
def
visit_For
(
self
,
node
):
self
.
generic_visit
(
node
)
iter
=
node
.
iter
self
.
_transform_var_shape_if_necessary
(
iter
)
# If var.shape is a gast.Name and it is used in range function, transform it
self
.
_transform_var_shape_in_range
(
node
)
if
node
.
attr
==
'shape'
:
args
=
ast_to_source_code
(
node
.
value
).
strip
()
# NOTE(dev): we can deal with paddle.shape in this case, but it's
# not pretty to modify into 'convert_shape(paddle)(x)[0]'.
if
args
!=
'paddle'
:
convert_shape_func
=
"_jst.convert_shape({})"
.
format
(
args
)
shape_node
=
gast
.
parse
(
convert_shape_func
).
body
[
0
].
value
return
shape_node
return
node
def
_transform_var_shape_in_range
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
For
)
if
not
isinstance
(
node
.
iter
,
gast
.
Call
):
return
False
if
not
isinstance
(
node
.
iter
.
func
,
gast
.
Name
):
return
False
if
node
.
iter
.
func
.
id
!=
"range"
:
return
False
args
=
node
.
iter
.
args
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
gast
.
Name
)
and
arg
.
id
in
self
.
name_to_var_shape
:
args
[
idx
]
=
create_choose_shape_node
(
arg
.
id
,
self
.
name_to_var_shape
[
arg
.
id
])
return
True
def
_transform_var_shape_if_necessary
(
self
,
cond
):
need_transformed
=
False
for
child_node
in
gast
.
walk
(
cond
):
var_shape_node
=
None
if
isinstance
(
child_node
,
(
gast
.
Name
,
gast
.
Attribute
,
gast
.
Subscript
)):
child_name
=
ast_to_source_code
(
child_node
).
strip
()
if
child_name
in
self
.
name_to_var_shape
:
var_shape_node
=
create_choose_shape_node
(
child_name
,
self
.
name_to_var_shape
[
child_name
])
elif
self
.
_is_var_shape
(
child_node
):
var_shape_node
=
child_node
if
var_shape_node
:
need_transformed
=
True
wrapper_node
=
self
.
node_to_wrapper_map
.
get
(
child_node
)
parent_node
=
wrapper_node
.
parent
.
node
for
field
,
value
in
gast
.
iter_fields
(
parent_node
):
if
child_node
is
value
:
if
var_shape_node
is
child_node
:
setattr
(
parent_node
,
field
,
create_convert_shape_node
(
var_shape_node
,
None
,
True
))
else
:
setattr
(
parent_node
,
field
,
var_shape_node
)
break
# Some child_node may be in a list such as gast.Compare
if
isinstance
(
value
,
list
):
has_converted_shape
=
False
for
i
,
v
in
enumerate
(
value
):
if
child_node
is
v
:
if
var_shape_node
is
child_node
:
value
[
i
]
=
create_convert_shape_node
(
var_shape_node
,
None
,
True
)
else
:
value
[
i
]
=
var_shape_node
has_converted_shape
=
True
break
if
has_converted_shape
:
break
return
need_transformed
def
_used_by_paddle_api
(
self
,
node
):
"""
Whether node is used in paddle api as arguments.
For example:
1) Return True in `paddle.relu(x)` where node is `x` (gast.Name)
2) Return True in `paddle.add(self.x)` where node is `self.x` (gast.Attribute)
3) Return False in `paddle.add(self.x)` where node is `paddle.add` (gast.Attribute),
because the role of node is not arguments but `gast.Call.func`.
"""
assert
isinstance
(
node
,
(
gast
.
Attribute
,
gast
.
Name
))
wrapper_node
=
self
.
node_to_wrapper_map
.
get
(
node
)
if
not
wrapper_node
:
# Transformed node is not in node_to_wrapper_map
return
False
while
wrapper_node
.
parent
:
parent_node
=
wrapper_node
.
parent
.
node
if
isinstance
(
parent_node
,
gast
.
Call
):
# Note(Aurelius84): Filter the case when the role of node is `gast.Call.func`.
if
is_paddle_api
(
parent_node
)
and
parent_node
.
func
!=
node
:
return
True
else
:
return
False
wrapper_node
=
wrapper_node
.
parent
return
False
def
_is_var_shape
(
self
,
node
):
"""
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
"""
if
not
isinstance
(
node
,
(
gast
.
Attribute
,
gast
.
Subscript
)):
return
False
if
isinstance
(
node
,
gast
.
Attribute
):
# If node is `paddle.shape`, return False
if
(
node
.
attr
==
'shape'
and
isinstance
(
node
.
value
,
gast
.
Name
)
and
node
.
value
.
id
==
'paddle'
):
return
False
if
node
.
attr
!=
'shape'
:
return
False
return
True
if
isinstance
(
node
,
gast
.
Subscript
):
value_node
=
node
.
value
return
self
.
_is_var_shape
(
value_node
)
return
False
def
_update_name_to_var_shape
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
value_node
=
node
.
value
update_static_shape_var_node
=
None
if
isinstance
(
target_node
,
gast
.
Tuple
):
update_static_shape_var_node
=
[]
for
idx
,
element
in
enumerate
(
target_node
.
elts
):
target_id
=
ast_to_source_code
(
element
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name
=
unique_name
.
generate
(
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
sub_node_str
=
"{}[{}]"
.
format
(
static_shape_value_name
,
idx
)
sub_node
=
gast
.
parse
(
sub_node_str
).
body
[
0
].
value
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
static_shape_var_name
=
unique_name
.
generate
(
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node
=
ShapeAttributeTransformer
(
).
visit
(
static_shape_value_node
)
sub_node_str
=
"{}[{}]"
.
format
(
ast_to_source_code
(
static_shape_value_node
).
strip
(),
idx
)
sub_node
=
gast
.
parse
(
sub_node_str
).
body
[
0
].
value
# Note(Aurelius84): Becuase static_shape_var_name is used in
# eval_if_exist_else_none() as plain string, so it will not
# be pasred as argument in convert_loop/ifelse. We delcare it
# as global var because it has unique name.
update_static_shape_var_node
.
append
(
gast
.
Global
(
names
=
[
static_shape_var_name
]))
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
else
:
target_id
=
ast_to_source_code
(
target_node
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
static_shape_var_name
=
unique_name
.
generate
(
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
static_shape_value_node
=
gast
.
parse
(
static_shape_value_name
).
body
[
0
].
value
update_static_shape_var_node
=
[
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
)
]
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
elif
self
.
_is_var_shape
(
value_node
):
# eg: x.shape or x.shape[0]
static_shape_var_name
=
unique_name
.
generate
(
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node
=
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
# Declare static_shape_var_name as global var
update_static_shape_var_node
=
[
gast
.
Global
(
names
=
[
static_shape_var_name
])
]
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
浏览文件 @
d82d5b8c
...
...
@@ -15,7 +15,6 @@
import
numpy
as
np
import
paddle
import
unittest
from
paddle.jit.dy2static.convert_operators
import
eval_if_exist_else_none
class
CallNotExist
(
paddle
.
nn
.
Layer
):
...
...
@@ -143,108 +142,6 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle
.
disable_static
()
class
TestChooseShapeAttrOrApi
(
unittest
.
TestCase
):
def
test_api_shape_is_none
(
self
):
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
,
2
],
None
),
[
1
,
2
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
],
None
),
[
1
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
2
,
3
,
7
],
None
,
0
),
2
)
def
test_attr_shape_is_int
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
0
],
paddle
.
shape
(
x
)[
0
]),
1
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
1
],
paddle
.
shape
(
x
)[
1
]),
3
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
)[
0
]),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
def
test_positive_attr_shape
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
)),
x
.
shape
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
),
3
),
x
.
shape
[
3
])
def
test_negative_attr_shape
(
self
):
x
=
paddle
.
zeros
([
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
)),
paddle
.
shape
(
x
))
class
TestEvaIfExistElseNone
(
unittest
.
TestCase
):
def
test_globals
(
self
):
global
x_shape
x_shape
=
[
1
,
2
,
3
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
None
)
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
x_shape
)
del
x_shape
def
test_enclosing_scope
(
self
):
global
x_shape
x_shape
=
[
1
,
2
,
3
]
def
foo
():
y_shape
=
[
2
,
3
,
4
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
[
1
,
2
,
3
])
self
.
assertEqual
(
eval_if_exist_else_none
(
'y_shape'
,
locals
()),
[
2
,
3
,
4
])
foo
()
del
x_shape
def
test_global_in_func
(
self
):
x_shape
=
[
1
,
2
,
3
]
def
foo
():
global
y_shape
y_shape
=
[
2
,
3
,
4
]
self
.
assertEqual
(
eval_if_exist_else_none
(
'y_shape'
,
globals
()),
[
2
,
3
,
4
])
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
None
)
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
globals
()),
None
)
del
y_shape
foo
()
def
test_none
(
self
):
def
foo
():
x_shape
=
[
2
,
3
,
4
]
return
x_shape
self
.
assertEqual
(
eval_if_exist_else_none
(
'x_shape'
,
locals
()),
None
)
class
ShapeLayer
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
浏览文件 @
d82d5b8c
...
...
@@ -275,6 +275,7 @@ class TestTensorShapeBasic(unittest.TestCase):
self
.
expected_slice_op_num
=
0
def
_compute_op_num
(
self
,
program
):
print
(
program
)
self
.
op_num
=
sum
([
len
(
block
.
ops
)
for
block
in
program
.
blocks
])
self
.
shape_op_num
=
0
self
.
slice_op_num
=
0
...
...
@@ -300,8 +301,8 @@ class TestTensorShapeBasic2(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_tensor_shape_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
3
self
.
expected_shape_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
...
...
@@ -323,9 +324,9 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_tensor_shape_5
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTensorShapeBasic6
(
TestTensorShapeBasic
):
...
...
@@ -334,21 +335,23 @@ class TestTensorShapeBasic6(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_tensor_shape_6
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTupleShape1
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
input
=
numpy
.
ones
((
5
,
7
)).
astype
(
"int32"
)
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
5
,
7
],
dtype
=
"int32"
)]
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
-
1
,
-
1
],
dtype
=
"int32"
)
]
self
.
dygraph_func
=
dyfunc_tuple_shape_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
2
self
.
expected_op_num
=
5
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
2
...
...
@@ -356,13 +359,15 @@ class TestTupleShape2(TestTensorShapeBasic):
def
init_test_func
(
self
):
self
.
input
=
numpy
.
ones
((
5
,
7
)).
astype
(
"int32"
)
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
5
,
7
],
dtype
=
"int32"
)]
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
-
1
,
7
],
dtype
=
"int32"
)
]
self
.
dygraph_func
=
dyfunc_tuple_shape_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
5
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
2
self
.
expected_slice_op_num
=
1
class
TestTupleShape3
(
TestTensorShapeBasic
):
...
...
@@ -398,9 +403,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTensorShapeInIf2
(
TestTensorShapeBasic
):
...
...
@@ -432,9 +437,9 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
self
.
dygraph_func
=
dyfunc_with_for_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
9
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
7
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTensorShapeInFor3
(
TestTensorShapeInFor1
):
...
...
@@ -466,9 +471,9 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
self
.
dygraph_func
=
dyfunc_with_while_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTensorShapeInWhile3
(
TestTensorShapeBasic
):
...
...
@@ -477,8 +482,8 @@ class TestTensorShapeInWhile3(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_with_while_3
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
3
self
.
expected_shape_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
...
...
@@ -510,9 +515,9 @@ class TestOpNumBasicWithTensorShape(unittest.TestCase):
self
.
dygraph_func
=
dyfunc_tensor_shape_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
3
self
.
expected_op_num
=
5
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
0
self
.
expected_slice_op_num
=
1
def
_compute_op_num
(
self
,
program
):
self
.
op_num
=
sum
([
len
(
block
.
ops
)
for
block
in
program
.
blocks
])
...
...
@@ -541,9 +546,9 @@ class TestOpNumBasicWithTensorShape4(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_tensor_shape_4
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
8
self
.
expected_shape_op_num
=
2
self
.
expected_slice_op_num
=
2
class
TestOpNumWithTensorShapeTuple1
(
TestOpNumBasicWithTensorShape
):
...
...
@@ -552,9 +557,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_tuple_shape_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
7
self
.
expected_shape_op_num
=
2
self
.
expected_slice_op_num
=
2
self
.
expected_op_num
=
5
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestOpNumWithTensorShapeInIf1
(
TestOpNumBasicWithTensorShape
):
...
...
@@ -563,9 +568,9 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
28
self
.
expected_op_num
=
32
self
.
expected_shape_op_num
=
4
self
.
expected_slice_op_num
=
2
self
.
expected_slice_op_num
=
4
class
TestOpNumWithTensorShapeInFor1
(
TestOpNumBasicWithTensorShape
):
...
...
@@ -594,13 +599,15 @@ class TestChangeShapeAfterAssign(TestTensorShapeBasic):
def
init_test_func
(
self
):
self
.
input
=
numpy
.
ones
((
2
,
3
)).
astype
(
"int32"
)
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
2
,
3
],
dtype
=
"int32"
)]
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
-
1
,
3
],
dtype
=
"int32"
)
]
self
.
dygraph_func
=
dyfunc_change_shape_after_assign
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
7
self
.
expected_shape_op_num
=
2
self
.
expected_slice_op_num
=
2
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
def
dyfunc_with_static_convert_var_shape
(
x
):
...
...
@@ -627,16 +634,5 @@ class TestFindStatiConvertVarShapeSuffixVar(unittest.TestCase):
func
.
concrete_program
class
TestPaddleShape
(
unittest
.
TestCase
):
def
test_paddle_shape
(
self
):
func
=
paddle
.
jit
.
to_static
(
dyfunc_len_paddle_shape
)
func_code
=
func
.
code
.
replace
(
"
\n
"
,
""
).
replace
(
" "
,
""
)
self
.
assertEqual
(
'paddle.shape(x)'
in
func_code
,
True
)
func
=
paddle
.
jit
.
to_static
(
dyfunc_dict_assign_shape
)
func_code
=
func
.
code
.
replace
(
"
\n
"
,
""
).
replace
(
" "
,
""
)
self
.
assertEqual
(
"__static_convert_var_shape_suffix"
in
func_code
,
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/jit/dy2static/__init__.py
浏览文件 @
d82d5b8c
...
...
@@ -26,10 +26,7 @@ from .convert_operators import convert_pop # noqa: F401
from
.convert_operators
import
convert_print
# noqa: F401
from
.convert_operators
import
convert_shape_compare
# noqa: F401
from
.convert_operators
import
convert_var_dtype
# noqa: F401
from
.convert_operators
import
convert_var_shape
# noqa: F401
from
.convert_operators
import
convert_var_shape_simple
# noqa: F401
from
.convert_operators
import
eval_if_exist_else_none
# noqa: F401
from
.convert_operators
import
choose_shape_attr_or_api
# noqa: F401
from
.convert_operators
import
convert_shape
# noqa: F401
from
.convert_operators
import
convert_while_loop
# noqa: F401
from
.variable_trans_func
import
create_bool_as_type
# noqa: F401
from
.variable_trans_func
import
create_fill_constant_node
# noqa: F401
...
...
python/paddle/jit/dy2static/convert_operators.py
浏览文件 @
d82d5b8c
...
...
@@ -24,10 +24,7 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_pop #
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_print
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_shape_compare
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_dtype
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape_simple
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
eval_if_exist_else_none
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
choose_shape_attr_or_api
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_shape
# noqa: F401
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_while_loop
# noqa: F401
__all__
=
[]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录