Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
842050f2
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
842050f2
编写于
2月 14, 2023
作者:
A
Aurelius84
提交者:
GitHub
2月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Dy2St]Enhance @not_to_static API (#50453)
* [Dy2St]Enhance @not_to_static API * del breakpoint()
上级
c5087da8
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
62 addition
and
84 deletion
+62
-84
.gitignore
.gitignore
+1
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py
...id/tests/unittests/dygraph_to_static/test_convert_call.py
+33
-79
python/paddle/jit/api.py
python/paddle/jit/api.py
+1
-2
python/paddle/jit/dy2static/convert_call_func.py
python/paddle/jit/dy2static/convert_call_func.py
+14
-1
python/paddle/jit/dy2static/program_translator.py
python/paddle/jit/dy2static/program_translator.py
+7
-0
python/paddle/jit/dy2static/utils.py
python/paddle/jit/dy2static/utils.py
+6
-2
未找到文件。
.gitignore
浏览文件 @
842050f2
...
...
@@ -85,3 +85,4 @@ paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h
paddle/fluid/prim/api/generated/prim_api/*
python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_call.py
浏览文件 @
842050f2
...
...
@@ -16,12 +16,12 @@ import logging
import
unittest
import
numpy
as
np
from
test_program_translator
import
get_source_code
import
paddle
import
paddle.fluid
as
fluid
import
paddle.jit.dy2static
as
_jst
from
paddle.jit.dy2static.convert_call_func
import
CONVERSION_OPTIONS
from
paddle.jit.dy2static.utils
import
func_to_source_code
SEED
=
2020
np
.
random
.
seed
(
SEED
)
...
...
@@ -216,103 +216,57 @@ class TestStaticMethod(TestRecursiveCall2):
# Situation 2 : test not_to_static
def
func_sum
(
x
):
res
=
paddle
.
sum
(
x
)
return
res
@
paddle
.
jit
.
not_to_static
def
func_not_to_static
(
x
):
res
=
func_sum
(
x
)
return
res
class
NotToStaticHelper
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
NotToStaticHelper
,
self
).
__init__
()
@
paddle
.
jit
.
to_static
def
func_convert_then_not_to_static
(
x
):
y
=
func_not_to_static
(
x
)
return
y
def
sum
(
self
,
x
):
if
x
.
shape
[
0
]
>
1
:
res
=
x
+
1
res
=
paddle
.
sum
(
x
)
return
res
def
outer
(
self
,
x
):
res
=
self
.
sum
(
x
)
return
res
class
TestClass
(
paddle
.
nn
.
Layer
):
@
paddle
.
jit
.
not_to_static
def
called_member
(
self
,
x
):
return
paddle
.
sum
(
x
)
@
paddle
.
jit
.
to_static
def
forward
(
self
,
x
):
y
=
self
.
called_member
(
x
)
return
y
def
inner
(
self
,
x
):
return
self
.
outer
(
x
)
class
TestNotToConvert
(
TestRecursiveCall2
):
def
set_func
(
self
):
self
.
dygraph_func
=
func_not_to_static
self
.
net
=
NotToStaticHelper
()
paddle
.
jit
.
not_to_static
(
self
.
net
.
sum
)
self
.
dygraph_func
=
paddle
.
jit
.
to_static
(
self
.
net
.
outer
)
def
test_conversion_options
(
self
):
options
=
getattr
(
self
.
dygraph_func
,
CONVERSION_OPTIONS
,
None
)
options
=
getattr
(
self
.
net
.
sum
,
CONVERSION_OPTIONS
,
None
)
self
.
assertIsNotNone
(
options
)
self
.
assertTrue
(
options
.
not_convert
)
class
TestNotToConvert2
(
TestRecursiveCall2
):
def
set_func
(
self
):
self
.
dygraph_func
=
func_convert_then_not_to_static
class
TestNotToConvert3
(
TestRecursiveCall2
):
def
set_func
(
self
):
self
.
dygraph_func
=
TestClass
()
class
TestDynamicToStaticCode
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
set_func
()
self
.
set_answer_func
()
def
set_func
(
self
):
self
.
func
=
func_not_to_static
def
set_answer_func
(
self
):
class
StaticCode
:
@
paddle
.
jit
.
not_to_static
def
func_not_to_static
(
x
):
res
=
func_sum
(
x
)
return
res
self
.
answer_func
=
StaticCode
.
func_not_to_static
def
_get_answer_code
(
self
):
return
get_source_code
(
self
.
answer_func
)
def
_get_transformed_code
(
self
):
transformed_func
=
_jst
.
Call
(
self
.
func
)
return
get_source_code
(
transformed_func
)
def
test_code
(
self
):
transformed_code
=
self
.
_get_transformed_code
()
answer_code
=
self
.
_get_answer_code
()
self
.
assertEqual
(
answer_code
,
transformed_code
,
msg
=
"
\n
transformed_code :
\n
{}
\n
answer_code :
\n
{}"
.
format
(
transformed_code
,
answer_code
),
# check 'if statement' is not converted
self
.
assertIn
(
"if x.shape[0] > 1"
,
func_to_source_code
(
_jst
.
Call
(
self
.
net
.
sum
))
)
class
Test
DynamicToStaticCode2
(
TestDynamicToStaticCode
):
class
Test
NotToConvert2
(
TestRecursiveCall2
):
def
set_func
(
self
):
self
.
func
=
func_convert_then_not_to_static
self
.
net
=
NotToStaticHelper
()
# for to_static(not_to_static(function)) == enable_static
paddle
.
jit
.
not_to_static
(
self
.
net
.
sum
)
self
.
dygraph_func
=
paddle
.
jit
.
to_static
(
self
.
net
.
sum
)
def
set_answer_func
(
self
):
class
StaticCode
:
def
func_convert_then_not_to_static
(
x
):
__return_value_0
=
None
y
=
_jst
.
Call
(
func_not_to_static
)(
x
)
__return_value_0
=
y
return
__return_value_0
def
test_conversion_options
(
self
):
options
=
getattr
(
self
.
net
.
sum
,
CONVERSION_OPTIONS
,
None
)
self
.
assertIsNotNone
(
options
)
self
.
assertTrue
(
options
.
not_convert
)
self
.
answer_func
=
StaticCode
.
func_convert_then_not_to_static
def
test_code
(
self
):
# check 'if statement' is not converted
self
.
assertIn
(
"if x.shape[0] > 1"
,
self
.
dygraph_func
.
code
)
if
__name__
==
'__main__'
:
...
...
python/paddle/jit/api.py
浏览文件 @
842050f2
...
...
@@ -42,7 +42,6 @@ from paddle.fluid.dygraph.base import (
from
.dy2static
import
logging_utils
from
.dy2static.convert_call_func
import
(
ConversionOptions
,
CONVERSION_OPTIONS
,
add_ignore_module
,
)
from
.dy2static.program_translator
import
(
...
...
@@ -348,7 +347,7 @@ def not_to_static(func=None):
return
not_to_static
options
=
ConversionOptions
(
not_convert
=
True
)
setattr
(
func
,
CONVERSION_OPTIONS
,
options
)
options
.
attach
(
func
)
return
func
...
...
python/paddle/jit/dy2static/convert_call_func.py
浏览文件 @
842050f2
...
...
@@ -42,7 +42,7 @@ __all__ = []
translator_logger
=
TranslatorLogger
()
CONVERSION_OPTIONS
=
"
An attribute for a function that indicates conversion flags of the function in dynamic-to-static.
"
CONVERSION_OPTIONS
=
"
__jst_not_to_static
"
class
ConversionOptions
:
...
...
@@ -58,6 +58,19 @@ class ConversionOptions:
def
__init__
(
self
,
not_convert
=
False
):
self
.
not_convert
=
not_convert
def
attach
(
self
,
func
):
if
inspect
.
ismethod
(
func
):
func
=
func
.
__func__
if
inspect
.
isfunction
(
func
):
setattr
(
func
,
CONVERSION_OPTIONS
,
self
)
else
:
translator_logger
.
warn
(
"Only support @not_to_static to type(function) or type(method), but recevied {}"
.
format
(
type
(
func
)
)
)
def
is_builtin
(
func
,
name
=
None
):
"""predict whether a function is a builtin function with name={name}.
...
...
python/paddle/jit/dy2static/program_translator.py
浏览文件 @
842050f2
...
...
@@ -28,6 +28,7 @@ from paddle.utils import gast
from
.
import
error
,
logging_utils
from
.ast_transformer
import
DygraphToStaticAst
from
.convert_call_func
import
CONVERSION_OPTIONS
from
.function_spec
import
(
FunctionSpec
,
_hash_spec_names
,
...
...
@@ -152,6 +153,12 @@ def convert_to_static(function):
"""
if
getattr
(
function
,
ALREADY_D2S
,
None
):
return
function
# Return directly if decorated with @not_to_static and DO NOT Cache it
options
=
getattr
(
function
,
CONVERSION_OPTIONS
,
None
)
if
options
is
not
None
and
options
.
not_convert
:
return
function
.
__func__
if
inspect
.
ismethod
(
function
)
else
function
with
_CACHE_LOCK
:
static_func
=
_FUNCTION_CACHE
.
convert_with_cache
(
function
)
setattr
(
static_func
,
ALREADY_D2S
,
True
)
...
...
python/paddle/jit/dy2static/utils.py
浏览文件 @
842050f2
...
...
@@ -15,6 +15,7 @@
import
ast
import
atexit
import
copy
import
functools
import
importlib.util
import
inspect
import
os
...
...
@@ -23,7 +24,6 @@ import sys
import
tempfile
import
textwrap
import
warnings
from
functools
import
reduce
from
importlib.machinery
import
SourceFileLoader
import
astor
...
...
@@ -637,6 +637,8 @@ def func_to_source_code(function, dedent=True):
"""
Transforms function into raw string of source code.
"""
if
isinstance
(
function
,
functools
.
partial
):
function
=
function
.
func
if
not
(
inspect
.
isfunction
(
function
)
or
inspect
.
ismethod
(
function
)):
raise
TypeError
(
"The type of 'function' should be a function or method, but received {}."
.
format
(
...
...
@@ -1429,7 +1431,9 @@ class GetterSetterHelper:
def
__init__
(
self
,
getter_func
,
setter_func
,
*
name_lists
):
name_lists
=
map
(
lambda
x
:
[]
if
x
is
None
else
x
,
name_lists
)
name_sets
=
map
(
lambda
x
:
set
(
x
),
name_lists
)
self
.
_union
=
list
(
reduce
(
lambda
x
,
y
:
x
|
y
,
name_sets
,
set
()))
self
.
_union
=
list
(
functools
.
reduce
(
lambda
x
,
y
:
x
|
y
,
name_sets
,
set
())
)
self
.
_union
.
sort
()
self
.
getter
=
getter_func
self
.
setter
=
setter_func
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录