Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3a72408f
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3a72408f
编写于
2月 23, 2021
作者:
H
Huihuang Zheng
提交者:
GitHub
2月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick][Dy2stat] Cherry-pick of PR31082 and PR31051 (#31101)
Cherry-pick of #31051 and #31082
上级
29467060
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
335 addition
and
79 deletion
+335
-79
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
...ddle/fluid/dygraph/dygraph_to_static/convert_operators.py
+49
-5
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+163
-63
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
...sts/unittests/dygraph_to_static/test_convert_operators.py
+53
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
...id/tests/unittests/dygraph_to_static/test_tensor_shape.py
+65
-10
python/paddle/jit/dy2static/convert_operators.py
python/paddle/jit/dy2static/convert_operators.py
+5
-1
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
浏览文件 @
3a72408f
...
...
@@ -267,12 +267,12 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
A function representation of the shape of variable.
"""
def
has_neg
e
tive
(
list_shape
,
idx
=
None
):
def
has_neg
a
tive
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
num_neg
e
tive
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_neg
e
tive
>
0
num_neg
a
tive
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_neg
a
tive
>
0
# When `x` is Variable, call nn.shape(x) in following cases:
# (1) The shape of `x` is used in control flow condition.
...
...
@@ -280,18 +280,62 @@ def convert_var_shape(x, idx=None, in_control_flow=False):
# if x.shape[0] == 1:
# y = XX
# ```
# (2) The dim to be used is neg
e
tive
# (2) The dim to be used is neg
a
tive
# ```
# # Assume x.shape=[3, -1] in static mode
# y = paddle.reshape(x, shape=[1, x.shape[1]])
# ```
if
isinstance
(
x
,
Variable
)
and
(
in_control_flow
or
has_neg
e
tive
(
x
.
shape
,
if
isinstance
(
x
,
Variable
)
and
(
in_control_flow
or
has_neg
a
tive
(
x
.
shape
,
idx
)):
return
nn
.
shape
(
x
)
if
idx
is
None
else
nn
.
shape
(
x
)[
idx
]
else
:
return
x
.
shape
if
idx
is
None
else
x
.
shape
[
idx
]
def
convert_var_shape_simple
(
x
):
"""
A function representation of the shape of variable.
"""
if
isinstance
(
x
,
Variable
):
return
nn
.
shape
(
x
)
else
:
return
x
.
shape
def
eval_if_exist_else_none
(
name
):
try
:
return
eval
(
name
)
except
:
return
None
def
choose_shape_attr_or_api
(
attr_shape
,
api_shape
,
idx
=
None
):
"""
Input can be attribute `x.shape` or api `shape(x)`, this function
chooses which one to return to use in dy2stat.
Note: sometimes users write `x.shape[3]`, so attr_shape can be an integer.
"""
if
api_shape
is
None
:
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
if
not
isinstance
(
attr_shape
,
(
list
,
tuple
)):
# some variables like x.shape[0] is no longer a list or tuple
if
isinstance
(
attr_shape
,
int
)
and
attr_shape
<
0
:
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
def
has_negative
(
list_shape
,
idx
=
None
):
if
idx
is
not
None
:
return
list_shape
[
idx
]
<
0
num_negative
=
sum
([
1
if
i
<
0
else
0
for
i
in
list_shape
])
return
num_negative
>
0
if
has_negative
(
attr_shape
,
idx
):
return
api_shape
if
idx
is
None
else
api_shape
[
idx
]
return
attr_shape
if
idx
is
None
else
attr_shape
[
idx
]
def
convert_shape_compare
(
left
,
*
args
):
"""
A function handles comparison difference between Paddle and Python.
...
...
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
3a72408f
...
...
@@ -17,12 +17,15 @@ from __future__ import print_function
import
copy
import
gast
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
StaticAnalysisVisitor
STATIC_CONVERT_VAR_SHAPE_SUFFIX
=
'__static_convert_var_shape_suffix'
def
create_convert_shape_node
(
var_shape_node
,
slice_node
=
None
,
...
...
@@ -31,13 +34,20 @@ def create_convert_shape_node(var_shape_node,
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
if
slice_node
:
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})"
.
format
(
","
.
join
(
args
),
in_control_flow
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
return
gast
.
Subscript
(
value
=
api_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
return
api_shape_node
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
...
...
@@ -47,6 +57,39 @@ def create_convert_shape_node(var_shape_node,
return
result_node
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}')"
.
format
(
api_shape_name
)
args
=
[
attr_shape_name
,
eval_exist_func
]
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
choose_shape_func
=
"paddle.jit.dy2static.choose_shape_attr_or_api({})"
.
format
(
","
.
join
(
args
))
choose_shape_node
=
gast
.
parse
(
choose_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
return
gast
.
Subscript
(
value
=
choose_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
return
choose_shape_node
class
ShapeAttributeTransformer
(
gast
.
NodeTransformer
):
"""
Input a node like `x.shape` or `x[4].shape[0]` (self._is_var_shape(node) is True),
return a new node changes input to static shape API like `convert_var_shape(x)`,
`convert_var_shape(x[4])[0]`.
"""
def
visit_Attribute
(
self
,
node
):
if
node
.
attr
==
'shape'
:
args
=
ast_to_source_code
(
node
.
value
).
strip
()
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape_simple({})"
.
format
(
args
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
return
api_shape_node
return
node
class
TensorShapeTransformer
(
gast
.
NodeTransformer
):
"""
This class transforms variable.shape used in Paddle Apis or control flow conditions into Static Graph Ast.
...
...
@@ -58,6 +101,8 @@ class TensorShapeTransformer(gast.NodeTransformer):
),
"Input non-AstNodeWrapper node for the initialization of TensorShapeTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
# stores origin var string name (like "x" in `x = t.shape`) to
# static shape var string name (like "x_SUFFIX" in `x_SUFFIX = shape(t)`)
self
.
name_to_var_shape
=
{}
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
...
...
@@ -72,8 +117,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
self
.
visit
(
self
.
root
)
def
visit_Assign
(
self
,
node
):
if
self
.
_update_name_to_var_shape
(
node
):
return
node
update_static_shape_var_node
=
self
.
_update_name_to_var_shape
(
node
)
if
update_static_shape_var_node
is
not
None
:
ret
=
[
node
]
ret
.
extend
(
update_static_shape_var_node
)
return
ret
self
.
generic_visit
(
node
)
return
node
...
...
@@ -81,37 +129,44 @@ class TensorShapeTransformer(gast.NodeTransformer):
value_node
=
node
.
value
slice_node
=
node
.
slice
if
isinstance
(
value_node
,
gast
.
Name
):
if
self
.
_is_var_shape
(
value_node
)
and
self
.
_used_by_paddle_api
(
value_node
):
var_shape_node
=
self
.
name_to_var_shape
[
value_node
.
id
]
return
create_convert_shape_node
(
var_shape_node
,
slice_node
)
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
)
and
self
.
_is_var_shape
(
if
value_node
.
id
in
self
.
name_to_var_shape
and
self
.
_used_by_paddle_api
(
value_node
):
return
create_convert_shape_node
(
value_node
,
slice_node
)
return
create_choose_shape_node
(
value_node
.
id
,
self
.
name_to_var_shape
[
value_node
.
id
],
slice_node
)
elif
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
):
value_name
=
ast_to_source_code
(
value_node
).
strip
()
if
value_name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
value_name
,
self
.
name_to_var_shape
[
value_name
],
slice_node
)
if
self
.
_is_var_shape
(
value_node
):
return
create_convert_shape_node
(
value_node
,
slice_node
)
return
node
def
visit_Attribute
(
self
,
node
):
if
self
.
_used_by_paddle_api
(
node
):
name
=
ast_to_source_code
(
node
).
strip
()
if
name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
name
,
self
.
name_to_var_shape
[
name
])
if
self
.
_is_var_shape
(
node
):
return
create_convert_shape_node
(
node
)
return
node
def
visit_Name
(
self
,
node
):
if
self
.
_is_var_shape
(
node
)
:
if
node
.
id
in
self
.
name_to_var_shape
:
if
self
.
_used_by_paddle_api
(
node
):
var_shape_node
=
self
.
name_to_var_shape
[
node
.
id
]
return
create_convert_shape_node
(
var_shape_node
)
return
create_choose_shape_node
(
node
.
id
,
self
.
name_to_var_shape
[
node
.
id
]
)
return
node
def
visit_Call
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Call
)
if
is_paddle_api
(
node
):
# Visit gast.Attribute and gast.Name to replace var.shape if necessary.
self
.
generic_visit
(
node
)
# Don't have to visit other APIs
return
node
def
visit_If
(
self
,
node
):
...
...
@@ -147,22 +202,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
return
False
args
=
node
.
iter
.
args
for
idx
,
arg
in
enumerate
(
args
):
if
isinstance
(
arg
,
gast
.
Name
)
and
self
.
_is_var_shape
(
arg
):
args
[
idx
]
=
create_convert_shape_node
(
self
.
name_to_var_shape
[
arg
.
id
])
if
isinstance
(
arg
,
gast
.
Name
)
and
arg
.
id
in
self
.
name_to_var_shape
:
args
[
idx
]
=
create_choose_shape_node
(
arg
.
id
,
self
.
name_to_var_shape
[
arg
.
id
])
return
True
def
_transform_var_shape_if_necessary
(
self
,
cond
):
need_transformed
=
False
for
child_node
in
gast
.
walk
(
cond
):
var_shape_node
=
None
if
isinstance
(
child_node
,
(
gast
.
Attribute
,
gast
.
Subscript
)):
if
self
.
_is_var_shape
(
child_node
):
if
isinstance
(
child_node
,
(
gast
.
Name
,
gast
.
Attribute
,
gast
.
Subscript
)):
child_name
=
ast_to_source_code
(
child_node
).
strip
()
if
child_name
in
self
.
name_to_var_shape
:
var_shape_node
=
create_choose_shape_node
(
child_name
,
self
.
name_to_var_shape
[
child_name
])
elif
self
.
_is_var_shape
(
child_node
):
var_shape_node
=
child_node
elif
isinstance
(
child_node
,
(
gast
.
Name
)):
if
self
.
_is_var_shape
(
child_node
):
var_shape_node
=
self
.
name_to_var_shape
[
child_node
.
id
]
if
var_shape_node
:
need_transformed
=
True
...
...
@@ -170,17 +226,23 @@ class TensorShapeTransformer(gast.NodeTransformer):
parent_node
=
wrapper_node
.
parent
.
node
for
field
,
value
in
gast
.
iter_fields
(
parent_node
):
if
child_node
is
value
:
setattr
(
parent_node
,
field
,
create_convert_shape_node
(
var_shape_node
,
None
,
True
))
if
var_shape_node
is
child_node
:
setattr
(
parent_node
,
field
,
create_convert_shape_node
(
var_shape_node
,
None
,
True
))
else
:
setattr
(
parent_node
,
field
,
var_shape_node
)
break
# Some child_node may be in a list such as gast.Compare
if
isinstance
(
value
,
list
):
has_converted_shape
=
False
for
i
,
v
in
enumerate
(
value
):
if
child_node
is
v
:
value
[
i
]
=
create_convert_shape_node
(
var_shape_node
,
None
,
True
)
if
var_shape_node
is
child_node
:
value
[
i
]
=
create_convert_shape_node
(
var_shape_node
,
None
,
True
)
else
:
value
[
i
]
=
var_shape_node
has_converted_shape
=
True
break
if
has_converted_shape
:
...
...
@@ -217,19 +279,12 @@ class TensorShapeTransformer(gast.NodeTransformer):
"""
Return True if node is like `x.shape` or `x.shape[0]`, return False otherwise.
"""
if
not
isinstance
(
node
,
(
gast
.
Name
,
gast
.
Attribute
,
gast
.
Subscript
)):
if
not
isinstance
(
node
,
(
gast
.
Attribute
,
gast
.
Subscript
)):
return
False
if
isinstance
(
node
,
gast
.
Name
)
and
node
.
id
in
self
.
name_to_var_shape
:
return
True
if
isinstance
(
node
,
gast
.
Attribute
):
if
node
.
attr
!=
'shape'
:
return
False
if
not
isinstance
(
node
.
value
,
gast
.
Name
):
return
False
return
True
if
isinstance
(
node
,
gast
.
Subscript
):
...
...
@@ -243,49 +298,94 @@ class TensorShapeTransformer(gast.NodeTransformer):
target_node
=
node
.
targets
[
0
]
value_node
=
node
.
value
update_static_shape_var_node
=
None
if
isinstance
(
target_node
,
gast
.
Tuple
):
has_updated
=
False
update_static_shape_var_node
=
[]
for
idx
,
element
in
enumerate
(
target_node
.
elts
):
target_id
=
ast_to_source_code
(
element
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
value_node
.
id
in
self
.
name_to_var_shape
:
# TODO(zhhsplendid): is context a problem for the result node of gast.parse?
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
static_shape_value_node
=
gast
.
parse
(
static_shape_value_name
).
body
[
0
].
value
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
var_shape_node
=
self
.
name_to_var_shape
[
value_node
.
id
]
sub_node
=
gast
.
Subscript
(
value
=
var_shap
e_node
,
value
=
static_shape_valu
e_node
,
slice
=
slice_index_node
,
ctx
=
gast
.
Load
())
self
.
name_to_var_shape
[
target_id
]
=
sub_node
has_updated
=
True
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
sub_node
=
gast
.
Subscript
(
value
=
value_node
,
value
=
static_shape_
value_node
,
slice
=
slice_index_node
,
ctx
=
gast
.
Load
())
self
.
name_to_var_shape
[
target_id
]
=
sub_node
has_updated
=
True
return
has_updated
update_static_shape_var_node
.
append
(
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
sub_node
))
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
else
:
target_id
=
ast_to_source_code
(
target_node
).
strip
()
if
isinstance
(
value_node
,
gast
.
Name
):
if
self
.
_is_var_shape
(
value_node
):
self
.
name_to_var_shape
[
target_id
]
=
self
.
name_to_var_shape
[
if
value_node
.
id
in
self
.
name_to_var_shape
:
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
return
True
if
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
):
# eg: x.shape
self
.
name_to_var_shape
[
target_id
]
=
value_node
return
True
if
isinstance
(
value_node
,
gast
.
Subscript
):
if
isinstance
(
value_node
.
value
,
gast
.
Attribute
):
if
self
.
_is_var_shape
(
value_node
.
value
):
# eg: x.shape[0]
self
.
name_to_var_shape
[
target_id
]
=
value_node
return
True
return
False
static_shape_value_node
=
gast
.
parse
(
static_shape_value_name
).
body
[
0
].
value
update_static_shape_var_node
=
[
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
)
]
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
elif
self
.
_is_var_shape
(
value_node
):
# eg: x.shape or x.shape[0]
static_shape_var_name
=
unique_name
.
generate
(
target_id
+
STATIC_CONVERT_VAR_SHAPE_SUFFIX
)
static_shape_var_node
=
gast
.
parse
(
static_shape_var_name
).
body
[
0
].
value
static_shape_value_node
=
copy
.
deepcopy
(
value_node
)
# x.shape becomes convert_var_shape_simple(x)
ShapeAttributeTransformer
().
visit
(
static_shape_value_node
)
update_static_shape_var_node
=
[
gast
.
Assign
(
targets
=
[
static_shape_var_node
],
value
=
static_shape_value_node
)
]
self
.
name_to_var_shape
[
target_id
]
=
static_shape_var_name
return
update_static_shape_var_node
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py
浏览文件 @
3a72408f
...
...
@@ -136,5 +136,58 @@ class TestConvertShapeCompare(unittest.TestCase):
paddle
.
disable_static
()
class
TestChooseShapeAttrOrApi
(
unittest
.
TestCase
):
def
test_api_shape_is_none
(
self
):
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
,
2
],
None
),
[
1
,
2
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
1
],
None
),
[
1
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
2
,
3
,
7
],
None
,
0
),
2
)
def
test_attr_shape_is_int
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
0
],
paddle
.
shape
(
x
)[
0
]),
1
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
[
1
],
paddle
.
shape
(
x
)[
1
]),
3
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
)[
0
]),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
-
1
,
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
def
test_positive_attr_shape
(
self
):
x
=
paddle
.
zeros
([
1
,
3
,
5
,
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
)),
x
.
shape
)
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
(
x
.
shape
,
paddle
.
shape
(
x
),
3
),
x
.
shape
[
3
])
def
test_negative_attr_shape
(
self
):
x
=
paddle
.
zeros
([
7
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
),
0
),
paddle
.
shape
(
x
)[
0
])
self
.
assertEqual
(
paddle
.
jit
.
dy2static
.
choose_shape_attr_or_api
([
-
1
],
paddle
.
shape
(
x
)),
paddle
.
shape
(
x
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py
浏览文件 @
3a72408f
...
...
@@ -60,6 +60,16 @@ def dyfunc_tensor_shape_5(x):
return
res
def
dyfunc_tensor_shape_6
(
x
):
# `res = fluid.layers.reshape(x, shape=(-1, s))` to
# `res = fluid.layers.reshape(x, shape=(-1,
# paddle.jit.dy2static.convert_var_shape(x)[0:]))`
x
=
fluid
.
dygraph
.
to_variable
(
x
)
s
=
x
.
shape
[
0
:]
res
=
fluid
.
layers
.
reshape
(
x
,
shape
=
s
)
return
res
def
dyfunc_tuple_shape_1
(
x
):
x
=
paddle
.
to_tensor
(
x
)
a
,
b
=
x
.
shape
...
...
@@ -197,6 +207,14 @@ def dyfunc_with_while_4(x):
return
x
def
dyfunc_change_shape_after_assign
(
x
):
x
=
paddle
.
to_tensor
(
x
)
a
,
b
=
x
.
shape
x
=
paddle
.
reshape
(
x
,
shape
=
(
-
1
,
1
))
res
=
paddle
.
reshape
(
x
,
shape
=
(
b
,
a
))
return
res
# 1. Basic tests without control flow
class
TestTensorShapeBasic
(
unittest
.
TestCase
):
def
setUp
(
self
):
...
...
@@ -279,6 +297,21 @@ class TestTensorShapeBasic5(TestTensorShapeBasic):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_tensor_shape_5
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTensorShapeBasic6
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_tensor_shape_6
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTupleShape1
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
...
...
@@ -312,9 +345,9 @@ class TestTensorShapeInIf1(TestTensorShapeBasic):
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
26
self
.
expected_shape_op_num
=
2
self
.
expected_slice_op_num
=
2
self
.
expected_op_num
=
4
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTensorShapeInIf2
(
TestTensorShapeBasic
):
...
...
@@ -342,6 +375,11 @@ class TestTensorShapeInFor2(TestTensorShapeInFor1):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_for_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
9
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
# 4. Tests with control flow while loop
class
TestTensorShapeInWhile1
(
TestTensorShapeInFor1
):
...
...
@@ -353,15 +391,20 @@ class TestTensorShapeInWhile2(TestTensorShapeInFor1):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_while_2
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
6
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
class
TestTensorShapeInWhile3
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
dygraph_func
=
dyfunc_with_while_3
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
2
5
self
.
expected_shape_op_num
=
6
self
.
expected_slice_op_num
=
3
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestTensorShapeInWhile4
(
TestTensorShapeBasic
):
...
...
@@ -431,9 +474,9 @@ class TestOpNumWithTensorShapeTuple1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_tuple_shape_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
5
self
.
expected_shape_op_num
=
1
self
.
expected_slice_op_num
=
1
self
.
expected_op_num
=
2
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
class
TestOpNumWithTensorShapeInIf1
(
TestOpNumBasicWithTensorShape
):
...
...
@@ -441,7 +484,7 @@ class TestOpNumWithTensorShapeInIf1(TestOpNumBasicWithTensorShape):
self
.
dygraph_func
=
dyfunc_with_if_1
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
28
self
.
expected_op_num
=
19
self
.
expected_shape_op_num
=
4
self
.
expected_slice_op_num
=
2
...
...
@@ -466,5 +509,17 @@ class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
self
.
expected_slice_op_num
=
3
class
TestChangeShapeAfterAssign
(
TestTensorShapeBasic
):
def
init_test_func
(
self
):
self
.
input
=
numpy
.
ones
((
2
,
3
)).
astype
(
"int32"
)
self
.
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
2
,
3
],
dtype
=
"int32"
)]
self
.
dygraph_func
=
dyfunc_change_shape_after_assign
def
_set_expected_op_num
(
self
):
self
.
expected_op_num
=
3
self
.
expected_shape_op_num
=
0
self
.
expected_slice_op_num
=
0
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/jit/dy2static/convert_operators.py
浏览文件 @
3a72408f
...
...
@@ -25,11 +25,15 @@ from ...fluid.dygraph.dygraph_to_static.convert_operators import convert_print
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_shape_compare
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_dtype
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_var_shape_simple
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
eval_if_exist_else_none
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
choose_shape_attr_or_api
#DEFINE_ALIAS
from
...fluid.dygraph.dygraph_to_static.convert_operators
import
convert_while_loop
#DEFINE_ALIAS
__all__
=
[
'cast_bool_if_necessary'
,
'convert_assert'
,
'convert_ifelse'
,
'convert_len'
,
'convert_logical_and'
,
'convert_logical_not'
,
'convert_logical_or'
,
'convert_pop'
,
'convert_print'
,
'convert_shape_compare'
,
'convert_var_dtype'
,
'convert_var_shape'
,
'convert_while_loop'
'convert_var_dtype'
,
'convert_var_shape'
,
'convert_var_shape_simple'
,
'eval_if_exist_else_none'
,
'choose_shape_attr_or_api'
,
'convert_while_loop'
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录