Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5a202af9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5a202af9
编写于
3月 18, 2020
作者:
L
liym27
提交者:
GitHub
3月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support slice write in dygraph_to_static. test=develop (#23055)
上级
52575304
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
210 addition
and
8 deletion
+210
-8
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
...addle/fluid/dygraph/dygraph_to_static/list_transformer.py
+36
-6
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
...ddle/fluid/tests/unittests/dygraph_to_static/test_list.py
+0
-2
python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py
...dle/fluid/tests/unittests/dygraph_to_static/test_slice.py
+174
-0
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
浏览文件 @
5a202af9
...
...
@@ -17,12 +17,12 @@ from __future__ import print_function
import
gast
import
astor
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
NodeVarType
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_control_flow_to_transform
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_control_flow_to_transform
,
ast_to_source_code
class
ListTransformer
(
gast
.
NodeTransformer
):
"""
This class transforms python list used in control flow into Static Graph Ast
This class transforms python list used in control flow into Static Graph Ast
.
"""
def
__init__
(
self
,
wrapper_root
):
...
...
@@ -31,8 +31,8 @@ class ListTransformer(gast.NodeTransformer):
),
"Input non-AstNodeWrapper node for the initialization of ListTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
self
.
name_of_list_set
=
set
()
self
.
list_name_to_updated
=
dict
()
self
.
list_nodes
=
set
()
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
self
.
node_to_wrapper_map
=
self
.
static_analysis_visitor
.
get_node_to_wrapper_map
(
...
...
@@ -46,7 +46,11 @@ class ListTransformer(gast.NodeTransformer):
self
.
replace_list_with_tensor_array
(
self
.
root
)
def
visit_Assign
(
self
,
node
):
self
.
_update_list_name_to_updated
(
node
)
if
self
.
_update_list_name_to_updated
(
node
):
return
node
if
self
.
_need_to_array_write_node
(
node
):
return
self
.
_transform_slice_to_tensor_write
(
node
)
return
node
def
visit_If
(
self
,
node
):
...
...
@@ -85,8 +89,33 @@ class ListTransformer(gast.NodeTransformer):
if
self
.
_is_list_append_tensor
(
node
.
value
):
return
True
if
isinstance
(
node
,
gast
.
Assign
):
target_node
=
node
.
targets
[
0
]
if
isinstance
(
target_node
,
gast
.
Subscript
):
list_name
=
ast_to_source_code
(
target_node
.
value
).
strip
()
if
list_name
in
self
.
list_name_to_updated
:
if
self
.
list_name_to_updated
[
list_name
]
==
True
:
return
True
return
False
def
_transform_slice_to_tensor_write
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
target_name
=
target_node
.
value
.
id
slice_node
=
target_node
.
slice
if
isinstance
(
slice_node
,
gast
.
Slice
):
pass
elif
isinstance
(
slice_node
,
gast
.
Index
):
value_code
=
ast_to_source_code
(
node
.
value
)
i
=
"fluid.layers.cast("
\
"x=fluid.dygraph.dygraph_to_static.variable_trans_func.to_static_variable({}),"
\
"dtype='int64')"
.
format
(
ast_to_source_code
(
slice_node
))
assign_code
=
"{} = fluid.layers.array_write(x={}, i={}, array={})"
\
.
format
(
target_name
,
value_code
,
i
,
target_name
)
assign_node
=
gast
.
parse
(
assign_code
).
body
[
0
]
return
assign_node
def
_is_list_append_tensor
(
self
,
node
):
"""
a.append(b): a is list, b is Tensor
...
...
@@ -135,7 +164,7 @@ class ListTransformer(gast.NodeTransformer):
target_id
=
target_node
.
id
except
AttributeError
:
return
False
if
self
.
list_name_to_updated
.
get
(
target_id
):
if
self
.
list_name_to_updated
.
get
(
target_id
)
and
node
in
self
.
list_nodes
:
return
True
return
False
...
...
@@ -165,7 +194,8 @@ class ListTransformer(gast.NodeTransformer):
value_node
=
node
.
value
if
isinstance
(
value_node
,
gast
.
List
):
self
.
list_name_to_updated
[
target_id
]
=
False
self
.
list_nodes
.
add
(
node
)
return
True
elif
target_id
in
self
.
name_of_list_set
:
elif
target_id
in
self
.
list_name_to_updated
:
del
self
.
list_name_to_updated
[
target_id
]
return
False
python/paddle/fluid/tests/unittests/dygraph_to_static/test_list.py
浏览文件 @
5a202af9
...
...
@@ -44,7 +44,6 @@ def test_list_in_if(x):
def
test_list_in_for_loop
(
x
,
iter_num
):
# Note: for_loop can't be transformed before PR22867 merged.
x
=
fluid
.
dygraph
.
to_variable
(
x
)
a
=
[]
for
i
in
range
(
iter_num
):
...
...
@@ -53,7 +52,6 @@ def test_list_in_for_loop(x, iter_num):
def
test_list_in_for_loop_with_concat
(
x
,
iter_num
):
# Note: for_loop can't be transformed before PR22867 merged.
x
=
fluid
.
dygraph
.
to_variable
(
x
)
a
=
[]
for
i
in
range
(
iter_num
):
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_slice.py
0 → 100644
浏览文件 @
5a202af9
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_graph
SEED
=
2020
np
.
random
.
seed
(
SEED
)
def
test_slice_without_control_flow
(
x
):
# Python slice will not be transformed.
x
=
fluid
.
dygraph
.
to_variable
(
x
)
a
=
[
x
]
a
[
0
]
=
fluid
.
layers
.
fill_constant
(
shape
=
[
2
],
value
=
2
,
dtype
=
"float32"
)
return
a
def
test_slice_in_if
(
x
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
a
=
[]
if
x
.
numpy
()[
0
]
>
0
:
a
.
append
(
x
)
else
:
a
.
append
(
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
2
],
value
=
9
,
dtype
=
"int64"
))
if
x
.
numpy
()[
0
]
>
0
:
a
[
0
]
=
x
return
a
def
test_slice_in_while_loop
(
x
,
iter_num
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
iter_num
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
iter_num
,
dtype
=
"int32"
)
a
=
[]
i
=
0
# Note: `i < iter_num` can't be supported in dygraph mode now,
# but PR22892 is fixing it https://github.com/PaddlePaddle/Paddle/pull/22892.
# If PR22892 merged, change `i < iter_num.numpy()[0]` to `i < iter_num`.
while
i
<
iter_num
.
numpy
()[
0
]:
a
.
append
(
x
)
i
+=
1
i
=
0
while
i
<
iter_num
.
numpy
()[
0
]:
a
[
i
]
=
fluid
.
layers
.
fill_constant
(
shape
=
[
2
],
value
=
2
,
dtype
=
"float32"
)
i
+=
1
return
a
def
test_slice_in_for_loop
(
x
,
iter_num
):
x
=
fluid
.
dygraph
.
to_variable
(
x
)
a
=
[]
for
i
in
range
(
iter_num
):
a
.
append
(
x
)
for
i
in
range
(
iter_num
):
a
[
i
]
=
x
return
a
class
TestSliceWithoutControlFlow
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
input
=
np
.
random
.
random
((
3
)).
astype
(
'int32'
)
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
self
.
init_dygraph_func
()
def
init_dygraph_func
(
self
):
self
.
dygraph_func
=
test_slice_without_control_flow
def
run_dygraph_mode
(
self
):
with
fluid
.
dygraph
.
guard
():
res
=
self
.
dygraph_func
(
self
.
input
)
if
isinstance
(
res
,
(
list
,
tuple
)):
res
=
res
[
0
]
return
res
.
numpy
()
def
run_static_mode
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
tensor_list
=
dygraph_to_static_graph
(
self
.
dygraph_func
)(
self
.
input
)
exe
=
fluid
.
Executor
(
self
.
place
)
static_res
=
exe
.
run
(
main_program
,
fetch_list
=
tensor_list
[
0
])
return
static_res
[
0
]
def
test_transformed_static_result
(
self
):
static_res
=
self
.
run_static_mode
()
dygraph_res
=
self
.
run_dygraph_mode
()
self
.
assertTrue
(
np
.
allclose
(
dygraph_res
,
static_res
),
msg
=
'dygraph res is {}
\n
static_res is {}'
.
format
(
dygraph_res
,
static_res
))
class
TestSliceInIf
(
TestSliceWithoutControlFlow
):
def
init_dygraph_func
(
self
):
self
.
dygraph_func
=
test_slice_in_if
def
run_static_mode
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
tensor_array
=
dygraph_to_static_graph
(
self
.
dygraph_func
)(
self
.
input
)
static_out
=
fluid
.
layers
.
array_read
(
tensor_array
,
i
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
0
,
dtype
=
'int64'
))
exe
=
fluid
.
Executor
(
self
.
place
)
numpy_res
=
exe
.
run
(
main_program
,
fetch_list
=
static_out
)
return
numpy_res
[
0
]
class
TestSliceInWhileLoop
(
TestSliceWithoutControlFlow
):
def
setUp
(
self
):
self
.
iter_num
=
3
self
.
input
=
np
.
random
.
random
((
3
)).
astype
(
'int32'
)
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
self
.
init_dygraph_func
()
def
init_dygraph_func
(
self
):
self
.
dygraph_func
=
test_slice_in_while_loop
def
run_dygraph_mode
(
self
):
with
fluid
.
dygraph
.
guard
():
var_res
=
self
.
dygraph_func
(
self
.
input
,
self
.
iter_num
)
numpy_res
=
[
ele
.
numpy
()
for
ele
in
var_res
]
return
numpy_res
def
run_static_mode
(
self
):
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
tensor_array
=
dygraph_to_static_graph
(
self
.
dygraph_func
)(
self
.
input
,
self
.
iter_num
)
static_outs
=
[]
for
i
in
range
(
self
.
iter_num
):
static_outs
.
append
(
fluid
.
layers
.
array_read
(
tensor_array
,
i
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
i
,
dtype
=
'int64'
)))
exe
=
fluid
.
Executor
(
self
.
place
)
numpy_res
=
exe
.
run
(
main_program
,
fetch_list
=
static_outs
)
return
numpy_res
class
TestSliceInForLoop
(
TestSliceInWhileLoop
):
def
init_dygraph_func
(
self
):
self
.
dygraph_func
=
test_slice_in_for_loop
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录