Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
522c91ec
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
522c91ec
编写于
3月 04, 2021
作者:
L
liym27
提交者:
GitHub
3月 04, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat] Remove gast.Index for compatibility of gast 0.4.0 (#31358)
上级
62289fcc
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
67 addition
and
54 deletion
+67
-54
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
...addle/fluid/dygraph/dygraph_to_static/list_transformer.py
+6
-2
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
...uid/dygraph/dygraph_to_static/tensor_shape_transformer.py
+25
-32
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+36
-19
python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py
...dle/fluid/tests/unittests/test_gast_with_compatibility.py
+0
-1
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
浏览文件 @
522c91ec
...
@@ -18,7 +18,10 @@ import astor
...
@@ -18,7 +18,10 @@ import astor
import
gast
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
,
is_control_flow_to_transform
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
slice_is_num
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_control_flow_to_transform
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
...
@@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer):
...
@@ -116,12 +119,13 @@ class ListTransformer(gast.NodeTransformer):
def
_transform_slice_to_tensor_write
(
self
,
node
):
def
_transform_slice_to_tensor_write
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
target_node
=
node
.
targets
[
0
]
target_name
=
target_node
.
value
.
id
target_name
=
target_node
.
value
.
id
slice_node
=
target_node
.
slice
slice_node
=
target_node
.
slice
if
isinstance
(
slice_node
,
gast
.
Slice
):
if
isinstance
(
slice_node
,
gast
.
Slice
):
pass
pass
elif
isinstance
(
slice_node
,
gast
.
Index
):
elif
slice_is_num
(
target_node
):
value_code
=
ast_to_source_code
(
node
.
value
)
value_code
=
ast_to_source_code
(
node
.
value
)
i
=
"paddle.cast("
\
i
=
"paddle.cast("
\
"x=paddle.jit.dy2static.to_static_variable({}),"
\
"x=paddle.jit.dy2static.to_static_variable({}),"
\
...
...
python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py
浏览文件 @
522c91ec
...
@@ -19,6 +19,7 @@ import gast
...
@@ -19,6 +19,7 @@ import gast
from
paddle.fluid
import
unique_name
from
paddle.fluid
import
unique_name
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
slice_is_num
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_paddle_api
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
SplitAssignTransformer
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
...
@@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node,
...
@@ -34,43 +35,42 @@ def create_convert_shape_node(var_shape_node,
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
if
isinstance
(
var_shape_node
,
gast
.
Attribute
):
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
args
=
[
ast_to_source_code
(
var_shape_node
.
value
).
strip
()]
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
# (1) A slice can be a simple number such as 1, -2, i.e. gast.Index
or gast.Constant
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
# (2) A slice can also be represented by bounds such as 2:-1, i.e. not gast.Index
or gast.Constant
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (1) case, we pass the number as 'idx' argument in convert_var_shape
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
# In (2) case, we have to make it like `convert_var_shape(x)[slice]`
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
if
slice_node
is
not
None
and
slice_is_num
(
slice_node
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
args
.
append
(
ast_to_source_code
(
slice_node
.
slice
).
strip
())
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})"
.
format
(
convert_var_shape_func
=
"paddle.jit.dy2static.convert_var_shape({}, in_control_flow={})"
.
format
(
","
.
join
(
args
),
in_control_flow
)
","
.
join
(
args
),
in_control_flow
)
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
api_shape_node
=
gast
.
parse
(
convert_var_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
if
slice_node
is
not
None
and
not
slice_is_num
(
slice_node
):
return
gast
.
Subscript
(
return
gast
.
Subscript
(
value
=
api_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
value
=
api_shape_node
,
slice
=
slice_node
.
slice
,
ctx
=
gast
.
Load
())
return
api_shape_node
return
api_shape_node
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
if
isinstance
(
var_shape_node
,
gast
.
Subscript
):
result_node
=
copy
.
deepcopy
(
var_shape_node
)
result_node
=
copy
.
deepcopy
(
var_shape_node
)
result_node
=
create_convert_shape_node
(
result_node
=
create_convert_shape_node
(
result_node
.
value
,
result_node
,
result_node
.
value
,
result_node
.
slice
,
in_control_flow
)
in_control_flow
)
return
result_node
return
result_node
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
def
create_choose_shape_node
(
attr_shape_name
,
api_shape_name
,
slice_node
=
None
):
# Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())"
.
format
(
eval_exist_func
=
"paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())"
.
format
(
api_shape_name
)
api_shape_name
)
args
=
[
attr_shape_name
,
eval_exist_func
]
args
=
[
attr_shape_name
,
eval_exist_func
]
if
slice_node
is
not
None
and
isinstance
(
slice_node
,
gast
.
Index
):
if
slice_node
is
not
None
and
slice_is_num
(
slice_node
):
args
.
append
(
ast_to_source_code
(
slice_node
).
strip
())
args
.
append
(
ast_to_source_code
(
slice_node
.
slice
).
strip
())
choose_shape_func
=
"paddle.jit.dy2static.choose_shape_attr_or_api({})"
.
format
(
choose_shape_func
=
"paddle.jit.dy2static.choose_shape_attr_or_api({})"
.
format
(
","
.
join
(
args
))
","
.
join
(
args
))
choose_shape_node
=
gast
.
parse
(
choose_shape_func
).
body
[
0
].
value
choose_shape_node
=
gast
.
parse
(
choose_shape_func
).
body
[
0
].
value
if
slice_node
is
not
None
and
not
isinstance
(
slice_node
,
gast
.
Index
):
if
slice_node
is
not
None
and
not
slice_is_num
(
slice_node
):
return
gast
.
Subscript
(
return
gast
.
Subscript
(
value
=
choose_shape_node
,
slice
=
slice_node
,
ctx
=
gast
.
Load
())
value
=
choose_shape_node
,
slice
=
slice_node
.
slice
,
ctx
=
gast
.
Load
())
return
choose_shape_node
return
choose_shape_node
...
@@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -133,17 +133,15 @@ class TensorShapeTransformer(gast.NodeTransformer):
if
value_node
.
id
in
self
.
name_to_var_shape
and
self
.
_used_by_paddle_api
(
if
value_node
.
id
in
self
.
name_to_var_shape
and
self
.
_used_by_paddle_api
(
value_node
):
value_node
):
return
create_choose_shape_node
(
return
create_choose_shape_node
(
value_node
.
id
,
self
.
name_to_var_shape
[
value_node
.
id
],
value_node
.
id
,
self
.
name_to_var_shape
[
value_node
.
id
],
node
)
slice_node
)
elif
isinstance
(
value_node
,
gast
.
Attribute
):
elif
isinstance
(
value_node
,
gast
.
Attribute
):
if
self
.
_used_by_paddle_api
(
value_node
):
if
self
.
_used_by_paddle_api
(
value_node
):
value_name
=
ast_to_source_code
(
value_node
).
strip
()
value_name
=
ast_to_source_code
(
value_node
).
strip
()
if
value_name
in
self
.
name_to_var_shape
:
if
value_name
in
self
.
name_to_var_shape
:
return
create_choose_shape_node
(
return
create_choose_shape_node
(
value_name
,
self
.
name_to_var_shape
[
value_name
],
value_name
,
self
.
name_to_var_shape
[
value_name
],
node
)
slice_node
)
if
self
.
_is_var_shape
(
value_node
):
if
self
.
_is_var_shape
(
value_node
):
return
create_convert_shape_node
(
value_node
,
slice_
node
)
return
create_convert_shape_node
(
value_node
,
node
)
return
node
return
node
def
visit_Attribute
(
self
,
node
):
def
visit_Attribute
(
self
,
node
):
...
@@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -315,14 +313,10 @@ class TensorShapeTransformer(gast.NodeTransformer):
static_shape_value_name
=
self
.
name_to_var_shape
[
static_shape_value_name
=
self
.
name_to_var_shape
[
value_node
.
id
]
value_node
.
id
]
static_shape_value_node
=
gast
.
parse
(
static_shape_value_name
).
body
[
0
].
value
sub_node_str
=
"{}[{}]"
.
format
(
static_shape_value_name
,
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
idx
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
sub_node
=
gast
.
parse
(
sub_node_str
).
body
[
0
].
value
sub_node
=
gast
.
Subscript
(
value
=
static_shape_value_node
,
slice
=
slice_index_node
,
ctx
=
gast
.
Load
())
update_static_shape_var_node
.
append
(
update_static_shape_var_node
.
append
(
gast
.
Assign
(
gast
.
Assign
(
...
@@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
...
@@ -342,12 +336,11 @@ class TensorShapeTransformer(gast.NodeTransformer):
# x.shape becomes convert_var_shape_simple(x)
# x.shape becomes convert_var_shape_simple(x)
static_shape_value_node
=
ShapeAttributeTransformer
(
static_shape_value_node
=
ShapeAttributeTransformer
(
).
visit
(
static_shape_value_node
)
).
visit
(
static_shape_value_node
)
index_value_node
=
gast
.
Constant
(
value
=
idx
,
kind
=
None
)
slice_index_node
=
gast
.
Index
(
value
=
index_value_node
)
sub_node_str
=
"{}[{}]"
.
format
(
sub_node
=
gast
.
Subscript
(
ast_to_source_code
(
static_shape_value_node
).
strip
(),
value
=
static_shape_value_node
,
idx
)
slice
=
slice_index_node
,
sub_node
=
gast
.
parse
(
sub_node_str
).
body
[
0
].
value
ctx
=
gast
.
Load
())
update_static_shape_var_node
.
append
(
update_static_shape_var_node
.
append
(
gast
.
Assign
(
gast
.
Assign
(
...
...
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
522c91ec
...
@@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer):
...
@@ -921,18 +921,15 @@ class ForLoopTuplePreTransformer(gast.NodeTransformer):
def
tuple_to_stmts
(
self
,
node
,
tuple_name
,
idx
=
[]):
def
tuple_to_stmts
(
self
,
node
,
tuple_name
,
idx
=
[]):
if
not
isinstance
(
node
,
(
gast
.
Tuple
,
gast
.
List
)):
if
not
isinstance
(
node
,
(
gast
.
Tuple
,
gast
.
List
)):
value_node
=
gast
.
Name
(
value_node_str
=
tuple_name
id
=
tuple_name
,
ctx
=
gast
.
Load
(),
annotation
=
None
,
type_comment
=
None
)
for
i
in
idx
:
for
i
in
idx
:
value_node
=
gast
.
Subscript
(
value_node_str
=
value_node_str
+
"[{}]"
.
format
(
i
)
value
=
value_node
,
slice
=
gast
.
Index
(
value
=
gast
.
Constant
(
node_str
=
ast_to_source_code
(
node
).
strip
()
value
=
i
,
kind
=
None
)),
assign_node_str
=
"{} = {}"
.
format
(
node_str
,
value_node_str
)
ctx
=
gast
.
Load
())
assign_node
=
gast
.
parse
(
assign_node_str
).
body
[
0
]
return
[
gast
.
Assign
(
targets
=
[
node
],
value
=
value_node
)]
return
[
assign_node
]
# isinstance(node, (gast.Tuple, gast.List))
# isinstance(node, (gast.Tuple, gast.List))
ret
=
[]
ret
=
[]
for
i
,
element
in
enumerate
(
node
.
elts
):
for
i
,
element
in
enumerate
(
node
.
elts
):
...
@@ -1240,14 +1237,9 @@ class ForNodeVisitor(object):
...
@@ -1240,14 +1237,9 @@ class ForNodeVisitor(object):
value
=
step_node
)
value
=
step_node
)
def
_build_assign_var_slice_node
(
self
):
def
_build_assign_var_slice_node
(
self
):
var_slice_node
=
gast
.
Subscript
(
var_slice_str
=
"{}[{}]"
.
format
(
value
=
self
.
iter_node
,
ast_to_source_code
(
self
.
iter_node
).
strip
(),
self
.
iter_idx_name
)
slice
=
gast
.
Index
(
value
=
gast
.
Name
(
var_slice_node
=
gast
.
parse
(
var_slice_str
).
body
[
0
].
value
id
=
self
.
iter_idx_name
,
ctx
=
gast
.
Load
(),
annotation
=
None
,
type_comment
=
None
)),
ctx
=
gast
.
Load
(),
)
new_iter_var_name
=
unique_name
.
generate
(
FOR_ITER_VAR_NAME_PREFIX
)
new_iter_var_name
=
unique_name
.
generate
(
FOR_ITER_VAR_NAME_PREFIX
)
target_node
,
assign_node
=
create_assign_node
(
new_iter_var_name
,
target_node
,
assign_node
=
create_assign_node
(
new_iter_var_name
,
var_slice_node
)
var_slice_node
)
...
@@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
...
@@ -1422,3 +1414,28 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
return
False
return
False
return
True
return
True
def
slice_is_num
(
slice_node
):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
# (2) ast.Slice, which is represented by bounds such as [2:-1]
# (3) ast.Tuple, which includes the above two cases such as [2:-1, 1]
# If slice node is case (1), return True, Otherwise, return False.
#
# NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced
# other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on.
# Considering the compatibility of gast, here use ast note to check whether the
# node is a num. For more details, please visit https://github.com/serge-sans-paille/gast
assert
isinstance
(
slice_node
,
gast
.
Subscript
)
slice_node_str
=
ast_to_source_code
(
slice_node
).
strip
()
ast_node
=
ast
.
parse
(
slice_node_str
).
body
[
0
].
value
if
isinstance
(
ast_node
.
slice
,
(
ast
.
Tuple
,
ast
.
Slice
)):
return
False
if
isinstance
(
ast_node
.
slice
,
ast
.
Index
):
return
True
return
False
python/paddle/fluid/tests/unittests/test_gast_with_compatibility.py
浏览文件 @
522c91ec
...
@@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer):
...
@@ -97,7 +97,6 @@ class GastNodeTransformer(gast.NodeTransformer):
It will be generally represented by gast.Index or gast.Slice in gast.
It will be generally represented by gast.Index or gast.Slice in gast.
Note: Paddle doesn't support PY3.8 currently.
Note: Paddle doesn't support PY3.8 currently.
"""
"""
assert
isinstance
(
node
.
slice
,
(
gast
.
Index
,
gast
.
Slice
))
self
.
generic_visit
(
node
)
self
.
generic_visit
(
node
)
return
node
return
node
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录