未验证 提交 1ed8baf9 编写于 作者: A Aurelius84 提交者: GitHub

[dy2static] Support for static graph training with @declarative decorator (#24259)

* support to train in static

* support to independent decorator

* remove in_dygraph_mode condition in ProgramTranslator

* fix import param_guard and add train/eval test=develop

* Modify into ShareVarsFromScope and rm __all__ in partial_program test=develop
上级 2424297f
......@@ -102,74 +102,50 @@ static void CheckOutputVarStatus(const Variable &src_var,
}
static void VariableShare(const Variable &src_var, Variable *dst_var) {
// The previous check ensures that the variable type can only be LoDTensor
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
lod_tensor->ShareDataWith(src_var.Get<LoDTensor>());
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
}
static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
auto *var = scope->Var(var_names[i]);
CheckInputVarStatus(*vars[i], var_names[i]);
VariableShare(*vars[i], var);
}
}
static void VariableCopy(const Variable &src_var,
const platform::Place &dst_place, Variable *dst_var) {
// The previous check ensures that the variable type can only be LoDTensor or
// SelectedRows
// SelectedRows.
if (src_var.IsType<LoDTensor>()) {
auto *lod_tensor = dst_var->GetMutable<LoDTensor>();
TensorCopySync(src_var.Get<LoDTensor>(), dst_place, lod_tensor);
lod_tensor->ShareDataWith(src_var.Get<LoDTensor>());
lod_tensor->set_lod(src_var.Get<LoDTensor>().lod());
} else if (src_var.IsType<SelectedRows>()) {
auto *selected_rows = dst_var->GetMutable<SelectedRows>();
TensorCopySync(src_var.Get<SelectedRows>().value(), dst_place,
selected_rows->mutable_value());
selected_rows->mutable_value()->ShareDataWith(
src_var.Get<SelectedRows>().value());
selected_rows->set_rows(src_var.Get<SelectedRows>().rows());
selected_rows->set_height(src_var.Get<SelectedRows>().height());
}
}
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
auto *var = scope->FindVar(var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("The output variable %s is not in "
"RunProgram(Grad)Op(StaticModelRunner)'"
"s internal scope.",
var_names[i]));
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
VariableShare(*var, vars[i]);
auto *var = scope->Var(var_names[i]);
CheckInputVarStatus(*vars[i], var_names[i]);
VariableShare(*vars[i], var);
}
}
static void CopyVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
const platform::Place &dst_place,
framework::Scope *scope) {
static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == framework::kEmptyVarName) {
VLOG(2) << "find variable name is " << framework::kEmptyVarName
<< ", skip it!";
continue;
}
auto *var = scope->FindVar(var_names[i]);
// NOTE: Here skip not found var is dangerous, if a bug is caused here,
// the result is grad calculation error, which will be very hidden!
auto *var = scope->FindVar(var_names[i]);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("The output variable %s is not in "
"RunProgram(Grad)Op(StaticModelRunner)'"
"s internal scope.",
var_names[i]));
CheckOutputVarStatus(*var, *vars[i], var_names[i]);
VariableCopy(*var, dst_place, vars[i]);
VariableShare(*var, vars[i]);
}
}
......@@ -306,11 +282,9 @@ class RunProgramGradOpKernel : public framework::OpKernel<T> {
end_op_index, /*create_local_scope=*/false,
/*create_vars=*/true, /*keep_kids=*/false);
// Step 4. copy outputs
details::CopyVarsFromScope(input_grad_vars, input_grad_var_names,
ctx.GetPlace(), &scope);
details::CopyVarsFromScope(param_grad_vars, param_grad_names,
ctx.GetPlace(), &scope);
// Step 4. get outputs
details::ShareVarsFromScope(input_grad_vars, input_grad_var_names, &scope);
details::ShareVarsFromScope(param_grad_vars, param_grad_names, &scope);
}
};
......
......@@ -76,7 +76,8 @@ def check_variable_and_dtype(input,
expected_dtype,
op_name,
extra_message=''):
check_type(input, input_name, Variable, op_name, extra_message)
check_type(input, input_name, (Variable, core.VarBase), op_name,
extra_message)
check_dtype(input.dtype, input_name, expected_dtype, op_name, extra_message)
......
......@@ -61,6 +61,26 @@ def program_desc_tracing_guard(enable):
_functional_dygraph_context_manager = None
@signature_safe_contextmanager
def param_guard(parameters):
# Note: parameters is a reference of self._parameters
if not framework.in_dygraph_mode() and parameters:
origin_parameters = parameters.copy()
for name, var_base in parameters.items():
if isinstance(var_base, core.VarBase):
new_var = framework.Parameter(
var_base.block,
var_base.shape,
var_base.dtype,
var_base.type,
name=var_base.name)
parameters[name] = new_var
yield
parameters.update(origin_parameters)
else:
yield
def enabled():
"""
This function checks whether the program runs in dynamic graph mode or not.
......
......@@ -14,8 +14,6 @@
from __future__ import print_function
import astor
import copy
# gast is a generic AST to represent Python2 and Python3's Abstract Syntax Tree(AST).
# It provides a compatibility layer between the AST of various Python versions,
# as produced by ast.parse from the standard ast module.
......@@ -24,8 +22,6 @@ import gast
import inspect
import textwrap
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.basic_api_transformer import BasicApiTransformer
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakContinueTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
......@@ -35,14 +31,9 @@ from paddle.fluid.dygraph.dygraph_to_static.tensor_shape_transformer import Tens
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.print_transformer import PrintTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static_api
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_api, is_dygraph_api, is_to_variable
from paddle.fluid.dygraph.dygraph_to_static.utils import to_assign_node, to_static_ast, update_args_of_func
__all__ = ['DygraphToStaticAst', 'convert_to_static']
......@@ -146,6 +137,9 @@ def convert_to_static(dyfunc):
Converts dygraph function into static function.
"""
# Get AST from dygraph function
# Note: In Python2, it will raise OSError when inspect function
# with decorator directly and dyfunc.__wrapped__ holds the actual function.
dyfunc = getattr(dyfunc, '__wrapped__', dyfunc)
raw_code = inspect.getsource(dyfunc)
code = textwrap.dedent(raw_code)
root = gast.parse(code)
......
......@@ -17,9 +17,6 @@ from __future__ import print_function
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import NodeTestTransformer
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
......
......@@ -103,17 +103,8 @@ def convert_call(func):
return func
try:
if func in func.__globals__.values():
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except AttributeError:
# NOTE:
# If func is not in __globals__, it does not need to be transformed
......@@ -121,45 +112,25 @@ def convert_call(func):
converted_call = None
except (IOError, OSError):
# NOTE:
# If func has beed decorated, its source code can not be get
# If func has been decorated, its source code can not be get
# so that it can not be transformed to static function.
converted_call = None
elif inspect.ismethod(func):
try:
if six.PY3:
source_code = inspect.getsource(func)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
else:
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
converted_call = to_static_func(func)
func_self = getattr(func, '__self__', None)
except (IOError, OSError):
# NOTE: func may have beed decorated.
# NOTE: func may have been decorated.
converted_call = None
elif hasattr(func, '__class__') and hasattr(func.__class__, '__call__'):
if hasattr(func, 'forward') and isinstance(func, Layer):
try:
if six.PY3:
source_code = inspect.getsource(func.forward)
if any(decorator in source_code
for decorator in DECORATOR_NAMES):
converted_call = None
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
else:
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
forward_func = to_static_func(func.forward)
setattr(func, 'forward', forward_func)
func_self = func
except Exception:
# NOTE: func.forward may have beed decorated.
# NOTE: func.forward may have been decorated.
func_self = None if func_self else func_self
converted_call = func
else:
......
......@@ -22,7 +22,6 @@ from collections import defaultdict
# as produced by ast.parse from the standard ast module.
# See details in https://github.com/serge-sans-paille/gast/
import gast
import six
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import compare_with_none
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle.compat as cpt
class PartialProgramLayer(layers.Layer):
"""
PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
and execute them as a static subgraph.
.. note::
**1. It should not be called directly and is used to train dygraph by static mode.
**2. LoDTensorArray is not currently supported in the output.
Args:
main_program(Program): The main program that contains ops need to be executed.
inputs(list[Variable]): The input list of the decorated function by `@declarative`.
outputs(list[Variable]): The output list of the decorated function by `@declarative`.
parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.
Returns:
Layer: A Layer object that run all ops internally in static mode.
"""
def __init__(self, main_program, inputs, outputs, parameters=None):
super(PartialProgramLayer, self).__init__()
self.inputs = inputs
self.outputs = outputs
self._params = parameters
self._infer_program = main_program
self._train_program = self._append_backward_desc()
# Switch infer or train by train() and eval()
self._trace_program = None
self._set_grad_type(self._params)
self._inner_scope = core.Scope()
# Set default mode to train
self.train()
@switch_to_static_graph
def _append_backward_desc(self):
program = self._infer_program.clone()
targets = []
for out in self.outputs:
if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name))
if targets and self._params:
backward.gradients(targets=targets, inputs=[])
return program
def train(self):
# self.training is inherited from layers.Layer
self.training = True
self._trace_program = self._train_program
def eval(self):
self.training = False
self._trace_program = self._infer_program
def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={
'X': valid_vars(in_vars),
'Params': valid_vars(self._params)
},
outputs={'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec},
attrs={
'global_block': self._trace_program.desc.block(0),
'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training
})
outs = out_vars
if len(outs) == 1:
outs = outs[0]
return outs
def _prepare(self, inputs):
"""
Prepare inputs, outputs, attrs.
"""
assert isinstance(inputs, (tuple, list))
# Convert variable into VarBase and feed in training data.
input_vars = []
for i, value in enumerate(inputs):
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=self.inputs[i].desc.name(),
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
elif isinstance(value, core.VarBase):
var = value
var.name = self.inputs[i].desc.name()
else:
continue
input_vars.append(var)
# Create VarBase to receive output data.
out_vars = []
for var in self.outputs:
if not isinstance(var, framework.Variable):
continue
var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
out_vars.append(var_base)
# Hold forward variables
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(self._inner_scope)
return input_vars, out_vars, tmp_scope_vec
def _set_grad_type(self, params):
# NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
# set param grad VarBase by forward VarBase(LoDTensor)
# If we don't change grad_var type here, RunProgramOp need
# transform SelectedRows to LoDTensor forcibly, it may not
# be user wanted result.
for param in params:
grad_name = param.name + core.grad_var_suffix()
grad_var = self._train_program.desc.block(0).find_var(
cpt.to_bytes(grad_name))
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
if grad_var is None:
continue
param._set_grad_type(grad_var.type())
def valid_vars(vars):
"""
Note: run_program_op.InferShape requires `X`/'Out' not be null.
But it's common in dy2static, fake varBase is created to handle the
problem.
"""
if vars:
return vars
return [
core.VarBase(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
]
def append_grad_suffix(name):
"""
Append grad suffix to the given variable name.
e.g. x ==> x@GRAD
"""
suffix = core.kGradVarSuffix()
name = cpt.to_text(name)
if suffix not in name:
name = name + suffix
return name
def partial_program_from(concrete_program):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):
inputs = inputs[1:]
return PartialProgramLayer(concrete_program.main_program, inputs,
concrete_program.outputs,
concrete_program.parameters)
......@@ -15,7 +15,6 @@
from __future__ import print_function
import gast
import astor
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper, NodeVarType, StaticAnalysisVisitor
......
......@@ -17,7 +17,6 @@ from __future__ import print_function
__all__ = ['TracedLayer', 'declarative', 'dygraph_to_static_func']
import logging
from paddle.fluid import core
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.base import program_desc_tracing_guard, switch_to_static_graph
......@@ -156,13 +155,11 @@ def _declarative_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative:
if not program_translator.enable_declarative:
logger.info(
"The decorator 'declarative' doesn't work in dygraph "
"mode or set ProgramTranslator.enable to False. We will "
"just return dygraph output.")
"The decorator 'declarative' doesn't work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs)
return __impl__
......@@ -228,6 +225,7 @@ class TracedLayer(object):
self._program = program
self._feed_names = feed_names
self._fetch_names = fetch_names
self._params = parameters
self._place = _current_expected_place()
......
......@@ -23,7 +23,7 @@ from . import parallel_helper
from .. import unique_name
from paddle.fluid import core
from .layer_object_helper import LayerObjectHelper
from .base import program_desc_tracing_guard
from .base import program_desc_tracing_guard, param_guard
from paddle.fluid import framework
from ..param_attr import ParamAttr
import copy
......@@ -457,7 +457,8 @@ class Layer(core.Layer):
self._parameters.values())
self._built = True
outputs = self.forward(*inputs, **kwargs)
with param_guard(self._parameters):
outputs = self.forward(*inputs, **kwargs)
for forward_post_hook in self._forward_post_hooks.values():
hook_result = forward_post_hook(self, inputs, outputs)
......
......@@ -1961,7 +1961,7 @@ class Operator(object):
in_arg_names.append(arg)
elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode())
elif isinstance(arg, Variable):
elif isinstance(arg, (Variable, core.VarBase)):
in_arg_names.append(cpt.to_text(arg.name))
else:
raise TypeError(
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
def add_fn(x):
......@@ -142,7 +141,6 @@ class NetWithControlFlowIf(fluid.dygraph.Layer):
self.alpha = 10.
self.constant_vars = {}
@dygraph_to_static_func
def forward(self, input):
hidden_dim = input.shape[-1]
if hidden_dim != self.hidden_dim:
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.jit import declarative
SEED = 2020
np.random.seed(SEED)
......@@ -160,13 +160,9 @@ class TestContinueInFor(unittest.TestCase):
return res.numpy()
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
res = dygraph_to_static_func(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=[res])
return static_res[0]
with fluid.dygraph.guard():
res = declarative(self.dygraph_func)(self.input)
return res.numpy()
def test_transformed_static_result(self):
static_res = self.run_static_mode()
......
......@@ -36,8 +36,7 @@ class TestCacheProgram(unittest.TestCase):
def test_cache(self):
prev_ops, cur_ops = Counter(), Counter()
prev_out, cur_out = None, None
main_program = fluid.Program()
with fluid.program_guard(main_program):
with fluid.dygraph.guard(fluid.CPUPlace()):
static_net = self.dygraph_class()
for batch_id in range(self.batch_num):
out = static_net(self.data)
......@@ -51,9 +50,9 @@ class TestCacheProgram(unittest.TestCase):
])
if batch_id > 0:
prev_out_numpy = prev_out[0].numpy() if isinstance(
prev_out, tuple) else prev_out.numpy()
prev_out, (tuple, list)) else prev_out.numpy()
cur_out_numpy = cur_out[0].numpy() if isinstance(
cur_out, tuple) else cur_out.numpy()
cur_out, (tuple, list)) else cur_out.numpy()
self.assertTrue(
np.allclose(prev_out_numpy, cur_out_numpy),
msg='Output in previous batch is {}\n Output in current batch is \n{}'
......@@ -75,29 +74,23 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
self.batch_num = 5
def train_static(self):
main_program = fluid.Program()
loss_data = []
with fluid.program_guard(main_program):
static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
program_translator = ProgramTranslator()
program_translator.set_optimizer(adam, index_of_loss=1)
return self.train(to_static=True)
for batch_id in range(self.batch_num):
pred, avg_loss = static_net(self.data)
loss_data.append(np.array(avg_loss.numpy()))
def train_dygraph(self):
return self.train(to_static=False)
return loss_data
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
def train_dygraph(self):
with fluid.dygraph.guard(fluid.CPUPlace()):
dygraph_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(
learning_rate=0.001, parameter_list=dygraph_net.parameters())
loss_data = []
for batch_id in range(self.batch_num):
pred, avg_loss = dygraph_net(self.data)
input = fluid.dygraph.to_variable(self.data)
pred, avg_loss = dygraph_net(input)
loss_data.append(avg_loss.numpy())
avg_loss.backward()
......@@ -114,20 +107,6 @@ class TestCacheProgramWithOptimizer(unittest.TestCase):
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
def test_exception(self):
main_program = fluid.Program()
loss_data = []
with fluid.program_guard(main_program):
static_net = self.dygraph_class()
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.001)
# set optimizer
program_translator = ProgramTranslator()
with self.assertRaisesRegexp(ValueError, "has already been set"):
for batch_id in range(self.batch_num):
program_translator.set_optimizer(adam, index_of_loss=1)
static_net(self.data)
def simple_func(x):
inputs = fluid.dygraph.to_variable(x)
......@@ -156,7 +135,6 @@ def sum_even_util_limit(max_len, limit):
return ret_sum
@declarative
def sum_under_while(limit):
i = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
ret_sum = fluid.dygraph.to_variable(np.zeros((1)).astype('int32'))
......@@ -168,11 +146,12 @@ def sum_under_while(limit):
class TestToOutputWithCache(unittest.TestCase):
def test_output(self):
ret = sum_even_util_limit(80, 10)
self.assertEqual(ret.numpy(), 30)
with fluid.dygraph.guard():
ret = sum_even_util_limit(80, 10)
self.assertEqual(ret.numpy(), 30)
ret = sum_under_while(100)
self.assertEqual(ret.numpy(), 5050)
ret = declarative(sum_under_while)(100)
self.assertEqual(ret.numpy(), 5050)
if __name__ == '__main__':
......
......@@ -19,7 +19,8 @@ import numpy as np
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
PLACE = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
......@@ -75,7 +76,7 @@ class MainNetWithDict(fluid.dygraph.Layer):
self.output_size = output_size
self.sub_net = SubNetWithDict(hidden_size, output_size)
@dygraph_to_static_func
@declarative
def forward(self, input, max_len=4):
input = fluid.dygraph.to_variable(input)
cache = {
......@@ -121,17 +122,14 @@ class TestNetWithDict(unittest.TestCase):
self.batch_size = self.x.shape[0]
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
net = MainNetWithDict(batch_size=self.batch_size)
# Transform into static graph
out = net(self.x)
exe = fluid.Executor(PLACE)
exe.run(fluid.default_startup_program())
ret = exe.run(main_program, fetch_list=out)
return ret[0]
return self.train(to_static=True)
def _run_dygraph(self):
return self.train(to_static=False)
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
with fluid.dygraph.guard(PLACE):
net = MainNetWithDict(batch_size=self.batch_size)
ret = net(self.x)
......
......@@ -14,12 +14,12 @@
from __future__ import print_function
from paddle.fluid.dygraph.jit import declarative
import numpy as np
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
SEED = 2020
......@@ -32,22 +32,20 @@ class Pool2D(fluid.dygraph.Layer):
@declarative
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
# Add func `get_result` for testing arg_name_to_idx in ast transformation.
def get_result(x):
return self.pool2d(x)
pre = get_result(inputs)
pre = get_result(x)
return pre
class Linear(fluid.dygraph.Layer):
def __init__(self):
def __init__(self, input_dim=10, output_dim=5):
super(Linear, self).__init__()
self.fc = fluid.dygraph.Linear(
input_dim=10,
output_dim=5,
input_dim,
output_dim,
act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.99)),
......@@ -56,8 +54,7 @@ class Linear(fluid.dygraph.Layer):
@declarative
def forward(self, x):
inputs = fluid.dygraph.to_variable(x)
pre = self.fc(inputs)
pre = self.fc(x)
loss = fluid.layers.mean(pre)
return pre, loss
......@@ -67,28 +64,28 @@ class TestPool2D(unittest.TestCase):
self.dygraph_class = Pool2D
self.data = np.random.random((1, 2, 4, 4)).astype('float32')
def run_dygraph_mode(self):
def train(self, to_static=False):
program_translator = ProgramTranslator()
program_translator.enable(to_static)
with fluid.dygraph.guard():
dy_layer = self.dygraph_class()
prediction = dy_layer(x=self.data)
x = fluid.dygraph.to_variable(self.data)
prediction = dy_layer(x)
if isinstance(prediction, (list, tuple)):
prediction = prediction[0]
return prediction.numpy()
def run_static_mode(self):
startup_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
dy_layer = self.dygraph_class()
out = dy_layer(x=self.data)
if isinstance(out, tuple):
return out[0].numpy()
return out.numpy()
def test_static_output(self):
dygraph_res = self.run_dygraph_mode()
static_res = self.run_static_mode()
def train_static(self):
return self.train(to_static=True)
def train_dygraph(self):
return self.train(to_static=False)
def test_declarative(self):
dygraph_res = self.train_dygraph()
static_res = self.train_static()
self.assertTrue(
np.allclose(dygraph_res, static_res),
......
......@@ -17,8 +17,7 @@ from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph import declarative
@fluid.dygraph.declarative
......@@ -31,7 +30,7 @@ def dygraph_decorated_func(x):
return x_v
@fluid.dygraph.jit.declarative
@fluid.dygraph.declarative
def jit_decorated_func(x):
x = fluid.dygraph.to_variable(x)
if fluid.layers.mean(x) > 0:
......@@ -62,18 +61,14 @@ class TestFullNameDecorator(unittest.TestCase):
def test_run_success(self):
x = np.ones([1, 2]).astype("float32")
answer = np.zeros([1, 2]).astype("float32")
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.dygraph.guard():
self.assertTrue(
np.allclose(dygraph_decorated_func(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
self.assertTrue(np.allclose(jit_decorated_func(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
self.assertTrue(
np.allclose(decorated_call_decorated(x).numpy(), answer))
with fluid.program_guard(fluid.Program(), fluid.Program()):
with self.assertRaises(NotImplementedError):
DoubleDecorated().double_decorated_func1(x)
with fluid.program_guard(fluid.Program(), fluid.Program()):
with self.assertRaises(NotImplementedError):
DoubleDecorated().double_decorated_func2(x)
......
......@@ -15,10 +15,10 @@
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
from ifelse_simple_func import *
......@@ -41,19 +41,16 @@ class TestDygraphIfElse(unittest.TestCase):
self.dyfunc = dyfunc_with_if_else
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = dygraph_to_static_func(self.dyfunc)(x_v)
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
return ret
return self._run_dygraph(to_static=True)
def _run_dygraph(self, to_static=False):
def _run_dygraph(self):
with fluid.dygraph.guard(place):
x_v = fluid.dygraph.to_variable(self.x)
ret = self.dyfunc(x_v)
if to_static:
ret = declarative(self.dyfunc)(x_v)
else:
ret = self.dyfunc(x_v)
return ret.numpy()
def test_ast_to_func(self):
......@@ -187,18 +184,15 @@ class TestDygraphIfElseNet(unittest.TestCase):
self.Net = NetWithControlFlowIf
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
net = self.Net()
x_v = fluid.layers.assign(self.x)
# Transform into static graph
out = net(x_v)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(main_program, fetch_list=out)
return ret[0]
return self._run(to_static=True)
def _run_dygraph(self):
return self._run(to_static=False)
def _run(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
with fluid.dygraph.guard(place):
net = self.Net()
x_v = fluid.dygraph.to_variable(self.x)
......@@ -234,7 +228,7 @@ class TestAst2FuncWithExternalFunc(TestDygraphIfElse):
class NetWithExternalFunc(fluid.dygraph.Layer):
@dygraph_to_static_func
@declarative
def forward(self, x, label=None):
if fluid.layers.mean(x) < 0:
x_v = x - 1
......
......@@ -15,7 +15,6 @@
from __future__ import print_function
import unittest
from functools import partial
import numpy as np
import paddle.fluid as fluid
......@@ -27,7 +26,6 @@ np.random.seed(SEED)
# Situation 1: Test list append
@declarative
def test_list_append_without_control_flow(x):
# Python list will not be transformed.
x = fluid.dygraph.to_variable(x)
......@@ -38,7 +36,6 @@ def test_list_append_without_control_flow(x):
return a
@declarative
def test_list_append_in_if(x):
x = fluid.dygraph.to_variable(x)
a = []
......@@ -48,10 +45,10 @@ def test_list_append_in_if(x):
a.append(
fluid.layers.fill_constant(
shape=[1, 2], value=9, dtype="int64"))
return a
# TODO(Aurelius84): Currently, run_program_op doesn't support output LoDTensorArray.
return a[0]
@declarative
def test_list_append_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
......@@ -61,10 +58,9 @@ def test_list_append_in_for_loop(x, iter_num):
a = []
for i in range(iter_num):
a.append(x)
return a
return a[0]
@declarative
def test_list_append_in_for_loop_with_concat(x, iter_num):
x = fluid.dygraph.to_variable(x)
a = []
......@@ -78,7 +74,6 @@ def test_list_append_in_for_loop_with_concat(x, iter_num):
return a
@declarative
def test_list_append_in_while_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(
......@@ -88,10 +83,9 @@ def test_list_append_in_while_loop(x, iter_num):
while i < iter_num:
a.append(x)
i += 1
return a
return a[0]
@declarative
def test_list_append_in_while_loop_with_stack(x, iter_num):
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(
......@@ -106,7 +100,6 @@ def test_list_append_in_while_loop_with_stack(x, iter_num):
# Situation 2: Test list pop
@declarative
def test_list_pop_without_control_flow_1(x):
x = fluid.dygraph.to_variable(x)
a = []
......@@ -116,18 +109,16 @@ def test_list_pop_without_control_flow_1(x):
return a
@declarative
def test_list_pop_without_control_flow_2(x):
x = fluid.dygraph.to_variable(x)
a = []
if 2 > 1:
a.append(x)
a.append(x + 1)
last_tiem = a.pop(1)
return last_tiem
last_item = a.pop(1)
return last_item
@declarative
def test_list_pop_in_if(x):
x = fluid.dygraph.to_variable(x)
a = []
......@@ -138,11 +129,9 @@ def test_list_pop_in_if(x):
a.append(x + 1)
a.append(fluid.layers.fill_constant(shape=[2], value=2, dtype="int64"))
item1 = a.pop(1)
a.pop()
return a, item1
return item1
@declarative
def test_list_pop_in_for_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
# Use `fill_constant` so that static analysis can analyze the type of iter_num is Tensor
......@@ -158,10 +147,9 @@ def test_list_pop_in_for_loop(x, iter_num):
for i in range(one.numpy()[0]):
item = a.pop()
return a, item
return a[0], item
@declarative
def test_list_pop_in_while_loop(x, iter_num):
x = fluid.dygraph.to_variable(x)
iter_num = fluid.layers.fill_constant(
......@@ -173,7 +161,7 @@ def test_list_pop_in_while_loop(x, iter_num):
i += 1
if i % 2 == 1:
a.pop()
return a
return a[0]
class TestListWithoutControlFlow(unittest.TestCase):
......@@ -201,15 +189,19 @@ class TestListWithoutControlFlow(unittest.TestCase):
res = [res.numpy()]
return res
def run_static_mode(self):
return self.train(to_static=True)
def run_dygraph_mode(self):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
return self.varbase_to_numpy(res)
return self.train(to_static=False)
def run_static_mode(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
res = self.dygraph_func(self.input)
def train(self, to_static=False):
with fluid.dygraph.guard():
if to_static:
res = declarative(self.dygraph_func)(self.input)
else:
res = self.dygraph_func(self.input)
return self.varbase_to_numpy(res)
def test_transformed_static_result(self):
......@@ -238,39 +230,34 @@ class TestListInWhileLoop(TestListWithoutControlFlow):
def init_dygraph_func(self):
self.all_dygraph_funcs = [
partial(
test_list_append_in_while_loop, iter_num=self.iter_num),
partial(
test_list_pop_in_while_loop, iter_num=self.iter_num),
test_list_append_in_while_loop, test_list_pop_in_while_loop
]
def train(self, to_static=False):
with fluid.dygraph.guard():
if to_static:
res = declarative(self.dygraph_func)(self.input, self.iter_num)
else:
res = self.dygraph_func(self.input, self.iter_num)
return self.varbase_to_numpy(res)
class TestListInWhileLoopWithStack(TestListInWhileLoop):
def init_dygraph_func(self):
self.all_dygraph_funcs = [
partial(
test_list_append_in_while_loop_with_stack,
iter_num=self.iter_num)
]
self.all_dygraph_funcs = [test_list_append_in_while_loop_with_stack]
class TestListInForLoop(TestListInWhileLoop):
def init_dygraph_func(self):
self.all_dygraph_funcs = [
partial(
test_list_append_in_for_loop, iter_num=self.iter_num),
partial(
test_list_pop_in_for_loop, iter_num=self.iter_num),
test_list_append_in_for_loop, test_list_pop_in_for_loop
]
class TestListInForLoopWithConcat(TestListInWhileLoopWithStack):
def init_dygraph_func(self):
self.all_dygraph_funcs = [
partial(
test_list_append_in_for_loop_with_concat,
iter_num=self.iter_num)
]
self.all_dygraph_funcs = [test_list_append_in_for_loop_with_concat, ]
if __name__ == '__main__':
......
......@@ -20,8 +20,8 @@ import numpy as np
import paddle.fluid as fluid
import unittest
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.dygraph_to_static.loop_transformer import NameVisitor
from paddle.fluid.dygraph.jit import declarative
SEED = 2020
np.random.seed(SEED)
......@@ -167,19 +167,17 @@ class TestTransformWhileLoop(unittest.TestCase):
self.dyfunc = while_loop_dyfunc
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
x_var = fluid.layers.assign(self.x)
static_func = dygraph_to_static_func(self.dyfunc)
out = static_func(x_var)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
return ret
return self._run(to_static=True)
def _run_dygraph(self):
return self._run(to_static=False)
def _run(self, to_static):
with fluid.dygraph.guard(self.place):
ret = self.dyfunc(fluid.dygraph.to_variable(self.x))
if to_static:
ret = declarative(self.dyfunc)(self.x)
else:
ret = self.dyfunc(self.x)
return ret.numpy()
def test_ast_to_func(self):
......@@ -219,22 +217,20 @@ class TestTransformForLoop(unittest.TestCase):
self.dyfunc = for_loop_dyfunc
def _run_static(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_func = dygraph_to_static_func(self.dyfunc)
out = static_func(self.len)
exe = fluid.Executor(self.place)
ret = exe.run(main_program, fetch_list=out)
return ret
return self._run(to_static=True)
def _run_dygraph(self):
return self._run(to_static=False)
def _run(self, to_static):
with fluid.dygraph.guard(self.place):
ret = self.dyfunc(self.len)
if to_static:
ret = declarative(self.dyfunc)(self.len)
else:
ret = self.dyfunc(self.len)
return ret.numpy()
def test_ast_to_func(self):
static_numpy = self._run_static()
self._run_dygraph()
self.assertTrue(np.allclose(self._run_dygraph(), self._run_static()))
......
......@@ -21,9 +21,13 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
SEED = 2020
class SimpleImgConvPool(fluid.dygraph.Layer):
......@@ -94,7 +98,7 @@ class MNIST(fluid.dygraph.Layer):
loc=0.0, scale=scale)),
act="softmax")
@dygraph_to_static_func
@declarative
def forward(self, inputs, label=None):
x = self.inference(inputs)
if label is not None:
......@@ -125,62 +129,73 @@ class TestMNIST(unittest.TestCase):
drop_last=True)
class TestMNISTWithStaticMode(TestMNIST):
class TestMNISTWithDeclarative(TestMNIST):
"""
Tests model when using `dygraph_to_static_func` to convert dygraph into static
model. It allows user to add customized code to train static model, such as `with`
and `Executor` statement.
Tests model if doesn't change the layers while decorated
by `dygraph_to_static_output`. In this case, everything should
still works if model is trained in dygraph mode.
"""
def test_train(self):
def train_static(self):
return self.train(to_static=True)
def train_dygraph(self):
return self.train(to_static=False)
def test_mnist_declarative(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
self.assertTrue(
np.allclose(dygraph_loss, static_loss),
msg='dygraph is {}\n static_res is \n{}'.format(dygraph_loss,
static_loss))
main_prog = fluid.Program()
with fluid.program_guard(main_prog):
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
loss_data = []
with fluid.dygraph.guard(self.place):
fluid.default_main_program().random_seed = SEED
fluid.default_startup_program().random_seed = SEED
mnist = MNIST()
adam = AdamOptimizer(
learning_rate=0.001, parameter_list=mnist.parameters())
exe = fluid.Executor(self.place)
start = time()
img = fluid.data(
name='img', shape=[None, 1, 28, 28], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
label.stop_gradient = True
prediction, acc, avg_loss = mnist(img, label)
adam.minimize(avg_loss)
exe.run(fluid.default_startup_program())
for epoch in range(self.epoch_num):
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
out = exe.run(main_prog,
fetch_list=[avg_loss, acc],
feed={'img': dy_x_data,
'label': y_data})
if batch_id % 100 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}"
.format(epoch, batch_id,
np.array(out[0]),
np.array(out[1]), time() - start))
if batch_id == 300:
# The accuracy of mnist should converge over 0.9 after 300 batch.
accuracy = np.array(out[1])
self.assertGreater(
accuracy,
0.9,
msg="The accuracy {} of mnist should converge over 0.9 after 300 batch."
.format(accuracy))
for epoch in range(self.epoch_num):
start = time()
for batch_id, data in enumerate(self.train_reader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)
img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
prediction, acc, avg_loss = mnist(img, label=label)
avg_loss.backward()
adam.minimize(avg_loss)
loss_data.append(avg_loss.numpy()[0])
# save checkpoint
mnist.clear_gradients()
if batch_id % 10 == 0:
print(
"Loss at epoch {} step {}: loss: {:}, acc: {}, cost: {}"
.format(epoch, batch_id,
avg_loss.numpy(),
acc.numpy(), time() - start))
start = time()
if batch_id == 50:
mnist.eval()
prediction, acc, avg_loss = mnist(img, label)
loss_data.append(avg_loss.numpy()[0])
break
return loss_data
# TODO: TestCase with cached program is required when building program in `for` loop.
if __name__ == "__main__":
unittest.main()
......@@ -18,8 +18,11 @@ import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative
program_translator = ProgramTranslator()
# 1. print VarBase
@declarative
......@@ -160,14 +163,17 @@ class TestPrintBase(unittest.TestCase):
def set_test_func(self):
raise NotImplementedError("Print test should implement set_test_func")
def get_dygraph_output(self):
def _run(self, to_static):
program_translator.enable(to_static)
with fluid.dygraph.guard():
self.dygraph_func(self.input)
def get_dygraph_output(self):
self._run(to_static=False)
def get_static_output(self):
with fluid.program_guard(fluid.Program()):
# TODO: How to catch C++ stdout to python
self.dygraph_func(self.input)
self._run(to_static=True)
class TestPrintVariable(TestPrintBase):
......
......@@ -31,22 +31,24 @@ from ifelse_simple_func import dyfunc_with_if_else
np.random.seed(0)
# TODO(Aurelius): Currently, `declarative` don't support decorate the function
# that contains layers with initialized operation, like `fc = linear(10, 3)`.
# Because initialized ops will be added into program and be executed many times.
# The parameters are assumed to initialized outside of the function.
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
w = fluid.dygraph.to_variable(weight_numpy)
y = fluid.layers.matmul(x, w)
z = fluid.layers.mean(y)
return z
@declarative
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer, bias_attr=False)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
w = fluid.dygraph.to_variable(weight_numpy)
y = fluid.layers.matmul(x, w)
z = fluid.layers.mean(y)
return z
......@@ -125,7 +127,7 @@ class TestEnableDeclarative(unittest.TestCase):
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.dygraph.guard():
program_translator.enable(True)
static_output = program_translator.get_output(simple_func, x,
weight)
......@@ -143,7 +145,7 @@ class TestEnableDeclarative(unittest.TestCase):
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.dygraph.guard():
program_translator.enable(True)
static_func = program_translator.get_func(simple_func)
self.assertTrue(callable(static_func))
......@@ -162,14 +164,12 @@ class TestEnableDeclarative(unittest.TestCase):
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
program_translator.enable(True)
static_output = program_translator.get_program(simple_func, x,
weight)
self.assertTrue(isinstance(static_output, tuple))
self.assertEqual(len(static_output), 4)
self.assertTrue(isinstance(static_output[0], fluid.Program))
self.assertTrue(isinstance(static_output[1], fluid.Program))
program_translator.enable(True)
static_output = program_translator.get_program(simple_func, x, weight)
self.assertTrue(isinstance(static_output, tuple))
self.assertEqual(len(static_output), 4)
self.assertTrue(isinstance(static_output[0], fluid.Program))
self.assertTrue(isinstance(static_output[1], fluid.Program))
program_translator.enable(False)
with fluid.dygraph.guard():
......@@ -182,7 +182,7 @@ class TestEnableDeclarative(unittest.TestCase):
weight = np.random.randn(32, 64).astype('float32')
program_translator = ProgramTranslator()
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.dygraph.guard():
program_translator.enable(True)
static_output = decorated_simple_func(x, weight)
......
......@@ -21,7 +21,7 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.nn import Embedding
......@@ -30,6 +30,8 @@ from paddle.fluid.optimizer import SGDOptimizer
PRINT_STEP = 20
SEED = 2020
program_translator = ProgramTranslator()
class SimpleLSTMRNN(fluid.Layer):
def __init__(self,
......@@ -169,13 +171,6 @@ class PtbModel(fluid.Layer):
@declarative
def forward(self, input, label, init_hidden, init_cell):
# TODO(liym27): Call `to_variable` to feed data successfully.
# Remove to_variable statements later
input = to_variable(input)
label = to_variable(label)
init_hidden = to_variable(init_hidden)
init_cell = to_variable(init_cell)
init_h = fluid.layers.reshape(
init_hidden, shape=[self.num_layers, -1, self.hidden_size])
......@@ -210,7 +205,8 @@ class PtbModel(fluid.Layer):
np.save("emb_grad", self.x_emb.gradient())
def train_dygraph(place):
def train(place):
num_layers = 1
batch_size = 4
hidden_size = 10
......@@ -286,78 +282,14 @@ def train_dygraph(place):
return out_loss, last_hidden.numpy(), last_cell.numpy()
def train_static(place):
num_layers = 1
batch_size = 4
hidden_size = 10
num_steps = 3
init_scale = 0.1
max_epoch = 1
dropout = 0.0
vocab_size = 1000
batch_num = 200
main_prog = fluid.Program()
startup_prog = fluid.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
with fluid.program_guard(main_prog, startup_prog):
ptb_model = PtbModel(
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale,
dropout=dropout)
sgd = SGDOptimizer(
learning_rate=1e-3, parameter_list=ptb_model.parameters())
program_translator = ProgramTranslator()
program_translator.set_optimizer(sgd, index_of_loss=0)
for epoch_id in range(max_epoch):
total_loss = 0.0
iters = 0.0
total_sample = 0
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
for step_id in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
y_data = y_data.reshape((-1, 1))
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, num_steps, 1))
dy_loss, last_hidden, last_cell = ptb_model(
x_data, y_data, init_hidden_data, init_cell_data)
out_loss = dy_loss.numpy()
total_loss += out_loss
iters += num_steps
total_sample += 1
def train_dygraph(place):
program_translator.enable(False)
return train(place)
if step_id % PRINT_STEP == 0:
if step_id == 0:
logging.info(
"epoch %d | step %d, loss %0.3f" %
(epoch_id, step_id, total_loss / total_sample))
avg_batch_time = time.time()
else:
speed = PRINT_STEP / (time.time() - avg_batch_time)
logging.info(
"epoch %d | step %d, loss %0.3f, speed %.3f steps/s"
% (epoch_id, step_id, total_loss / total_sample,
speed))
avg_batch_time = time.time()
return out_loss, last_hidden.numpy(), last_cell.numpy()
def train_static(place):
program_translator.enable(True)
return train(place)
class TestPtb(unittest.TestCase):
......
......@@ -19,14 +19,17 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph import declarative
program_translator = ProgramTranslator()
SEED = 2020
np.random.seed(SEED)
# Use a decorator to test exception
@dygraph_to_static_func
@declarative
def dyfunc_with_if(x_v):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
......@@ -35,7 +38,7 @@ def dyfunc_with_if(x_v):
return x_v
@dygraph_to_static_func
@declarative
def nested_func(x_v):
x_v = fluid.dygraph.to_variable(x_v)
......@@ -57,17 +60,16 @@ class TestRecursiveCall1(unittest.TestCase):
self.dyfunc = nested_func
def get_dygraph_output(self):
program_translator.enable(False)
with fluid.dygraph.guard():
res = self.dyfunc(self.input).numpy()
return res
def get_static_output(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = self.dyfunc(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dyfunc(self.input).numpy()
return res
def test_transformed_static_result(self):
static_res = self.get_static_output()
......@@ -93,14 +95,14 @@ class MyConvLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
@declarative
def forward(self, inputs):
y = dyfunc_with_if(inputs)
y = lambda_fun(y)
y = self.dymethod(y)
return y
@dygraph_to_static_func
@declarative
def dymethod(self, x_v):
x_v = fluid.layers.assign(x_v)
return x_v
......@@ -120,7 +122,7 @@ class MyLayer(fluid.dygraph.Layer):
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.5)))
@dygraph_to_static_func
@declarative
def forward(self, inputs):
h = self.conv(inputs)
out = self.fc(h)
......@@ -134,7 +136,7 @@ class TestRecursiveCall2(unittest.TestCase):
self.place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace()
def get_dygraph_output(self):
def _run(self):
with fluid.dygraph.guard():
self.dygraph_func = self.Layer()
fluid.default_startup_program.random_seed = SEED
......@@ -144,21 +146,13 @@ class TestRecursiveCall2(unittest.TestCase):
return res.numpy()
def get_static_output(self):
startup_program = fluid.Program()
startup_program.random_seed = SEED
main_program = fluid.Program()
main_program.random_seed = SEED
with fluid.program_guard(main_program, startup_program):
self.dygraph_func = self.Layer()
data = fluid.layers.assign(self.input)
static_out = self.dygraph_func(data)
def get_dygraph_output(self):
program_translator.enable(False)
return self._run()
exe = fluid.Executor(self.place)
exe.run(startup_program)
static_res = exe.run(main_program, fetch_list=static_out)
return static_res[0]
def get_static_output(self):
program_translator.enable(True)
return self._run()
def test_transformed_static_result(self):
dygraph_res = self.get_dygraph_output()
......
......@@ -44,7 +44,8 @@ class SimpleFcLayer(fluid.dygraph.Layer):
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self):
# TODO(Aurelius84): disable temporarily, need new save_inference interface
def _test_save_inference_model(self):
fc_size = 20
x = np.random.random((fc_size, fc_size)).astype('float32')
......
......@@ -18,10 +18,10 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.optimizer import AdamOptimizer
from test_fetch_feed import Linear
np.random.seed(2020)
......@@ -29,53 +29,50 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
class TestDyToStaticSaveLoad(unittest.TestCase):
def test_save_load_same_result(self):
program_translator = ProgramTranslator()
x_data = np.random.randn(30, 10, 32).astype('float32')
batch_num = 3
with fluid.dygraph.guard(place):
program_translator.enable(True)
x = fluid.dygraph.to_variable(x_data)
net = Linear(32, 64)
adam = AdamOptimizer(
learning_rate=0.1, parameter_list=net.parameters())
for i in range(batch_num):
static_out, static_loss = net(x)
# Update parameters
static_loss.backward()
adam.minimize(static_loss)
net.clear_gradients()
# Save parameters
fluid.save_dygraph(net.state_dict(), "./test_dy2stat_save_load")
# minimize() will update parameter, call net() to get output and avg_loss.
# Switch into eval mode.
net.eval()
static_out, static_loss = net(x)
# load parameters into dygraph
with fluid.dygraph.guard(place):
dygraph_net = Linear(32, 64)
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
# Load parameters
model_dict, _ = fluid.load_dygraph("./test_dy2stat_save_load")
dygraph_net.set_dict(model_dict)
# Switch into eval mode.
dygraph_net.eval()
x = fluid.dygraph.to_variable(x_data)
# predict output
program_translator.enable(False)
dygraph_out, dygraph_loss = dygraph_net(x)
class TestDyToStaticSaveLoad(unittest.TestCase):
def test_save_load_same_result(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
with fluid.dygraph.guard(place):
dygraph_result = simple_func(x, weight)
main_program, startup_program, inputs, outputs = ProgramTranslator(
).get_program(decorated_simple_func, x, weight)
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.save(main_program, "./test_dy2stat_save_load")
# set vars to zero so that we can test load in same file
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
tensor = fluid.global_scope().find_var(var.name).get_tensor()
tensor.set(np.zeros_like(np.array(tensor)), place)
# make sure all the paramerter or optimizer var have been set to zero
tensor_np = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertEqual(0, np.sum(np.abs(tensor_np)))
fluid.load(main_program, "./test_dy2stat_save_load")
static_result = exe.run(main_program,
feed={inputs[0].name: x},
fetch_list=outputs)
self.assertTrue(np.allclose(dygraph_result.numpy(), static_result))
self.assertTrue(np.allclose(dygraph_out.numpy(), static_out.numpy()))
self.assertTrue(np.allclose(dygraph_loss.numpy(), static_loss.numpy()))
if __name__ == '__main__':
......
......@@ -22,8 +22,9 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator
SEED = 2020
np.random.seed(SEED)
......@@ -286,7 +287,7 @@ class SeResNeXt(fluid.dygraph.Layer):
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Uniform(-stdv, stdv)))
@dygraph_to_static_func
@declarative
def forward(self, inputs, label):
if self.layers == 50 or self.layers == 101:
y = self.conv0(inputs)
......@@ -314,7 +315,10 @@ class SeResNeXt(fluid.dygraph.Layer):
return out, avg_loss, acc_top1, acc_top5
def train_dygraph(train_reader):
def train(train_reader, to_static):
program_translator = ProgramTranslator()
program_translator.enable(to_static)
np.random.seed(SEED)
with fluid.dygraph.guard(place):
......@@ -374,75 +378,6 @@ def train_dygraph(train_reader):
)
def train_static(train_reader):
np.random.seed(SEED)
exe = fluid.Executor(place)
main_prog = fluid.Program()
startup_prog = fluid.Program()
with fluid.program_guard(main_prog, startup_prog):
with fluid.unique_name.guard():
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
img = fluid.data(
name="img", shape=[None, 3, 224, 224], dtype="float32")
label = fluid.data(name="label", shape=[None, 1], dtype="int64")
label.stop_gradient = True
se_resnext = SeResNeXt()
pred, avg_loss_, acc_top1_, acc_top5_ = se_resnext(img, label)
optimizer = optimizer_setting(train_parameters,
se_resnext.parameters())
optimizer.minimize(avg_loss_)
exe.run(startup_prog)
for epoch_id in range(EPOCH_NUM):
total_loss = 0.0
total_acc1 = 0.0
total_acc5 = 0.0
total_sample = 0
step_idx = 0
speed_list = []
for step_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0].reshape(3, 224, 224)
for x in data]).astype('float32')
y_data = np.array([x[1] for x in data]).astype('int64').reshape(
BATCH_SIZE, 1)
pred_, avg_loss, acc_top1, acc_top5 = exe.run(
main_prog,
feed={"img": dy_x_data,
"label": y_data},
fetch_list=[pred, avg_loss_, acc_top1_, acc_top5_])
total_loss += avg_loss
total_acc1 += acc_top1
total_acc5 += acc_top5
total_sample += 1
if step_id % PRINT_STEP == 0:
if step_id == 0:
logging.info( "epoch %d | step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f" % \
( epoch_id, step_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample))
avg_batch_time = time.time()
else:
speed = PRINT_STEP / (time.time() - avg_batch_time)
speed_list.append(speed)
logging.info( "epoch %d | step %d, loss %0.3f, acc1 %0.3f, acc5 %0.3f, speed %.3f steps/s" % \
( epoch_id, step_id, total_loss / total_sample, \
total_acc1 / total_sample, total_acc5 / total_sample, speed))
avg_batch_time = time.time()
step_idx += 1
if step_idx == STEP_NUM:
break
return pred_, avg_loss, acc_top1, acc_top5
class TestSeResnet(unittest.TestCase):
def setUp(self):
self.train_reader = paddle.batch(
......@@ -452,8 +387,10 @@ class TestSeResnet(unittest.TestCase):
drop_last=True)
def test_check_result(self):
pred_1, loss_1, acc1_1, acc5_1 = train_static(self.train_reader)
pred_2, loss_2, acc1_2, acc5_2 = train_dygraph(self.train_reader)
pred_1, loss_1, acc1_1, acc5_1 = train(
self.train_reader, to_static=False)
pred_2, loss_2, acc1_2, acc5_2 = train(
self.train_reader, to_static=True)
self.assertTrue(
np.allclose(pred_1, pred_2),
......
......@@ -18,7 +18,7 @@ import numpy
import unittest
import paddle.fluid as fluid
from paddle.fluid.dygraph.jit import dygraph_to_static_func
from paddle.fluid.dygraph.jit import declarative
def dyfunc_tensor_shape_1(x):
......@@ -171,20 +171,19 @@ class TestTensorShapeBasic(unittest.TestCase):
def init_test_func(self):
self.dygraph_func = dyfunc_tensor_shape_1
def get_dygraph_output(self):
def _run(self, to_static):
with fluid.dygraph.guard():
res = self.dygraph_func(self.input).numpy()
if to_static:
res = declarative(self.dygraph_func)(self.input).numpy()
else:
res = self.dygraph_func(self.input).numpy()
return res
def get_static_output(self):
main_program = fluid.Program()
with fluid.program_guard(main_program):
static_out = dygraph_to_static_func(self.dygraph_func)(self.input)
exe = fluid.Executor(self.place)
static_res = exe.run(main_program, fetch_list=static_out)
def get_dygraph_output(self):
return self._run(to_static=False)
return static_res[0]
def get_static_output(self):
return self._run(to_static=False)
def test_transformed_static_result(self):
static_res = self.get_static_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册