Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f9d39b49
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f9d39b49
编写于
9月 02, 2020
作者:
L
liym27
提交者:
GitHub
9月 02, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2Stat] Transforme api 'to_tensor' to 'assign'. (#26873)
上级
932bbe95
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
63 addition
and
43 deletion
+63
-43
python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py
.../fluid/dygraph/dygraph_to_static/basic_api_transformer.py
+43
-8
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
+8
-33
python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py
...ttests/dygraph_to_static/test_basic_api_transformation.py
+11
-1
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
...d/tests/unittests/dygraph_to_static/test_logging_utils.py
+1
-1
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/basic_api_transformer.py
浏览文件 @
f9d39b49
...
...
@@ -16,9 +16,7 @@ import astor
import
gast
from
paddle.fluid.dygraph.dygraph_to_static.static_analysis
import
AstNodeWrapper
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_dygraph_api
,
is_to_variable
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
to_assign_node
,
to_static_ast
,
update_args_of_func
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
dygraph_class_to_static_api
from
paddle.fluid.dygraph.dygraph_to_static
import
utils
class
BasicApiTransformer
(
gast
.
NodeTransformer
):
...
...
@@ -56,7 +54,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if
isinstance
(
child_node
,
gast
.
Call
):
# TODO(liym27):
# Considers that a dygraph api which modifies the input or has a output.
if
is_dygraph_api
(
child_node
):
if
utils
.
is_dygraph_api
(
child_node
):
return
else
:
self
.
_visit_Call
(
child_node
)
...
...
@@ -73,7 +71,7 @@ class BasicApiTransformer(gast.NodeTransformer):
if
self
.
_is_dygraph_forward
(
func_name
):
class_node
=
self
.
_get_class_node
(
func_name
)
static_node
=
to_static_ast
(
node
,
class_node
)
static_node
=
utils
.
to_static_ast
(
node
,
class_node
)
return
static_node
else
:
return
node
...
...
@@ -91,14 +89,51 @@ class BasicApiTransformer(gast.NodeTransformer):
if
is_to_variable
(
node_value
):
return
False
if
is_dygraph_api
(
node_value
):
if
utils
.
is_dygraph_api
(
node_value
):
dygraph_api
=
node_value
.
func
.
attr
if
not
dygraph_class_to_static_api
.
get
(
dygraph_api
):
if
not
utils
.
dygraph_class_to_static_api
.
get
(
dygraph_api
):
return
False
update_args_of_func
(
node_value
,
node_value
,
"__init__"
)
u
tils
.
u
pdate_args_of_func
(
node_value
,
node_value
,
"__init__"
)
target_str
=
astor
.
to_source
(
gast
.
gast_to_ast
(
node
.
targets
[
0
]))
self
.
class_node_dict
[
target_str
]
=
node_value
return
True
# TODO: node.value is not dygraph class
return
False
def
is_to_variable
(
node
):
assert
isinstance
(
node
,
gast
.
Call
)
api_name
=
utils
.
ast_to_source_code
(
node
.
func
).
strip
()
if
utils
.
is_dygraph_api
(
node
):
return
api_name
.
endswith
(
"to_variable"
)
if
utils
.
is_paddle_api
(
node
):
return
api_name
.
endswith
(
"to_tensor"
)
return
False
def
to_assign_node
(
node
):
# Transform dygraph api `fluid.dygraph.to_variable` alias `paddle.to_tensor` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert
isinstance
(
node
,
gast
.
Call
)
assign_api
=
gast
.
parse
(
'fluid.layers.assign'
).
body
[
0
].
value
node
.
func
=
assign_api
if
node
.
args
:
node
.
args
=
[
node
.
args
[
0
]]
node
.
keywords
=
[]
else
:
for
idx
,
kw
in
enumerate
(
node
.
keywords
):
if
kw
.
arg
==
'value'
or
kw
.
arg
==
'data'
:
node
.
keywords
[
idx
].
arg
=
'input'
node
.
keywords
=
[
node
.
keywords
[
idx
]]
node
.
args
=
[]
break
return
node
python/paddle/fluid/dygraph/dygraph_to_static/utils.py
浏览文件 @
f9d39b49
...
...
@@ -136,9 +136,12 @@ def is_api_in_module(node, module_prefix):
# import_str = "".join(import_statements)
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph
as
dygraph
import
paddle.fluid.layers
as
layers
from
paddle.fluid.dygraph
import
to_variable
import
paddle.fluid.dygraph
as
dygraph
from
paddle
import
to_tensor
return
eval
(
"_is_api_in_module_helper({}, '{}')"
.
format
(
func_str
,
module_prefix
))
except
NameError
:
...
...
@@ -146,15 +149,18 @@ def is_api_in_module(node, module_prefix):
def
is_dygraph_api
(
node
):
# Note: A api in module dygraph_to_static is not a real dygraph api.
if
is_api_in_module
(
node
,
"paddle.fluid.dygraph.dygraph_to_static"
):
return
False
# TODO(liym27): A better way to determine whether it is a dygraph api.
# Consider the decorator @dygraph_only
return
is_api_in_module
(
node
,
"paddle.fluid.dygraph"
)
def
is_paddle_api
(
node
):
return
is_api_in_module
(
node
,
"paddle
.fluid
"
)
return
is_api_in_module
(
node
,
"paddle"
)
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
...
...
@@ -233,14 +239,6 @@ def _add_keywords_to(node, dygraph_api_name):
return
def
is_to_variable
(
node
):
assert
isinstance
(
node
,
gast
.
Call
)
if
is_dygraph_api
(
node
):
api_name
=
ast_to_source_code
(
node
.
func
).
strip
()
return
api_name
.
endswith
(
"to_variable"
)
return
False
def
to_static_ast
(
node
,
class_node
):
assert
isinstance
(
node
,
gast
.
Call
)
assert
isinstance
(
class_node
,
gast
.
Call
)
...
...
@@ -268,29 +266,6 @@ def to_static_ast(node, class_node):
return
node
def
to_assign_node
(
node
):
# Transform dygraph api `fluid.dygraph.to_variable` to static api `fluid.layers.assign`.
# NOTE:
# 1. Api `to_variable` supports data type {float16, float32, float64, int16, int32, int64, uint8, uint16},
# but api `assign` only supports {float32, float64, int32, int64, bool};
# 2. If the input of api `assign` is numpy.ndarray, its size cannot be greater than 1024 * 1024.
assert
isinstance
(
node
,
gast
.
Call
)
assign_api
=
gast
.
parse
(
'fluid.layers.assign'
).
body
[
0
].
value
node
.
func
=
assign_api
if
node
.
args
:
node
.
args
=
[
node
.
args
[
0
]]
node
.
keywords
=
[]
else
:
for
idx
,
kw
in
enumerate
(
node
.
keywords
):
if
kw
.
arg
==
'value'
:
node
.
keywords
[
idx
].
arg
=
'input'
node
.
keywords
=
[
node
.
keywords
[
idx
]]
node
.
args
=
[]
break
return
node
def
update_args_of_func
(
node
,
dygraph_node
,
method_name
):
assert
isinstance
(
node
,
gast
.
Call
)
if
method_name
not
in
[
"__init__"
,
"forward"
]:
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_basic_api_transformation.py
浏览文件 @
f9d39b49
...
...
@@ -19,9 +19,11 @@ import unittest
import
inspect
import
gast
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph
as
dygraph
from
paddle
import
to_tensor
from
paddle.fluid.dygraph
import
to_variable
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_func
from
paddle.fluid.dygraph.dygraph_to_static.utils
import
is_dygraph_api
...
...
@@ -45,11 +47,19 @@ def dyfunc_to_variable_3(x):
return
res
def
dyfunc_to_tensor
(
x
):
res1
=
paddle
.
to_tensor
(
x
,
dtype
=
None
,
place
=
None
,
stop_gradient
=
True
)
res2
=
paddle
.
tensor
.
to_tensor
(
data
=
res1
)
res3
=
to_tensor
(
data
=
res2
)
return
res3
class
TestDygraphBasicApi_ToVariable
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
input
=
np
.
ones
(
5
).
astype
(
"int32"
)
self
.
test_funcs
=
[
dyfunc_to_variable
,
dyfunc_to_variable_2
,
dyfunc_to_variable_3
dyfunc_to_tensor
,
dyfunc_to_variable
,
dyfunc_to_variable_2
,
dyfunc_to_variable_3
]
self
.
place
=
fluid
.
CUDAPlace
(
0
)
if
fluid
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_logging_utils.py
浏览文件 @
f9d39b49
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录