Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
a3436672
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a3436672
编写于
10月 20, 2022
作者:
X
xiongkun
提交者:
GitHub
10月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Static] Remove deprecated code in dy2static (#47148)
上级
0e552c08
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
0 addition
and
392 deletion
+0
-392
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+0
-4
python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py
...addle/fluid/dygraph/dygraph_to_static/grad_transformer.py
+0
-91
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
...addle/fluid/dygraph/dygraph_to_static/list_transformer.py
+0
-256
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+0
-41
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
a3436672
...
...
@@ -27,10 +27,8 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br
from
paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer
import
BreakTransformOptimizer
from
paddle.fluid.dygraph.dygraph_to_static.call_transformer
import
CallTransformer
from
paddle.fluid.dygraph.dygraph_to_static.cast_transformer
import
CastTransformer
from
paddle.fluid.dygraph.dygraph_to_static.grad_transformer
import
GradTransformer
from
paddle.fluid.dygraph.dygraph_to_static.typehint_transformer
import
TypeHintTransformer
from
paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer
import
IfElseTransformer
from
paddle.fluid.dygraph.dygraph_to_static.list_transformer
import
ListTransformer
from
paddle.fluid.dygraph.dygraph_to_static.logical_transformer
import
LogicalTransformer
from
paddle.fluid.dygraph.dygraph_to_static.loop_transformer
import
LoopTransformer
from
paddle.fluid.dygraph.dygraph_to_static.print_transformer
import
PrintTransformer
...
...
@@ -92,7 +90,6 @@ class DygraphToStaticAst(BaseTransformer):
EarlyReturnTransformer
,
BasicApiTransformer
,
# Basic Api
TensorShapeTransformer
,
# Tensor.shape -> layers.shape(Tensor)
#ListTransformer, # List used in control flow
BreakContinueTransformer
,
# break/continue in loops
ReturnTransformer
,
# return in functions
LogicalTransformer
,
# logical and/or/not
...
...
@@ -103,7 +100,6 @@ class DygraphToStaticAst(BaseTransformer):
PrintTransformer
,
# print statement
CallTransformer
,
# transform call recursively
CastTransformer
,
# type casting statement
#GradTransformer, # transform paddle.grad to paddle.gradients
DecoratorTransformer
,
# transform decorators to function call
TypeHintTransformer
,
# remove all typehint in gast.Name
]
...
...
python/paddle/fluid/dygraph/dygraph_to_static/grad_transformer.py
已删除
100644 → 0
浏览文件 @
0e552c08
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.utils
import
gast
import
warnings
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static
import
utils
from
paddle.fluid.dygraph.dygraph_to_static.base_transformer
import
BaseTransformer
class
GradTransformer
(
BaseTransformer
):
"""
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode.
"""
def
__init__
(
self
,
wrapper_root
):
assert
isinstance
(
wrapper_root
,
AstNodeWrapper
),
"Input non-AstNodeWrapper node for the initialization of GradTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
def
transform
(
self
):
self
.
visit
(
self
.
root
)
def
visit_Call
(
self
,
node
):
self
.
generic_visit
(
node
)
if
not
is_grad_api_node
(
node
):
return
node
dygraph_grad_parameters
=
[
"outputs"
,
"inputs"
,
"grad_outputs"
,
"retain_graph"
,
"create_graph"
,
"only_inputs"
,
"allow_unused"
,
"no_grad_vars"
]
to_static_grad_param
=
{
"outputs"
:
"targets"
,
"inputs"
:
"inputs"
,
"grad_outputs"
:
"target_gradients"
,
"no_grad_vars"
:
"no_grad_set"
}
static_keywords
=
[]
for
kw
in
node
.
keywords
:
if
kw
.
arg
not
in
dygraph_grad_parameters
or
kw
.
arg
not
in
to_static_grad_param
:
warnings
.
warn
(
"paddle.grad has unsupported parameter in jit: "
+
kw
.
arg
+
", jit will discard it"
)
continue
dygraph_grad_parameters
.
remove
(
kw
.
arg
)
kw
.
arg
=
to_static_grad_param
[
kw
.
arg
]
static_keywords
.
append
(
kw
)
for
i
in
range
(
len
(
node
.
args
)):
arg_name
=
dygraph_grad_parameters
[
i
]
if
arg_name
not
in
to_static_grad_param
:
warnings
.
warn
(
"paddle.grad has unsupported parameter in jit: "
+
kw
.
arg
+
", jit will discard it"
)
continue
kw
=
gast
.
keyword
(
arg
=
to_static_grad_param
[
arg_name
],
value
=
node
.
args
[
i
])
static_keywords
.
append
(
kw
)
node
.
func
=
gast
.
parse
(
'paddle.static.gradients'
).
body
[
0
].
value
node
.
keywords
=
static_keywords
node
.
args
=
[]
return
node
def
is_grad_api_node
(
node
):
assert
isinstance
(
node
,
gast
.
Call
)
api_name
=
utils
.
ast_to_source_code
(
node
.
func
).
strip
()
if
utils
.
is_paddle_api
(
node
):
if
'no_grad'
in
api_name
:
warnings
.
warn
(
"paddle.no_grad is only supported for inference model, and not supported for training under @to_static."
)
return
False
return
api_name
.
endswith
(
"grad"
)
return
False
python/paddle/fluid/dygraph/dygraph_to_static/list_transformer.py
已删除
100644 → 0
浏览文件 @
0e552c08
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
astor
from
paddle.utils
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
ast_to_source_code
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
slice_is_num
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_control_flow_to_transform
from
paddle.fluid.dygraph.dygraph_to_static.base_transformer
import
BaseTransformer
class
ListTransformer
(
BaseTransformer
):
"""
This class transforms python list used in control flow into Static Graph Ast.
"""
def
__init__
(
self
,
wrapper_root
):
assert
isinstance
(
wrapper_root
,
AstNodeWrapper
),
"Input non-AstNodeWrapper node for the initialization of ListTransformer."
self
.
wrapper_root
=
wrapper_root
self
.
root
=
wrapper_root
.
node
self
.
list_name_to_updated
=
dict
()
self
.
list_nodes
=
set
()
self
.
static_analysis_visitor
=
StaticAnalysisVisitor
(
self
.
root
)
self
.
node_to_wrapper_map
=
self
.
static_analysis_visitor
.
get_node_to_wrapper_map
(
)
var_env
=
self
.
static_analysis_visitor
.
get_var_env
()
var_env
.
cur_scope
=
var_env
.
cur_scope
.
sub_scopes
[
0
]
self
.
scope_var_type_dict
=
var_env
.
get_scope_var_type
()
def
transform
(
self
):
self
.
visit
(
self
.
root
)
self
.
replace_list_with_tensor_array
(
self
.
root
)
def
visit_Call
(
self
,
node
):
if
isinstance
(
node
.
func
,
gast
.
Attribute
):
func_name
=
node
.
func
.
attr
if
func_name
==
"pop"
:
node
=
self
.
_replace_pop
(
node
)
return
node
def
visit_Assign
(
self
,
node
):
if
self
.
_update_list_name_to_updated
(
node
):
return
node
if
self
.
_need_to_array_write_node
(
node
):
return
self
.
_transform_slice_to_tensor_write
(
node
)
self
.
generic_visit
(
node
)
return
node
def
visit_If
(
self
,
node
):
self
.
generic_visit
(
node
)
if
is_control_flow_to_transform
(
node
,
self
.
static_analysis_visitor
,
self
.
scope_var_type_dict
):
self
.
_transform_list_append_in_control_flow
(
node
)
return
node
def
visit_While
(
self
,
node
):
self
.
generic_visit
(
node
)
if
is_control_flow_to_transform
(
node
,
self
.
static_analysis_visitor
,
self
.
scope_var_type_dict
):
self
.
_transform_list_append_in_control_flow
(
node
)
return
node
def
visit_For
(
self
,
node
):
self
.
generic_visit
(
node
)
if
is_control_flow_to_transform
(
node
,
self
.
static_analysis_visitor
,
self
.
scope_var_type_dict
):
self
.
_transform_list_append_in_control_flow
(
node
)
return
node
def
replace_list_with_tensor_array
(
self
,
node
):
for
child_node
in
gast
.
walk
(
node
):
if
isinstance
(
child_node
,
gast
.
Assign
):
if
self
.
_need_to_create_tensor_array
(
child_node
):
child_node
.
value
=
self
.
_create_tensor_array
(
child_node
.
value
)
def
_transform_list_append_in_control_flow
(
self
,
node
):
for
child_node
in
gast
.
walk
(
node
):
if
self
.
_need_to_array_write_node
(
child_node
):
child_node
.
value
=
\
self
.
_to_array_write_node
(
child_node
.
value
)
def
_need_to_array_write_node
(
self
,
node
):
if
isinstance
(
node
,
gast
.
Expr
):
if
isinstance
(
node
.
value
,
gast
.
Call
):
if
self
.
_is_list_append_tensor
(
node
.
value
):
return
True
if
isinstance
(
node
,
gast
.
Assign
):
target_node
=
node
.
targets
[
0
]
if
isinstance
(
target_node
,
gast
.
Subscript
):
list_name
=
ast_to_source_code
(
target_node
.
value
).
strip
()
if
list_name
in
self
.
list_name_to_updated
:
if
self
.
list_name_to_updated
[
list_name
]
==
True
:
return
True
return
False
def
_transform_slice_to_tensor_write
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
target_name
=
target_node
.
value
.
id
slice_node
=
target_node
.
slice
if
isinstance
(
slice_node
,
gast
.
Slice
):
pass
elif
slice_is_num
(
target_node
):
value_code
=
ast_to_source_code
(
node
.
value
)
i
=
"paddle.cast("
\
"x=_jst.to_static_variable({}),"
\
"dtype='int64')"
.
format
(
ast_to_source_code
(
slice_node
))
assign_code
=
"{} = paddle.tensor.array_write(x={}, i={}, array={})"
\
.
format
(
target_name
,
value_code
,
i
,
target_name
)
assign_node
=
gast
.
parse
(
assign_code
).
body
[
0
]
return
assign_node
def
_is_list_append_tensor
(
self
,
node
):
"""
a.append(b): a is list, b is Tensor
self.x.append(b): self.x is list, b is Tensor
"""
assert
isinstance
(
node
,
gast
.
Call
)
# 1. The func is `append`.
if
not
isinstance
(
node
.
func
,
gast
.
Attribute
):
return
False
if
node
.
func
.
attr
!=
'append'
:
return
False
# 2. It's a `python list` to call append().
value_name
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
func
.
value
)).
strip
()
if
value_name
not
in
self
.
list_name_to_updated
:
return
False
# 3. The number of arg of append() is one
# Only one argument is supported in Python list.append()
if
len
(
node
.
args
)
!=
1
:
return
False
# TODO(liym27): The arg of append() should be Tensor. But because the type of arg is often wrong with static analysis,
# the arg is not required to be Tensor here.
# 4. The arg of append() is Tensor
# arg = node.args[0]
# if isinstance(arg, gast.Name):
# # TODO: `arg.id` may be not in scope_var_type_dict if `arg.id` is the arg of decorated function
# # Need a better way to confirm whether `arg.id` is a Tensor.
# try:
# var_type_set = self.scope_var_type_dict[arg.id]
# except KeyError:
# return False
# if NodeVarType.NUMPY_NDARRAY in var_type_set:
# return False
# if NodeVarType.TENSOR not in var_type_set and NodeVarType.PADDLE_RETURN_TYPES not in var_type_set:
# return False
# # TODO: Consider that `arg` may be a gast.Call about Paddle Api. eg: list_a.append(paddle.reshape(x))
# # else:
# # return True
self
.
list_name_to_updated
[
value_name
.
strip
()]
=
True
return
True
def
_need_to_create_tensor_array
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
try
:
target_id
=
target_node
.
id
except
AttributeError
:
return
False
if
self
.
list_name_to_updated
.
get
(
target_id
)
and
node
in
self
.
list_nodes
:
return
True
return
False
def
_create_tensor_array
(
self
,
value_node
):
# Although `dtype='float32'`, other types such as `int32` can also be supported
init_value
=
ast_to_source_code
(
value_node
).
strip
()
func_code
=
"paddle.tensor.create_array('float32', {})"
.
format
(
init_value
)
func_node
=
gast
.
parse
(
func_code
).
body
[
0
].
value
return
func_node
def
_to_array_write_node
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Call
)
array
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
func
.
value
))
x
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
args
[
0
]))
i
=
"paddle.tensor.array_length({})"
.
format
(
array
)
func_code
=
"paddle.tensor.array_write(x={}, i={}, array={})"
.
format
(
x
,
i
,
array
)
return
gast
.
parse
(
func_code
).
body
[
0
].
value
def
_update_list_name_to_updated
(
self
,
node
):
assert
isinstance
(
node
,
gast
.
Assign
)
target_node
=
node
.
targets
[
0
]
# NOTE: Code like `x, y = a, []` has been transformed to `x=a; y=[]`
try
:
target_id
=
target_node
.
id
except
AttributeError
:
return
False
value_node
=
node
.
value
if
isinstance
(
value_node
,
gast
.
List
):
self
.
list_name_to_updated
[
target_id
]
=
False
self
.
list_nodes
.
add
(
node
)
return
True
elif
target_id
in
self
.
list_name_to_updated
and
\
self
.
list_name_to_updated
[
target_id
]
==
False
:
del
self
.
list_name_to_updated
[
target_id
]
return
False
def
_replace_pop
(
self
,
node
):
"""
Replace a pop statement for a list or dict.
For example:
list_a = [0,1,2,3,4]
x = list_a.pop() # --> convert_pop(list_a)
y = list_a.pop(1) # --> convert_pop(list_a, 1)
dict_a = {"red":0, "blue":1, "yellow":2}
m = dict_a.pop("red") # --> convert_pop(dict_a, "red")
n = dict_a.pop("black", 3) # --> convert_pop(dict_a, "black", 3)
"""
assert
isinstance
(
node
,
gast
.
Call
)
assert
isinstance
(
node
.
func
,
gast
.
Attribute
)
target_node
=
node
.
func
.
value
target_str
=
ast_to_source_code
(
target_node
).
strip
()
args_str
=
[
ast_to_source_code
(
arg
).
strip
()
for
arg
in
node
.
args
]
# NOTE(liym27):
# 1. pop stmt for a list if len(args_str) == 0
# 2. pop stmt for a list or dict if len(args_str) == 1
# 3. pop stmt for a dict if len(args_str) == 2
if
len
(
args_str
)
<=
2
:
new_pop_str
=
"_jst.Pop({}, {})"
\
.
format
(
target_str
,
","
.
join
(
args_str
))
new_pop_node
=
gast
.
parse
(
new_pop_str
).
body
[
0
].
value
return
new_pop_node
else
:
return
node
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
a3436672
...
...
@@ -330,22 +330,6 @@ def is_numpy_api(node):
return
False
def
is_control_flow_to_transform
(
node
,
static_analysis_visitor
=
None
,
var_name_to_type
=
None
):
"""
Determines whether the node is a PaddlePaddle control flow statement which needs to
be transformed into a static graph control flow statement.
"""
assert
isinstance
(
node
,
gast
.
AST
),
\
"The type of input node must be gast.AST, but received %s."
%
type
(
node
)
visitor
=
IsControlFlowVisitor
(
node
,
static_analysis_visitor
,
node_var_type_map
=
var_name_to_type
)
need_to_transform
=
visitor
.
transform
()
return
need_to_transform
def
_delete_keywords_from
(
node
):
assert
isinstance
(
node
,
gast
.
Call
)
func_src
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
func
))
...
...
@@ -1001,31 +985,6 @@ def _compatible_non_tensor_spec(src_spec, desired_spec):
return
True
def
slice_is_num
(
slice_node
):
# A slice_node.slice can be a:
# (1) ast.Index, which is a simple number such as [1], [-2]
# (2) ast.Slice, which is represented by bounds such as [2:-1]
# (3) ast.Tuple, which includes the above two cases such as [2:-1, 1]
# If slice node is case (1), return True, Otherwise, return False.
#
# NOTE: In (1) case, when gast>=0.4.0, gast.Index is not used, which is replaced
# other gast node such as gast.Constant, gast.Name, gast.UnaryOp and so on.
# Considering the compatibility of gast, here use ast note to check whether the
# node is a num. For more details, please visit https://github.com/serge-sans-paille/gast
assert
isinstance
(
slice_node
,
gast
.
Subscript
)
slice_node_str
=
ast_to_source_code
(
slice_node
).
strip
()
ast_node
=
ast
.
parse
(
slice_node_str
).
body
[
0
].
value
if
isinstance
(
ast_node
.
slice
,
(
ast
.
Tuple
,
ast
.
Slice
)):
return
False
if
isinstance
(
ast_node
.
slice
,
ast
.
Index
):
return
True
return
False
class
NameScope
:
def
__init__
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录