Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a5036775
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a5036775
编写于
3月 04, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add program_cache in dygrapht_to_static (#22766)
上级
5ff2439f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
512 addition
and
67 deletion
+512
-67
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
+4
-0
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+23
-2
python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
...n/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
+345
-0
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+19
-60
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
...d/tests/unittests/dygraph_to_static/test_cache_program.py
+115
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
...luid/tests/unittests/dygraph_to_static/test_fetch_feed.py
+6
-5
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
浏览文件 @
a5036775
...
...
@@ -20,6 +20,10 @@ from .ast_transformer import *
from
.
import
static_analysis
from
.static_analysis
import
*
from
.
import
cache_program
from
.cache_program
import
*
__all__
=
[]
__all__
+=
ast_transformer
.
__all__
__all__
+=
static_analysis
.
__all__
__all__
+=
cache_program
.
__all__
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
a5036775
...
...
@@ -15,15 +15,17 @@
from
__future__
import
print_function
from
.utils
import
*
import
gast
import
textwrap
import
inspect
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
,
ast_to_func
from
paddle.fluid
import
unique_name
from
.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
__all__
=
[
'DygraphToStaticAst'
]
__all__
=
[
'DygraphToStaticAst'
,
'convert_to_static'
]
DECORATOR_NAMES
=
[
'dygraph_to_static_output'
,
'dygraph_to_static_graph'
]
...
...
@@ -253,3 +255,22 @@ class BasicApiTransformer(gast.NodeTransformer):
def
get_feed_name_to_arg_id
(
self
):
return
self
.
feed_name_to_arg_id
def
convert_to_static
(
dyfunc
):
"""
Converts dygraph function into static function.
"""
# Get AST from dygraph function
raw_code
=
inspect
.
getsource
(
dyfunc
)
code
=
textwrap
.
dedent
(
raw_code
)
root
=
gast
.
parse
(
code
)
# Transform AST
dygraph_to_static
=
DygraphToStaticAst
()
root_wrapper
=
dygraph_to_static
.
get_static_ast
(
root
)
# Get static_func from AST
func_name
=
dygraph_to_static
.
get_module_name
()
static_func
,
file_name
=
ast_to_func
(
root_wrapper
.
node
,
func_name
)
return
static_func
,
dygraph_to_static
python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
0 → 100644
浏览文件 @
a5036775
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
inspect
import
textwrap
import
threading
import
numpy
import
six
from
paddle.fluid
import
framework
from
paddle.fluid.layers
import
io
from
paddle.fluid
import
core
,
executor
from
paddle.fluid.dygraph.dygraph_to_static
import
convert_to_static
__all__
=
[
'AutoTracer'
]
class
FunctionCache
(
object
):
"""
Caches the transformed functions to avoid redundant conversions of the same function.
"""
def
__init__
(
self
):
self
.
_cache_funcs
=
dict
()
self
.
_func_to_transformer
=
dict
()
def
__call__
(
self
,
func
):
static_func
=
self
.
_get_or_cache_func
(
func
)
return
static_func
def
_get_or_cache_func
(
self
,
func
):
cache_key
=
self
.
hash_key
(
func
)
static_func
=
self
.
_cache_funcs
.
get
(
cache_key
,
None
)
if
static_func
is
None
:
static_func
,
dygraph_to_static
=
convert_to_static
(
func
)
self
.
_cache_funcs
[
cache_key
]
=
static_func
self
.
_func_to_transformer
[
static_func
]
=
dygraph_to_static
return
static_func
def
transformer
(
self
,
func
):
return
self
.
_func_to_transformer
.
get
(
func
,
None
)
def
hash_key
(
self
,
func
):
raw_code
=
inspect
.
getsource
(
func
)
code
=
textwrap
.
dedent
(
raw_code
)
return
hash
(
code
)
def
exist
(
self
,
func
):
return
self
.
_cache_funcs
.
get
(
self
.
hash_key
(
func
),
None
)
is
not
None
def
synchronized
(
func
):
func
.
__lock__
=
threading
.
Lock
()
def
lock_func
(
*
args
,
**
kwargs
):
with
func
.
__lock__
:
return
func
(
*
args
,
**
kwargs
)
return
lock_func
class
ProgramCache
(
object
):
"""
Wrapper class for the program functions defined by dygraph function.
"""
def
__init__
(
self
):
self
.
_inputs
=
[]
self
.
_outputs
=
[]
# Always set program to default_main_program. Because once `__call__` is called,
# it means layers(or Ops) are added into default_main_program switched by outer
# `with` statement.
self
.
_program
=
framework
.
default_main_program
()
self
.
_func_cache
=
FunctionCache
()
# Stores the entry function of Net or Model.
self
.
_forward_func
=
None
self
.
_feed_name_to_idx
=
{}
self
.
_is_repeated
=
False
# Indicates whether the function call is still building program.
# Because `__call__` can be called recursively when `Net` has
# sub class in `forward()`.
self
.
_in_build_process
=
True
def
__call__
(
self
,
dyfunc
,
*
args
,
**
kwargs
):
"""
Executes the main_program with specialized inputs.
"""
# Transfroms dygraph function into static functions and caches them.
static_func
=
self
.
_transform_or_cache_layers
(
dyfunc
)
# 1. Adds `fluid.data` layers for input if needed
if
not
self
.
_inputs
:
self
.
_add_feed_layers
(
args
,
kwargs
)
# 2. Avoids inserting forward ops repeatedly.
if
self
.
_is_repeated
:
return
self
.
outputs
# 3. Builds program only once and returns the output Variables.
outputs
=
self
.
_get_or_build_program
(
static_func
,
args
,
kwargs
)
if
static_func
==
self
.
_forward_func
:
self
.
_in_build_process
=
False
return
outputs
def
_transform_or_cache_layers
(
self
,
dyfunc
):
"""
Transforms dygraph function into static function.
"""
static_func
=
self
.
_func_cache
(
dyfunc
)
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once.
# After that, cached program will be always returned by default.
if
static_func
==
self
.
_forward_func
:
self
.
_is_repeated
=
True
if
self
.
_forward_func
is
None
:
self
.
_forward_func
=
static_func
return
static_func
def
_get_or_build_program
(
self
,
func
,
args
,
kwargs
):
"""
Returns program of the input function. If called at first time,
builds a new program and caches it.
"""
with
framework
.
program_guard
(
self
.
_program
):
if
func
==
self
.
_forward_func
:
# Replaces input data with `layers.data`
args
=
list
(
args
)
for
feed_layer
in
self
.
_inputs
:
idx
=
self
.
feed_name_to_idx
[
feed_layer
.
name
]
args
[
idx
]
=
feed_layer
fetch_list
=
func
(
*
args
,
**
kwargs
)
self
.
_outputs
=
fetch_list
else
:
fetch_list
=
func
(
*
args
,
**
kwargs
)
return
fetch_list
def
_add_feed_layers
(
self
,
args
,
kwargs
):
"""
Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
by `to_variable()`, it makes program to be executed dynamically.
"""
if
not
self
.
_feed_name_to_idx
:
self
.
_feed_name_to_idx
=
self
.
_get_name_to_idx
(
self
.
_forward_func
)
with
framework
.
program_guard
(
self
.
_program
):
for
feed_name
,
idx
in
self
.
feed_name_to_idx
.
items
():
batch_data
=
args
[
idx
]
assert
isinstance
(
batch_data
,
numpy
.
ndarray
),
"Input {} should be numpy.ndarray, but received {}."
.
format
(
feed_name
,
type
(
batch_data
))
feed_layer
=
io
.
data
(
name
=
feed_name
,
shape
=
list
(
batch_data
.
shape
[
1
:]),
dtype
=
str
(
batch_data
.
dtype
))
self
.
_inputs
.
append
(
feed_layer
)
def
_get_name_to_idx
(
self
,
func
):
"""
Returns name and index of input args from `forward(args)`
that need to be replaced with `fluid.data`.
"""
transformer
=
self
.
_func_cache
.
transformer
(
func
)
feed_name_to_idx
=
transformer
.
get_feed_name_to_idx
()
return
feed_name_to_idx
@
property
def
program
(
self
):
return
self
.
_program
@
property
def
inputs
(
self
):
return
self
.
_inputs
@
property
def
outputs
(
self
):
return
self
.
_outputs
@
property
def
feed_name_to_idx
(
self
):
return
self
.
_feed_name_to_idx
@
property
def
in_build_process
(
self
):
return
self
.
_in_build_process
class
AutoTracer
(
object
):
_instance
=
None
@
synchronized
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
object
.
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
__initialized
=
False
return
cls
.
_instance
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
raise
ValueError
(
"FuncProgram hasn
\'
t been created!"
)
return
cls
.
_instance
@
classmethod
def
reset
(
cls
):
if
cls
.
_instance
is
not
None
:
cls
.
_instance
.
__initialized
=
False
cls
.
_instance
.
__init__
()
def
__init__
(
self
,
exe
=
None
,
place
=
None
):
# To make sure that calls __init__ only once.
if
self
.
__initialized
:
return
self
.
__initialized
=
True
self
.
_place
=
core
.
CPUPlace
()
if
place
is
None
else
place
if
exe
is
None
:
self
.
_exe
=
executor
.
Executor
(
self
.
_place
)
else
:
self
.
_exe
=
exe
self
.
_cached_program
=
ProgramCache
()
self
.
_optimizer
=
None
self
.
_already_minimized
=
False
# Once main_program is changed, should run startup_program.
self
.
_need_startup
=
True
def
run
(
self
,
*
args
,
**
kwargs
):
"""
Executes main_program and returns output Tensors.
"""
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
main_program
=
self
.
_cached_program
.
program
outputs
=
self
.
_exe
.
run
(
main_program
,
feed
=
feed_dict
,
fetch_list
=
fetch_list
)
return
outputs
def
_prepare
(
self
,
args
):
"""
Prepares with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Updates batch_data for feed_dict
feed_dict
=
self
.
_update_batch_data
(
args
)
fetch_list
=
self
.
_cached_program
.
outputs
# Adds optimizer if needed.
if
self
.
_optimizer
and
not
self
.
_already_minimized
:
self
.
_add_optimizer
()
if
self
.
_need_startup
:
self
.
_exe
.
run
(
framework
.
default_startup_program
())
self
.
_need_startup
=
False
return
feed_dict
,
fetch_list
def
_check_cache_valid
(
self
):
"""
Checks whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
if
self
.
_cached_program
.
program
:
if
self
.
_cached_program
.
program
!=
framework
.
default_main_program
():
AutoTracer
.
reset
()
def
_update_batch_data
(
self
,
args
):
"""
Updates cached batch data while training program.
"""
feed_name_to_idx
=
self
.
_cached_program
.
feed_name_to_idx
feed_vars
=
self
.
_cached_program
.
inputs
feed_dict
=
{}
for
feed_var
in
feed_vars
:
idx
=
feed_name_to_idx
[
feed_var
.
name
]
feed_dict
[
feed_var
.
name
]
=
args
[
idx
]
return
feed_dict
def
set_optimizer
(
self
,
optimizer
,
loss_name
):
"""
Supports to set or update the optimizer used to minimize loss.
"""
self
.
_check_cache_valid
()
self
.
_optimizer
=
optimizer
if
not
isinstance
(
loss_name
,
six
.
string_types
):
raise
ValueError
(
"Type of input loss_name should type(str), but received {}."
.
format
(
type
(
loss_name
)))
self
.
_loss_name
=
loss_name
def
_add_optimizer
(
self
):
"""
Supports to set or update the optimizer used to minimize loss.
"""
main_program
=
self
.
_cached_program
.
program
all_vars
=
main_program
.
block
(
0
).
vars
loss_var
=
all_vars
.
get
(
self
.
_loss_name
,
None
)
if
loss_var
is
None
:
raise
ValueError
(
"Can't find {} in main_program, please confirm whether the loss input is correct"
.
format
(
self
.
_loss_name
))
# Adds optimizer to minimize loss
with
framework
.
program_guard
(
main_program
):
self
.
_optimizer
.
minimize
(
loss_var
)
# Avoids to set optimizer repeatedly.
self
.
_already_minimized
=
True
def
get_cached_program
(
self
):
"""
Returns the ProgramCache instance.
"""
self
.
_check_cache_valid
()
return
self
.
_cached_program
@
property
def
program
(
self
):
return
self
.
_cached_program
.
program
python/paddle/fluid/dygraph/jit.py
浏览文件 @
a5036775
...
...
@@ -16,21 +16,16 @@ from __future__ import print_function
__all__
=
[
'TracedLayer'
,
'dygraph_to_static_output'
,
'dygraph_to_static_graph'
]
import
gast
import
inspect
import
textwrap
import
warnings
from
..wrapped_decorator
import
wrap_decorator
from
.base
import
program_desc_tracing_guard
,
switch_to_static_graph
from
.dygraph_to_static
import
DygraphToStaticAst
from
.dygraph_to_static.ast_utils
import
ast_to_func
from
.dygraph_to_static
import
AutoTracer
,
convert_to_static
from
.layers
import
Layer
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Program
,
Block
,
Variable
,
_dygraph_tracer
,
dygraph_only
,
_dygraph_guard
,
_current_expected_place
,
in_dygraph_mode
from
paddle.fluid.executor
import
Executor
,
scope_guard
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid
import
program_guard
,
data
,
default_startup_program
,
default_main_program
def
create_program_from_desc
(
program_desc
):
...
...
@@ -56,23 +51,6 @@ def extract_vars(inputs):
return
result_list
def
to_static_func
(
dygraph_func
):
# Get AST from dygraph function
dygraph_code
=
inspect
.
getsource
(
dygraph_func
)
dygraph_code
=
textwrap
.
dedent
(
dygraph_code
)
root
=
gast
.
parse
(
dygraph_code
)
# Transform AST
dygraph_to_static
=
DygraphToStaticAst
()
root_wrapper
=
dygraph_to_static
.
get_static_ast
(
root
)
# Get static_func from AST
func_name
=
dygraph_to_static
.
get_module_name
()
static_func
,
file_name
=
ast_to_func
(
root_wrapper
.
node
,
func_name
)
return
static_func
,
dygraph_to_static
def
_dygraph_to_static_graph_
(
dygraph_func
):
def
__impl__
(
*
args
,
**
kwargs
):
if
in_dygraph_mode
():
...
...
@@ -80,13 +58,20 @@ def _dygraph_to_static_graph_(dygraph_func):
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
static_func
,
dygraph_to_static
=
to_static_fun
c
(
dygraph_func
)
static_func
,
ast_transformer
=
convert_to_stati
c
(
dygraph_func
)
return
static_func
(
*
args
,
**
kwargs
)
return
__impl__
dygraph_to_static_graph
=
wrap_decorator
(
_dygraph_to_static_graph_
)
def
_dygraph_to_static_output_
(
dygraph_func
):
# Singleton object to cache main_program to avoid inserting ops repeatedly.
# TODO: Need a better class name
auto_tracer
=
AutoTracer
()
def
__impl__
(
*
args
,
**
kwargs
):
if
in_dygraph_mode
():
warnings
.
warn
(
...
...
@@ -94,45 +79,19 @@ def _dygraph_to_static_output_(dygraph_func):
" Please use it in static mode."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
static_func
,
dygraph_to_static
=
to_static_func
(
dygraph_func
)
feed_name_to_idx
=
dygraph_to_static
.
get_feed_name_to_idx
()
feed_dict
=
{}
for
feed_name
,
idx
in
feed_name_to_idx
.
items
():
feed_dict
[
feed_name
]
=
args
[
idx
]
# Run static_func in static mode
startup_program
=
default_main_program
()
main_program
=
default_startup_program
()
static_res
=
run_static_func
(
main_program
,
startup_program
,
static_func
,
args
,
kwargs
,
feed_dict
,
feed_name_to_idx
)
return
static_res
return
__impl__
cached_program
=
auto_tracer
.
get_cached_program
()
outputs
=
cached_program
(
dygraph_func
,
*
args
,
**
kwargs
)
# Run program to fetch output Tensors once building successfully.
if
not
cached_program
.
in_build_process
:
outputs
=
auto_tracer
.
run
(
*
args
,
**
kwargs
)
def
run_static_func
(
main_program
,
startup_program
,
static_func
,
args
,
kwargs
,
feed_dict
,
feed_name_to_idx
):
return
outputs
with
program_guard
(
main_program
,
startup_program
):
args_list
=
list
(
args
)
for
var_name
,
value
in
feed_dict
.
items
():
idx
=
feed_name_to_idx
[
var_name
]
args_list
[
idx
]
=
data
(
name
=
var_name
,
shape
=
value
.
shape
,
dtype
=
str
(
value
.
dtype
))
args
=
tuple
(
args_list
)
static_out
=
static_func
(
*
args
,
**
kwargs
)
if
not
isinstance
(
static_out
,
(
list
,
tuple
)):
static_out
=
[
static_out
]
exe
=
Executor
(
core
.
CPUPlace
())
exe
.
run
(
startup_program
)
static_res
=
exe
.
run
(
main_program
,
fetch_list
=
static_out
,
feed
=
feed_dict
)
return
static_res
return
__impl__
dygraph_to_static_output
=
wrap_decorator
(
_dygraph_to_static_output_
)
dygraph_to_static_graph
=
wrap_decorator
(
_dygraph_to_static_graph_
)
@
dygraph_only
...
...
@@ -394,11 +353,11 @@ class TracedLayer(object):
in_var = to_variable(in_np)
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
place = fluid.CPUPlace()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_vars, fetch_vars = fluid.io.load_inference_model(save_dirname,
exe)
exe)
fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
print(fetch.shape) # (2, 10)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
0 → 100644
浏览文件 @
a5036775
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
collections
import
Counter
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static
import
AutoTracer
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_output
from
test_fetch_feed
import
Pool2D
,
Linear
class
TestCacheProgram
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_num
=
5
self
.
dygraph_class
=
Pool2D
self
.
data
=
np
.
random
.
random
((
1
,
2
,
4
,
4
)).
astype
(
'float32'
)
def
test_cache
(
self
):
prev_ops
,
cur_ops
=
Counter
(),
Counter
()
prev_out
,
cur_out
=
None
,
None
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
for
batch_id
in
range
(
self
.
batch_num
):
out
=
static_net
(
self
.
data
)
# Check outputs
prev_out
=
cur_out
cur_out
=
out
# Check forward ops
prev_ops
=
cur_ops
cur_ops
=
Counter
([
op
.
type
for
op
in
fluid
.
default_main_program
().
block
(
0
).
ops
])
if
batch_id
>
0
:
self
.
assertTrue
(
np
.
allclose
(
prev_out
[
0
],
cur_out
[
0
]),
msg
=
'Output in previous batch is {}
\n
Output in current batch is
\n
{}'
.
format
(
prev_out
,
cur_out
))
self
.
assertEqual
(
prev_ops
,
cur_ops
)
class
TestCacheProgram2
(
TestCacheProgram
):
def
setUp
(
self
):
self
.
batch_num
=
5
self
.
dygraph_class
=
Linear
self
.
data
=
np
.
random
.
random
((
4
,
10
)).
astype
(
'float32'
)
class
TestCacheProgramWithOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dygraph_class
=
Linear
self
.
data
=
np
.
random
.
random
((
4
,
10
)).
astype
(
'float32'
)
self
.
batch_num
=
5
def
train_static
(
self
):
main_program
=
fluid
.
Program
()
loss_data
=
[]
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
auto_tracer
=
AutoTracer
()
auto_tracer
.
set_optimizer
(
adam
,
'avg_loss'
)
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
static_net
(
self
.
data
)
loss_data
.
append
(
np
.
array
(
avg_loss
))
return
loss_data
def
train_dygraph
(
self
):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
dygraph_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
,
parameter_list
=
dygraph_net
.
parameters
())
loss_data
=
[]
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
dygraph_net
(
self
.
data
)
loss_data
.
append
(
avg_loss
.
numpy
())
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
dygraph_net
.
clear_gradients
()
return
loss_data
def
test_with_optimizer
(
self
):
dygraph_loss
=
self
.
train_dygraph
()
static_loss
=
self
.
train_static
()
self
.
assertTrue
(
np
.
allclose
(
dygraph_loss
,
static_loss
),
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
static_loss
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
浏览文件 @
a5036775
...
...
@@ -58,7 +58,8 @@ class Linear(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
pre
=
self
.
fc
(
inputs
)
return
pre
loss
=
fluid
.
layers
.
mean
(
pre
,
name
=
'avg_loss'
)
return
pre
,
loss
class
TestPool2D
(
unittest
.
TestCase
):
...
...
@@ -69,10 +70,11 @@ class TestPool2D(unittest.TestCase):
def
run_dygraph_mode
(
self
):
with
fluid
.
dygraph
.
guard
():
dy_layer
=
self
.
dygraph_class
()
for
_
in
range
(
1
):
prediction
=
dy_layer
(
x
=
self
.
data
)
if
isinstance
(
prediction
,
(
list
,
tuple
)):
prediction
=
prediction
[
0
]
prediction
=
dy_layer
(
x
=
self
.
data
)
return
prediction
.
numpy
()
return
prediction
.
numpy
()
def
run_static_mode
(
self
):
startup_prog
=
fluid
.
Program
()
...
...
@@ -90,7 +92,6 @@ class TestPool2D(unittest.TestCase):
np
.
allclose
(
dygraph_res
,
static_res
),
msg
=
'dygraph_res is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_res
,
static_res
))
return
class
TestLinear
(
TestPool2D
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录