未验证 提交 82630f38 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Support for paddle.grad (#33110)

This PR made these changes to support double grad:

1. Translate `paddle.grad` to `paddle.static.gradients` to support double grad for dy2stat.
2. Fix IfElseTransformer bug which may not change value if "Store before Load" variable is in "Store" statement is in IfElse conditional statement
3. Add `DOut` to support double grad variables in `run_program_op`
4. Add support for renaming for double grads for `jit.save/load`
上级 1e9299aa
......@@ -83,6 +83,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"contains at most one scope."
"NOTE: Do not use Scope directly because Scope output is not "
"currently supported.");
AddOutput("DOut",
"(vector<LoDTensor>)"
"The output tensors for GRAD Tensors in RunProgram forward "
"operator, the forward operator contains GRAD Tensors when it "
"computes double grad.")
.AsDuplicable()
.AsDispensable();
AddAttr<BlockDesc*>("global_block",
"(BlockDesc *)"
"The global block of executed program desc.");
......@@ -154,6 +161,7 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput("Params", this->Input("Params"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetInput("DOut", this->Output("DOut"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"),
this->InputGrad("Params"));
......
......@@ -131,6 +131,9 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == "Fake_var") {
continue;
}
auto *var = scope->Var(var_names[i]);
CheckInputVarStatus(*vars[i], var_names[i]);
VariableShare(*vars[i], var);
......@@ -141,9 +144,9 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names,
framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == framework::kEmptyVarName) {
VLOG(2) << "find variable name is " << framework::kEmptyVarName
<< ", skip it!";
if (var_names[i] == framework::kEmptyVarName ||
var_names[i] == "Fake_var") {
VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
continue;
}
// NOTE: Here skip not found var is dangerous, if a bug is caused here,
......@@ -170,9 +173,11 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto &input_vars = ctx.MultiInputVar("X");
auto &param_vars = ctx.MultiInputVar("Params");
auto output_vars = ctx.MultiOutputVar("Out");
auto dout_vars = ctx.MultiOutputVar("DOut");
auto input_var_names = ctx.InputNames("X");
auto output_var_names = ctx.OutputNames("Out");
auto dout_var_names = ctx.OutputNames("DOut");
// current program may not hold parameters
std::vector<std::string> param_names;
......@@ -195,7 +200,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// Step 2. prepare executor and init persistable variables
framework::Executor exe(ctx.GetPlace());
auto exe_ctx = framework::GetExecutorInfoFromCache(
exe, ctx, {output_var_names}, /*is_grad=*/false);
exe, ctx, {output_var_names, dout_var_names}, /*is_grad=*/false);
// NOTE(Aurelius84): While training some models, forward can be called many
// times and then apply backpropagation all at once, such as Reinforcement
......@@ -219,6 +224,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
// Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
......
......@@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
......@@ -86,6 +87,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
PrintTransformer, # print statement
CallTransformer, # transform call recursively
CastTransformer, # type casting statement
GradTransformer, # transform paddle.grad to paddle.gradients
]
for index, transformer in enumerate(transformers):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import warnings
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
class GradTransformer(gast.NodeTransformer):
"""
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of GradTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Call(self, node):
self.generic_visit(node)
if not is_grad_api_node(node):
return node
dygraph_grad_parameters = [
"outputs", "inputs", "grad_outputs", "retain_graph", "create_graph",
"only_inputs", "allow_unused", "no_grad_vars"
]
to_static_grad_param = {
"outputs": "targets",
"inputs": "inputs",
"grad_outputs": "target_gradients",
"no_grad_vars": "no_grad_set"
}
static_keywords = []
for kw in node.keywords:
if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
dygraph_grad_parameters.remove(kw.arg)
kw.arg = to_static_grad_param[kw.arg]
static_keywords.append(kw)
for i in range(len(node.args)):
arg_name = dygraph_grad_parameters[i]
if arg_name not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
kw = gast.keyword(
arg=to_static_grad_param[arg_name], value=node.args[i])
static_keywords.append(kw)
node.func = gast.parse('paddle.static.gradients').body[0].value
node.keywords = static_keywords
node.args = []
return node
def is_grad_api_node(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_paddle_api(node):
return api_name.endswith("grad")
return False
......@@ -402,7 +402,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
var for var in _vars_with_store(child_dict) if var in parent_dict
])
def _vars_loaded_before_store(ids_dict):
def _vars_loaded(ids_dict):
"""
gast.Param is also a kind of `load` semantic.
"""
......@@ -411,8 +411,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
for ctx in ctxs:
if isinstance(ctx, (gast.Load, gast.Param)):
new_dict[k].append(ctx)
elif isinstance(ctx, gast.Store):
break
return new_dict
# modified vars
......@@ -439,8 +437,12 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars
# 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# TODO(zhhsplendid): the _vars_loaded can be optimized as _vars_loaded_before_store. Because if a variable is stored before load,
# the value would change by the store statement, we don't have to return to change the value. However, analysis is
# complex because if the IfElse is nested and outer IfElse store statement may not run at all. We will put this optimization
# as the future TODO
used_vars_after_ifelse = set(
[var for var in _vars_loaded_before_store(after_ifelse_vars_dict)])
[var for var in _vars_loaded(after_ifelse_vars_dict)])
new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse
# 4. generate return_ids of if/else node.
......
......@@ -135,6 +135,7 @@ class PartialProgramLayer(layers.Layer):
self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope()
# Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True
@LazyInitialized
......@@ -192,24 +193,44 @@ class PartialProgramLayer(layers.Layer):
"""
required_params = []
for param in self._params:
found_param = False
for block in program.blocks:
if param.name in block.vars:
required_params.append(param)
for op in block.ops:
if param.name in op.input_arg_names or param.name in op.output_arg_names:
required_params.append(param)
found_param = True
break
if found_param:
break
self._params = required_params
def _get_double_grads(self, program):
double_grads = []
for block in program.blocks:
for name in block.vars:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(), False)
double_grads.append(var_base)
return double_grads
def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={
'X': valid_vars(in_vars),
'Params': valid_vars(self._params)
},
outputs={'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec},
outputs={
'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec,
'DOut': valid_vars(self._double_grads)
},
attrs={
'global_block': self.program.desc.block(0),
'start_op_index': 0,
......
......@@ -166,29 +166,46 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
def _rename_var_program_desc(program_desc, include=None, exclude=None):
"""
Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication
e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0.
If 'include' is not `None`,variables that are not in include are not renamed.
If 'exclude' is not `None`,variables that are in exclude will are not renamed.
Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication.
It is used when loading multiple program during inference.
e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. For double grad, x@GRAD ==> x_0@GRAD
If 'include' is not `None`,variables in include and the corresponding
double grad variables (if exist) are renamed.
If 'exclude' is not `None`,variables that are in exclude and the
corresponding double grad variables (if exist) are not renamed.
Args:
program_desc(ProgramDesc):the variables in it will be modified.
include(List):list of names of variables.
exclude(List):list of names of variables.
Returns:
tuple of (dict_rename_var_new_old, dict_rename_var_old_new)
dict_rename_var_new_old is a dict mapping from new name to old name
dict_rename_var_old_new is a dict mapping from old name to new name
"""
dict_rename_var_old_new = dict()
dict_rename_var_new_old = dict()
old_names = []
# Store all old names
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var in cur_block.all_vars():
old_names.append(var.name())
# Create dict_rename_var_new_old and dict_rename_var_old_new for non double
# grad variables
has_double_grad = False
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()):
name_old = var.name()
is_double_grad_var = "@GRAD" in name_old
has_double_grad = has_double_grad or is_double_grad_var
should_rename = (include is None or name_old in include) and (
exclude is None or name_old not in exclude)
exclude is None or
name_old not in exclude) and not is_double_grad_var
if should_rename:
temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric():
......@@ -206,9 +223,29 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
if name_old != name_new:
cur_block._rename_var(
cpt.to_bytes(name_old), cpt.to_bytes(name_new))
dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old
if not is_double_grad_var:
dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old
# Handle double grad names
if has_double_grad:
double_grad_rename_dict = {}
for name_old in dict_rename_var_old_new:
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()):
var_name = var.name()
if "@GRAD" in var_name and name_old in var_name:
new_var_name = var_name.replace(
name_old, dict_rename_var_old_new[name_old])
double_grad_rename_dict[var_name] = new_var_name
for var_name in double_grad_rename_dict:
dict_rename_var_old_new[var_name] = double_grad_rename_dict[
var_name]
dict_rename_var_new_old[double_grad_rename_dict[
var_name]] = var_name
# Rename on program desc
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for op_idx in six.moves.range(cur_block.op_size()):
......@@ -220,6 +257,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op._rename_input(
input_arg_name,
dict_rename_var_old_new[input_arg_name])
if cur_block.has_var(cpt.to_bytes(input_arg_name)):
cur_block._rename_var(
cpt.to_bytes(input_arg_name),
cpt.to_bytes(dict_rename_var_old_new[
input_arg_name]))
for output_arg_name in op.output_arg_names():
if output_arg_name in dict_rename_var_old_new:
if output_arg_name != dict_rename_var_old_new[
......@@ -227,6 +269,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op._rename_output(
output_arg_name,
dict_rename_var_old_new[output_arg_name])
if cur_block.has_var(cpt.to_bytes(output_arg_name)):
cur_block._rename_var(
cpt.to_bytes(output_arg_name),
cpt.to_bytes(dict_rename_var_old_new[
output_arg_name]))
program_desc.flush()
return dict_rename_var_new_old, dict_rename_var_old_new
......@@ -267,9 +314,10 @@ class _ProgramHolder(object):
def __init__(self, program_desc):
super(_ProgramHolder, self).__init__()
# input, output, persistable var info
# input, output, persistable, double_grads var info
self._input_descs = []
self._output_descs = []
self._double_grad_descs = []
self._persistable_names = []
# execution scope
......@@ -277,7 +325,6 @@ class _ProgramHolder(object):
# append suffix var name dict
self._suffix_varname_dict = None
# forward program
self._infer_program_desc = self._preprocess(program_desc)
# forward + backward program
......@@ -304,6 +351,10 @@ class _ProgramHolder(object):
def persistable_names(self):
return self._persistable_names
@property
def double_grad_descs(self):
return self._double_grad_descs
@property
def scope(self):
return self._inner_scope
......@@ -347,6 +398,12 @@ class _ProgramHolder(object):
for op_idx in reversed(ops_to_remove):
root_block._remove_op(op_idx, op_idx + 1)
for i in range(program_desc.num_blocks()):
block_desc = program_desc.block(i)
for var_desc in block_desc.all_vars():
if "@GRAD" in var_desc.name():
self._double_grad_descs.append(var_desc)
# 2. Input processing, reverse feed vars
self._input_descs.reverse()
......@@ -412,7 +469,6 @@ class _ProgramHolder(object):
# rewrite a series of methods for append_backward for program_desc.
# Therefore, in order to reuse the method of backward.py, build the program here.
program = _build_program_by_desc(program_desc_copy)
# 3. Add the outputs which is only used for training and not saved in
# inference program.
for block_idx in six.moves.range(program.num_blocks):
......@@ -738,6 +794,20 @@ def _run_dygraph(instance, input, program_holder):
core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(program_holder.scope)
double_grad_vars = []
for var_desc in program_holder.double_grad_descs:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
double_grad_vars.append(var)
if len(double_grad_vars) == 0:
double_grad_vars = [
core.VarBase(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
]
# 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program
end_op_index = program_holder.infer_program.block(0).op_size()
......@@ -745,8 +815,11 @@ def _run_dygraph(instance, input, program_holder):
type='run_program',
inputs={'X': input_vars,
'Params': persistable_vars},
outputs={'Out': output_vars,
'OutScope': tmp_scope_vec},
outputs={
'Out': output_vars,
'OutScope': tmp_scope_vec,
'DOut': double_grad_vars
},
attrs={
'global_block': trace_program.block(0),
'start_op_index': 0,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import unittest
class GradLayer(paddle.nn.Layer):
def __init__(self):
super(GradLayer, self).__init__()
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
y = x * x
dx = paddle.grad(outputs=[y], inputs=[x])[0]
return dx
class GradLinearLayer(paddle.nn.Layer):
def __init__(self):
super(GradLinearLayer, self).__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
tmp = x + x
for i in range(10):
tmp = self.linear(tmp)
out = tmp
dx = paddle.grad(
[out], [x], None, create_graph=True, allow_unused=False)[0]
return dx
class TestGrad(unittest.TestCase):
def setUp(self):
self.func = GradLayer()
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
def _run(self, func, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
ret = func(self.x).numpy()
prog_trans.enable(True)
return ret
def test_forward(self):
dygraph_res = self._run(self.func, to_static=False)
static_res = self._run(self.func, to_static=True)
self.assertTrue(np.allclose(static_res, dygraph_res))
class TestGradLinear(TestGrad):
def setUp(self):
self.func = GradLinearLayer()
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
def test_save_infer_program(self):
path = "double_grad_infer_model"
input_spec = [
paddle.static.InputSpec(
shape=[10, 2, 5], dtype='float32')
]
paddle.jit.save(self.func, path, input_spec=input_spec)
load_func = paddle.jit.load(path)
origin_res = self.func(self.x).numpy()
load_res = load_func(self.x).numpy()
self.assertTrue(np.allclose(origin_res, load_res))
def test_save_train_program(self):
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
grad_clip=grad_clip,
parameters=self.func.parameters())
for i in range(10):
out = self.func(self.x)
avg_loss = paddle.mean(paddle.abs(out - 1))
avg_loss.backward()
optimizer.minimize(avg_loss)
self.func.clear_gradients()
path = "double_grad_train_model"
paddle.jit.save(self.func, path)
load_func = paddle.jit.load(path)
origin_res = self.func(self.x).numpy()
load_res = load_func(self.x).numpy()
self.assertTrue(np.allclose(origin_res, load_res))
if __name__ == '__main__':
unittest.main()
......@@ -19,10 +19,13 @@ import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle import compat as cpt
from paddle.fluid import core, framework, executor
paddle.enable_static()
@contextlib.contextmanager
def program_scope_guard():
......@@ -164,6 +167,8 @@ class RunProgramOpTest(unittest.TestCase):
persistable=True)
inner_scope = core.Scope()
outputs['OutScope'].value().set_scope(inner_scope)
outputs['DOut'] = [create_var_base(False, "Fake_var")]
return outputs
def calc_dygraph_output(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册