Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
a5036775
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a5036775
编写于
3月 04, 2020
作者:
A
Aurelius84
提交者:
GitHub
3月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add program_cache in dygrapht_to_static (#22766)
上级
5ff2439f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
512 addition
and
67 deletion
+512
-67
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
+4
-0
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
...paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
+23
-2
python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
...n/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
+345
-0
python/paddle/fluid/dygraph/jit.py
python/paddle/fluid/dygraph/jit.py
+19
-60
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
...d/tests/unittests/dygraph_to_static/test_cache_program.py
+115
-0
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
...luid/tests/unittests/dygraph_to_static/test_fetch_feed.py
+6
-5
未找到文件。
python/paddle/fluid/dygraph/dygraph_to_static/__init__.py
浏览文件 @
a5036775
...
...
@@ -20,6 +20,10 @@ from .ast_transformer import *
from
.
import
static_analysis
from
.static_analysis
import
*
from
.
import
cache_program
from
.cache_program
import
*
__all__
=
[]
__all__
+=
ast_transformer
.
__all__
__all__
+=
static_analysis
.
__all__
__all__
+=
cache_program
.
__all__
python/paddle/fluid/dygraph/dygraph_to_static/ast_transformer.py
浏览文件 @
a5036775
...
...
@@ -15,15 +15,17 @@
from
__future__
import
print_function
from
.utils
import
*
import
gast
import
textwrap
import
inspect
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
from
.ast_utils
import
is_control_flow_if
,
create_cond_node
,
transform_if_else
,
ast_to_func
from
paddle.fluid
import
unique_name
from
.static_analysis
import
AstNodeWrapper
,
StaticAnalysisVisitor
__all__
=
[
'DygraphToStaticAst'
]
__all__
=
[
'DygraphToStaticAst'
,
'convert_to_static'
]
DECORATOR_NAMES
=
[
'dygraph_to_static_output'
,
'dygraph_to_static_graph'
]
...
...
@@ -253,3 +255,22 @@ class BasicApiTransformer(gast.NodeTransformer):
def
get_feed_name_to_arg_id
(
self
):
return
self
.
feed_name_to_arg_id
def
convert_to_static
(
dyfunc
):
"""
Converts dygraph function into static function.
"""
# Get AST from dygraph function
raw_code
=
inspect
.
getsource
(
dyfunc
)
code
=
textwrap
.
dedent
(
raw_code
)
root
=
gast
.
parse
(
code
)
# Transform AST
dygraph_to_static
=
DygraphToStaticAst
()
root_wrapper
=
dygraph_to_static
.
get_static_ast
(
root
)
# Get static_func from AST
func_name
=
dygraph_to_static
.
get_module_name
()
static_func
,
file_name
=
ast_to_func
(
root_wrapper
.
node
,
func_name
)
return
static_func
,
dygraph_to_static
python/paddle/fluid/dygraph/dygraph_to_static/cache_program.py
0 → 100644
浏览文件 @
a5036775
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
inspect
import
textwrap
import
threading
import
numpy
import
six
from
paddle.fluid
import
framework
from
paddle.fluid.layers
import
io
from
paddle.fluid
import
core
,
executor
from
paddle.fluid.dygraph.dygraph_to_static
import
convert_to_static
__all__
=
[
'AutoTracer'
]
class
FunctionCache
(
object
):
"""
Caches the transformed functions to avoid redundant conversions of the same function.
"""
def
__init__
(
self
):
self
.
_cache_funcs
=
dict
()
self
.
_func_to_transformer
=
dict
()
def
__call__
(
self
,
func
):
static_func
=
self
.
_get_or_cache_func
(
func
)
return
static_func
def
_get_or_cache_func
(
self
,
func
):
cache_key
=
self
.
hash_key
(
func
)
static_func
=
self
.
_cache_funcs
.
get
(
cache_key
,
None
)
if
static_func
is
None
:
static_func
,
dygraph_to_static
=
convert_to_static
(
func
)
self
.
_cache_funcs
[
cache_key
]
=
static_func
self
.
_func_to_transformer
[
static_func
]
=
dygraph_to_static
return
static_func
def
transformer
(
self
,
func
):
return
self
.
_func_to_transformer
.
get
(
func
,
None
)
def
hash_key
(
self
,
func
):
raw_code
=
inspect
.
getsource
(
func
)
code
=
textwrap
.
dedent
(
raw_code
)
return
hash
(
code
)
def
exist
(
self
,
func
):
return
self
.
_cache_funcs
.
get
(
self
.
hash_key
(
func
),
None
)
is
not
None
def
synchronized
(
func
):
func
.
__lock__
=
threading
.
Lock
()
def
lock_func
(
*
args
,
**
kwargs
):
with
func
.
__lock__
:
return
func
(
*
args
,
**
kwargs
)
return
lock_func
class
ProgramCache
(
object
):
"""
Wrapper class for the program functions defined by dygraph function.
"""
def
__init__
(
self
):
self
.
_inputs
=
[]
self
.
_outputs
=
[]
# Always set program to default_main_program. Because once `__call__` is called,
# it means layers(or Ops) are added into default_main_program switched by outer
# `with` statement.
self
.
_program
=
framework
.
default_main_program
()
self
.
_func_cache
=
FunctionCache
()
# Stores the entry function of Net or Model.
self
.
_forward_func
=
None
self
.
_feed_name_to_idx
=
{}
self
.
_is_repeated
=
False
# Indicates whether the function call is still building program.
# Because `__call__` can be called recursively when `Net` has
# sub class in `forward()`.
self
.
_in_build_process
=
True
def
__call__
(
self
,
dyfunc
,
*
args
,
**
kwargs
):
"""
Executes the main_program with specialized inputs.
"""
# Transfroms dygraph function into static functions and caches them.
static_func
=
self
.
_transform_or_cache_layers
(
dyfunc
)
# 1. Adds `fluid.data` layers for input if needed
if
not
self
.
_inputs
:
self
.
_add_feed_layers
(
args
,
kwargs
)
# 2. Avoids inserting forward ops repeatedly.
if
self
.
_is_repeated
:
return
self
.
outputs
# 3. Builds program only once and returns the output Variables.
outputs
=
self
.
_get_or_build_program
(
static_func
,
args
,
kwargs
)
if
static_func
==
self
.
_forward_func
:
self
.
_in_build_process
=
False
return
outputs
def
_transform_or_cache_layers
(
self
,
dyfunc
):
"""
Transforms dygraph function into static function.
"""
static_func
=
self
.
_func_cache
(
dyfunc
)
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once.
# After that, cached program will be always returned by default.
if
static_func
==
self
.
_forward_func
:
self
.
_is_repeated
=
True
if
self
.
_forward_func
is
None
:
self
.
_forward_func
=
static_func
return
static_func
def
_get_or_build_program
(
self
,
func
,
args
,
kwargs
):
"""
Returns program of the input function. If called at first time,
builds a new program and caches it.
"""
with
framework
.
program_guard
(
self
.
_program
):
if
func
==
self
.
_forward_func
:
# Replaces input data with `layers.data`
args
=
list
(
args
)
for
feed_layer
in
self
.
_inputs
:
idx
=
self
.
feed_name_to_idx
[
feed_layer
.
name
]
args
[
idx
]
=
feed_layer
fetch_list
=
func
(
*
args
,
**
kwargs
)
self
.
_outputs
=
fetch_list
else
:
fetch_list
=
func
(
*
args
,
**
kwargs
)
return
fetch_list
def
_add_feed_layers
(
self
,
args
,
kwargs
):
"""
Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
by `to_variable()`, it makes program to be executed dynamically.
"""
if
not
self
.
_feed_name_to_idx
:
self
.
_feed_name_to_idx
=
self
.
_get_name_to_idx
(
self
.
_forward_func
)
with
framework
.
program_guard
(
self
.
_program
):
for
feed_name
,
idx
in
self
.
feed_name_to_idx
.
items
():
batch_data
=
args
[
idx
]
assert
isinstance
(
batch_data
,
numpy
.
ndarray
),
"Input {} should be numpy.ndarray, but received {}."
.
format
(
feed_name
,
type
(
batch_data
))
feed_layer
=
io
.
data
(
name
=
feed_name
,
shape
=
list
(
batch_data
.
shape
[
1
:]),
dtype
=
str
(
batch_data
.
dtype
))
self
.
_inputs
.
append
(
feed_layer
)
def
_get_name_to_idx
(
self
,
func
):
"""
Returns name and index of input args from `forward(args)`
that need to be replaced with `fluid.data`.
"""
transformer
=
self
.
_func_cache
.
transformer
(
func
)
feed_name_to_idx
=
transformer
.
get_feed_name_to_idx
()
return
feed_name_to_idx
@
property
def
program
(
self
):
return
self
.
_program
@
property
def
inputs
(
self
):
return
self
.
_inputs
@
property
def
outputs
(
self
):
return
self
.
_outputs
@
property
def
feed_name_to_idx
(
self
):
return
self
.
_feed_name_to_idx
@
property
def
in_build_process
(
self
):
return
self
.
_in_build_process
class
AutoTracer
(
object
):
_instance
=
None
@
synchronized
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
object
.
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
__initialized
=
False
return
cls
.
_instance
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
raise
ValueError
(
"FuncProgram hasn
\'
t been created!"
)
return
cls
.
_instance
@
classmethod
def
reset
(
cls
):
if
cls
.
_instance
is
not
None
:
cls
.
_instance
.
__initialized
=
False
cls
.
_instance
.
__init__
()
def
__init__
(
self
,
exe
=
None
,
place
=
None
):
# To make sure that calls __init__ only once.
if
self
.
__initialized
:
return
self
.
__initialized
=
True
self
.
_place
=
core
.
CPUPlace
()
if
place
is
None
else
place
if
exe
is
None
:
self
.
_exe
=
executor
.
Executor
(
self
.
_place
)
else
:
self
.
_exe
=
exe
self
.
_cached_program
=
ProgramCache
()
self
.
_optimizer
=
None
self
.
_already_minimized
=
False
# Once main_program is changed, should run startup_program.
self
.
_need_startup
=
True
def
run
(
self
,
*
args
,
**
kwargs
):
"""
Executes main_program and returns output Tensors.
"""
feed_dict
,
fetch_list
=
self
.
_prepare
(
args
)
main_program
=
self
.
_cached_program
.
program
outputs
=
self
.
_exe
.
run
(
main_program
,
feed
=
feed_dict
,
fetch_list
=
fetch_list
)
return
outputs
def
_prepare
(
self
,
args
):
"""
Prepares with feed_dict, fetch_list, optimizer and initialize vars
by running startup_program.
"""
# Updates batch_data for feed_dict
feed_dict
=
self
.
_update_batch_data
(
args
)
fetch_list
=
self
.
_cached_program
.
outputs
# Adds optimizer if needed.
if
self
.
_optimizer
and
not
self
.
_already_minimized
:
self
.
_add_optimizer
()
if
self
.
_need_startup
:
self
.
_exe
.
run
(
framework
.
default_startup_program
())
self
.
_need_startup
=
False
return
feed_dict
,
fetch_list
def
_check_cache_valid
(
self
):
"""
Checks whether the current program is consistent with `default_main_program`.
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
if
self
.
_cached_program
.
program
:
if
self
.
_cached_program
.
program
!=
framework
.
default_main_program
():
AutoTracer
.
reset
()
def
_update_batch_data
(
self
,
args
):
"""
Updates cached batch data while training program.
"""
feed_name_to_idx
=
self
.
_cached_program
.
feed_name_to_idx
feed_vars
=
self
.
_cached_program
.
inputs
feed_dict
=
{}
for
feed_var
in
feed_vars
:
idx
=
feed_name_to_idx
[
feed_var
.
name
]
feed_dict
[
feed_var
.
name
]
=
args
[
idx
]
return
feed_dict
def
set_optimizer
(
self
,
optimizer
,
loss_name
):
"""
Supports to set or update the optimizer used to minimize loss.
"""
self
.
_check_cache_valid
()
self
.
_optimizer
=
optimizer
if
not
isinstance
(
loss_name
,
six
.
string_types
):
raise
ValueError
(
"Type of input loss_name should type(str), but received {}."
.
format
(
type
(
loss_name
)))
self
.
_loss_name
=
loss_name
def
_add_optimizer
(
self
):
"""
Supports to set or update the optimizer used to minimize loss.
"""
main_program
=
self
.
_cached_program
.
program
all_vars
=
main_program
.
block
(
0
).
vars
loss_var
=
all_vars
.
get
(
self
.
_loss_name
,
None
)
if
loss_var
is
None
:
raise
ValueError
(
"Can't find {} in main_program, please confirm whether the loss input is correct"
.
format
(
self
.
_loss_name
))
# Adds optimizer to minimize loss
with
framework
.
program_guard
(
main_program
):
self
.
_optimizer
.
minimize
(
loss_var
)
# Avoids to set optimizer repeatedly.
self
.
_already_minimized
=
True
def
get_cached_program
(
self
):
"""
Returns the ProgramCache instance.
"""
self
.
_check_cache_valid
()
return
self
.
_cached_program
@
property
def
program
(
self
):
return
self
.
_cached_program
.
program
python/paddle/fluid/dygraph/jit.py
浏览文件 @
a5036775
...
...
@@ -16,21 +16,16 @@ from __future__ import print_function
__all__
=
[
'TracedLayer'
,
'dygraph_to_static_output'
,
'dygraph_to_static_graph'
]
import
gast
import
inspect
import
textwrap
import
warnings
from
..wrapped_decorator
import
wrap_decorator
from
.base
import
program_desc_tracing_guard
,
switch_to_static_graph
from
.dygraph_to_static
import
DygraphToStaticAst
from
.dygraph_to_static.ast_utils
import
ast_to_func
from
.dygraph_to_static
import
AutoTracer
,
convert_to_static
from
.layers
import
Layer
from
paddle.fluid
import
core
from
paddle.fluid.framework
import
Program
,
Block
,
Variable
,
_dygraph_tracer
,
dygraph_only
,
_dygraph_guard
,
_current_expected_place
,
in_dygraph_mode
from
paddle.fluid.executor
import
Executor
,
scope_guard
from
paddle.fluid.compiler
import
CompiledProgram
from
paddle.fluid
import
program_guard
,
data
,
default_startup_program
,
default_main_program
def
create_program_from_desc
(
program_desc
):
...
...
@@ -56,23 +51,6 @@ def extract_vars(inputs):
return
result_list
def
to_static_func
(
dygraph_func
):
# Get AST from dygraph function
dygraph_code
=
inspect
.
getsource
(
dygraph_func
)
dygraph_code
=
textwrap
.
dedent
(
dygraph_code
)
root
=
gast
.
parse
(
dygraph_code
)
# Transform AST
dygraph_to_static
=
DygraphToStaticAst
()
root_wrapper
=
dygraph_to_static
.
get_static_ast
(
root
)
# Get static_func from AST
func_name
=
dygraph_to_static
.
get_module_name
()
static_func
,
file_name
=
ast_to_func
(
root_wrapper
.
node
,
func_name
)
return
static_func
,
dygraph_to_static
def
_dygraph_to_static_graph_
(
dygraph_func
):
def
__impl__
(
*
args
,
**
kwargs
):
if
in_dygraph_mode
():
...
...
@@ -80,13 +58,20 @@ def _dygraph_to_static_graph_(dygraph_func):
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
static_func
,
dygraph_to_static
=
to_static_fun
c
(
dygraph_func
)
static_func
,
ast_transformer
=
convert_to_stati
c
(
dygraph_func
)
return
static_func
(
*
args
,
**
kwargs
)
return
__impl__
dygraph_to_static_graph
=
wrap_decorator
(
_dygraph_to_static_graph_
)
def
_dygraph_to_static_output_
(
dygraph_func
):
# Singleton object to cache main_program to avoid inserting ops repeatedly.
# TODO: Need a better class name
auto_tracer
=
AutoTracer
()
def
__impl__
(
*
args
,
**
kwargs
):
if
in_dygraph_mode
():
warnings
.
warn
(
...
...
@@ -94,45 +79,19 @@ def _dygraph_to_static_output_(dygraph_func):
" Please use it in static mode."
)
return
dygraph_func
(
*
args
,
**
kwargs
)
static_func
,
dygraph_to_static
=
to_static_func
(
dygraph_func
)
feed_name_to_idx
=
dygraph_to_static
.
get_feed_name_to_idx
()
feed_dict
=
{}
for
feed_name
,
idx
in
feed_name_to_idx
.
items
():
feed_dict
[
feed_name
]
=
args
[
idx
]
# Run static_func in static mode
startup_program
=
default_main_program
()
main_program
=
default_startup_program
()
static_res
=
run_static_func
(
main_program
,
startup_program
,
static_func
,
args
,
kwargs
,
feed_dict
,
feed_name_to_idx
)
return
static_res
return
__impl__
cached_program
=
auto_tracer
.
get_cached_program
()
outputs
=
cached_program
(
dygraph_func
,
*
args
,
**
kwargs
)
# Run program to fetch output Tensors once building successfully.
if
not
cached_program
.
in_build_process
:
outputs
=
auto_tracer
.
run
(
*
args
,
**
kwargs
)
def
run_static_func
(
main_program
,
startup_program
,
static_func
,
args
,
kwargs
,
feed_dict
,
feed_name_to_idx
):
return
outputs
with
program_guard
(
main_program
,
startup_program
):
args_list
=
list
(
args
)
for
var_name
,
value
in
feed_dict
.
items
():
idx
=
feed_name_to_idx
[
var_name
]
args_list
[
idx
]
=
data
(
name
=
var_name
,
shape
=
value
.
shape
,
dtype
=
str
(
value
.
dtype
))
args
=
tuple
(
args_list
)
static_out
=
static_func
(
*
args
,
**
kwargs
)
if
not
isinstance
(
static_out
,
(
list
,
tuple
)):
static_out
=
[
static_out
]
exe
=
Executor
(
core
.
CPUPlace
())
exe
.
run
(
startup_program
)
static_res
=
exe
.
run
(
main_program
,
fetch_list
=
static_out
,
feed
=
feed_dict
)
return
static_res
return
__impl__
dygraph_to_static_output
=
wrap_decorator
(
_dygraph_to_static_output_
)
dygraph_to_static_graph
=
wrap_decorator
(
_dygraph_to_static_graph_
)
@
dygraph_only
...
...
@@ -394,11 +353,11 @@ class TracedLayer(object):
in_var = to_variable(in_np)
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
place = fluid.CPUPlace()
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_vars, fetch_vars = fluid.io.load_inference_model(save_dirname,
exe)
exe)
fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
print(fetch.shape) # (2, 10)
...
...
python/paddle/fluid/tests/unittests/dygraph_to_static/test_cache_program.py
0 → 100644
浏览文件 @
a5036775
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
collections
import
Counter
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.dygraph_to_static
import
AutoTracer
from
paddle.fluid.dygraph.jit
import
dygraph_to_static_output
from
test_fetch_feed
import
Pool2D
,
Linear
class
TestCacheProgram
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_num
=
5
self
.
dygraph_class
=
Pool2D
self
.
data
=
np
.
random
.
random
((
1
,
2
,
4
,
4
)).
astype
(
'float32'
)
def
test_cache
(
self
):
prev_ops
,
cur_ops
=
Counter
(),
Counter
()
prev_out
,
cur_out
=
None
,
None
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
for
batch_id
in
range
(
self
.
batch_num
):
out
=
static_net
(
self
.
data
)
# Check outputs
prev_out
=
cur_out
cur_out
=
out
# Check forward ops
prev_ops
=
cur_ops
cur_ops
=
Counter
([
op
.
type
for
op
in
fluid
.
default_main_program
().
block
(
0
).
ops
])
if
batch_id
>
0
:
self
.
assertTrue
(
np
.
allclose
(
prev_out
[
0
],
cur_out
[
0
]),
msg
=
'Output in previous batch is {}
\n
Output in current batch is
\n
{}'
.
format
(
prev_out
,
cur_out
))
self
.
assertEqual
(
prev_ops
,
cur_ops
)
class
TestCacheProgram2
(
TestCacheProgram
):
def
setUp
(
self
):
self
.
batch_num
=
5
self
.
dygraph_class
=
Linear
self
.
data
=
np
.
random
.
random
((
4
,
10
)).
astype
(
'float32'
)
class
TestCacheProgramWithOptimizer
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
dygraph_class
=
Linear
self
.
data
=
np
.
random
.
random
((
4
,
10
)).
astype
(
'float32'
)
self
.
batch_num
=
5
def
train_static
(
self
):
main_program
=
fluid
.
Program
()
loss_data
=
[]
with
fluid
.
program_guard
(
main_program
):
static_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
)
# set optimizer
# TODO: Need a better interfaces to set optimizer.
auto_tracer
=
AutoTracer
()
auto_tracer
.
set_optimizer
(
adam
,
'avg_loss'
)
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
static_net
(
self
.
data
)
loss_data
.
append
(
np
.
array
(
avg_loss
))
return
loss_data
def
train_dygraph
(
self
):
with
fluid
.
dygraph
.
guard
(
fluid
.
CPUPlace
()):
dygraph_net
=
self
.
dygraph_class
()
adam
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
0.001
,
parameter_list
=
dygraph_net
.
parameters
())
loss_data
=
[]
for
batch_id
in
range
(
self
.
batch_num
):
pred
,
avg_loss
=
dygraph_net
(
self
.
data
)
loss_data
.
append
(
avg_loss
.
numpy
())
avg_loss
.
backward
()
adam
.
minimize
(
avg_loss
)
dygraph_net
.
clear_gradients
()
return
loss_data
def
test_with_optimizer
(
self
):
dygraph_loss
=
self
.
train_dygraph
()
static_loss
=
self
.
train_static
()
self
.
assertTrue
(
np
.
allclose
(
dygraph_loss
,
static_loss
),
msg
=
'dygraph is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_loss
,
static_loss
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/dygraph_to_static/test_fetch_feed.py
浏览文件 @
a5036775
...
...
@@ -58,7 +58,8 @@ class Linear(fluid.dygraph.Layer):
def
forward
(
self
,
x
):
inputs
=
fluid
.
dygraph
.
to_variable
(
x
)
pre
=
self
.
fc
(
inputs
)
return
pre
loss
=
fluid
.
layers
.
mean
(
pre
,
name
=
'avg_loss'
)
return
pre
,
loss
class
TestPool2D
(
unittest
.
TestCase
):
...
...
@@ -69,10 +70,11 @@ class TestPool2D(unittest.TestCase):
def
run_dygraph_mode
(
self
):
with
fluid
.
dygraph
.
guard
():
dy_layer
=
self
.
dygraph_class
()
for
_
in
range
(
1
):
prediction
=
dy_layer
(
x
=
self
.
data
)
if
isinstance
(
prediction
,
(
list
,
tuple
)):
prediction
=
prediction
[
0
]
prediction
=
dy_layer
(
x
=
self
.
data
)
return
prediction
.
numpy
()
return
prediction
.
numpy
()
def
run_static_mode
(
self
):
startup_prog
=
fluid
.
Program
()
...
...
@@ -90,7 +92,6 @@ class TestPool2D(unittest.TestCase):
np
.
allclose
(
dygraph_res
,
static_res
),
msg
=
'dygraph_res is {}
\n
static_res is
\n
{}'
.
format
(
dygraph_res
,
static_res
))
return
class
TestLinear
(
TestPool2D
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录