未验证 提交 82630f38 编写于 作者: H Huihuang Zheng 提交者: GitHub

[Dy2stat] Add Support for paddle.grad (#33110)

This PR made these changes to support double grad:

1. Translate `paddle.grad` to `paddle.static.gradients` to support double grad for dy2stat.
2. Fix IfElseTransformer bug which may not change value if "Store before Load" variable is in "Store" statement is in IfElse conditional statement
3. Add `DOut` to support double grad variables in `run_program_op`
4. Add support for renaming for double grads for `jit.save/load`
上级 1e9299aa
...@@ -83,6 +83,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -83,6 +83,13 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"contains at most one scope." "contains at most one scope."
"NOTE: Do not use Scope directly because Scope output is not " "NOTE: Do not use Scope directly because Scope output is not "
"currently supported."); "currently supported.");
AddOutput("DOut",
"(vector<LoDTensor>)"
"The output tensors for GRAD Tensors in RunProgram forward "
"operator, the forward operator contains GRAD Tensors when it "
"computes double grad.")
.AsDuplicable()
.AsDispensable();
AddAttr<BlockDesc*>("global_block", AddAttr<BlockDesc*>("global_block",
"(BlockDesc *)" "(BlockDesc *)"
"The global block of executed program desc."); "The global block of executed program desc.");
...@@ -154,6 +161,7 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -154,6 +161,7 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput("Params", this->Input("Params")); grad_op->SetInput("Params", this->Input("Params"));
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
grad_op->SetInput("OutScope", this->Output("OutScope")); grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetInput("DOut", this->Output("DOut"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), grad_op->SetOutput(framework::GradVarName("Params"),
this->InputGrad("Params")); this->InputGrad("Params"));
......
...@@ -131,6 +131,9 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars, ...@@ -131,6 +131,9 @@ static void ShareVarsIntoScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names, const std::vector<std::string> &var_names,
framework::Scope *scope) { framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == "Fake_var") {
continue;
}
auto *var = scope->Var(var_names[i]); auto *var = scope->Var(var_names[i]);
CheckInputVarStatus(*vars[i], var_names[i]); CheckInputVarStatus(*vars[i], var_names[i]);
VariableShare(*vars[i], var); VariableShare(*vars[i], var);
...@@ -141,9 +144,9 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars, ...@@ -141,9 +144,9 @@ static void ShareVarsFromScope(const std::vector<Variable *> &vars,
const std::vector<std::string> &var_names, const std::vector<std::string> &var_names,
framework::Scope *scope) { framework::Scope *scope) {
for (size_t i = 0; i < vars.size(); ++i) { for (size_t i = 0; i < vars.size(); ++i) {
if (var_names[i] == framework::kEmptyVarName) { if (var_names[i] == framework::kEmptyVarName ||
VLOG(2) << "find variable name is " << framework::kEmptyVarName var_names[i] == "Fake_var") {
<< ", skip it!"; VLOG(2) << "find variable name is " << var_names[i] << ", skip it!";
continue; continue;
} }
// NOTE: Here skip not found var is dangerous, if a bug is caused here, // NOTE: Here skip not found var is dangerous, if a bug is caused here,
...@@ -170,9 +173,11 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -170,9 +173,11 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
auto &input_vars = ctx.MultiInputVar("X"); auto &input_vars = ctx.MultiInputVar("X");
auto &param_vars = ctx.MultiInputVar("Params"); auto &param_vars = ctx.MultiInputVar("Params");
auto output_vars = ctx.MultiOutputVar("Out"); auto output_vars = ctx.MultiOutputVar("Out");
auto dout_vars = ctx.MultiOutputVar("DOut");
auto input_var_names = ctx.InputNames("X"); auto input_var_names = ctx.InputNames("X");
auto output_var_names = ctx.OutputNames("Out"); auto output_var_names = ctx.OutputNames("Out");
auto dout_var_names = ctx.OutputNames("DOut");
// current program may not hold parameters // current program may not hold parameters
std::vector<std::string> param_names; std::vector<std::string> param_names;
...@@ -195,7 +200,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -195,7 +200,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// Step 2. prepare executor and init persistable variables // Step 2. prepare executor and init persistable variables
framework::Executor exe(ctx.GetPlace()); framework::Executor exe(ctx.GetPlace());
auto exe_ctx = framework::GetExecutorInfoFromCache( auto exe_ctx = framework::GetExecutorInfoFromCache(
exe, ctx, {output_var_names}, /*is_grad=*/false); exe, ctx, {output_var_names, dout_var_names}, /*is_grad=*/false);
// NOTE(Aurelius84): While training some models, forward can be called many // NOTE(Aurelius84): While training some models, forward can be called many
// times and then apply backpropagation all at once, such as Reinforcement // times and then apply backpropagation all at once, such as Reinforcement
...@@ -219,6 +224,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> { ...@@ -219,6 +224,7 @@ class RunProgramOpKernel : public framework::OpKernel<T> {
// Step 4. Get Output // Step 4. Get Output
details::ShareVarsFromScope(output_vars, output_var_names, &scope); details::ShareVarsFromScope(output_vars, output_var_names, &scope);
details::ShareVarsFromScope(dout_vars, dout_var_names, &scope);
// Debug info: scope info when run end // Debug info: scope info when run end
VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front()); VLOG(3) << framework::GenScopeTreeDebugInfo(out_scope_vec->front());
......
...@@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br ...@@ -25,6 +25,7 @@ from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import Br
from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer from paddle.fluid.dygraph.dygraph_to_static.break_continue_transformer import BreakTransformOptimizer
from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer from paddle.fluid.dygraph.dygraph_to_static.call_transformer import CallTransformer
from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer from paddle.fluid.dygraph.dygraph_to_static.cast_transformer import CastTransformer
from paddle.fluid.dygraph.dygraph_to_static.grad_transformer import GradTransformer
from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer from paddle.fluid.dygraph.dygraph_to_static.ifelse_transformer import IfElseTransformer
from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer from paddle.fluid.dygraph.dygraph_to_static.list_transformer import ListTransformer
from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer from paddle.fluid.dygraph.dygraph_to_static.logical_transformer import LogicalTransformer
...@@ -86,6 +87,7 @@ class DygraphToStaticAst(gast.NodeTransformer): ...@@ -86,6 +87,7 @@ class DygraphToStaticAst(gast.NodeTransformer):
PrintTransformer, # print statement PrintTransformer, # print statement
CallTransformer, # transform call recursively CallTransformer, # transform call recursively
CastTransformer, # type casting statement CastTransformer, # type casting statement
GradTransformer, # transform paddle.grad to paddle.gradients
] ]
for index, transformer in enumerate(transformers): for index, transformer in enumerate(transformers):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gast
import warnings
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static import utils
class GradTransformer(gast.NodeTransformer):
"""
A class transforms dygraph paddle.grad to static graph paddle.gradients. The
transformation is applied to support double grad mode.
"""
def __init__(self, wrapper_root):
assert isinstance(
wrapper_root, AstNodeWrapper
), "Input non-AstNodeWrapper node for the initialization of GradTransformer."
self.wrapper_root = wrapper_root
self.root = wrapper_root.node
def transform(self):
self.visit(self.root)
def visit_Call(self, node):
self.generic_visit(node)
if not is_grad_api_node(node):
return node
dygraph_grad_parameters = [
"outputs", "inputs", "grad_outputs", "retain_graph", "create_graph",
"only_inputs", "allow_unused", "no_grad_vars"
]
to_static_grad_param = {
"outputs": "targets",
"inputs": "inputs",
"grad_outputs": "target_gradients",
"no_grad_vars": "no_grad_set"
}
static_keywords = []
for kw in node.keywords:
if kw.arg not in dygraph_grad_parameters or kw.arg not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
dygraph_grad_parameters.remove(kw.arg)
kw.arg = to_static_grad_param[kw.arg]
static_keywords.append(kw)
for i in range(len(node.args)):
arg_name = dygraph_grad_parameters[i]
if arg_name not in to_static_grad_param:
warnings.warn("paddle.grad has unsupported parameter in jit: " +
kw.arg + ", jit will discard it")
continue
kw = gast.keyword(
arg=to_static_grad_param[arg_name], value=node.args[i])
static_keywords.append(kw)
node.func = gast.parse('paddle.static.gradients').body[0].value
node.keywords = static_keywords
node.args = []
return node
def is_grad_api_node(node):
assert isinstance(node, gast.Call)
api_name = utils.ast_to_source_code(node.func).strip()
if utils.is_paddle_api(node):
return api_name.endswith("grad")
return False
...@@ -402,7 +402,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -402,7 +402,7 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
var for var in _vars_with_store(child_dict) if var in parent_dict var for var in _vars_with_store(child_dict) if var in parent_dict
]) ])
def _vars_loaded_before_store(ids_dict): def _vars_loaded(ids_dict):
""" """
gast.Param is also a kind of `load` semantic. gast.Param is also a kind of `load` semantic.
""" """
...@@ -411,8 +411,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -411,8 +411,6 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
for ctx in ctxs: for ctx in ctxs:
if isinstance(ctx, (gast.Load, gast.Param)): if isinstance(ctx, (gast.Load, gast.Param)):
new_dict[k].append(ctx) new_dict[k].append(ctx)
elif isinstance(ctx, gast.Store):
break
return new_dict return new_dict
# modified vars # modified vars
...@@ -439,8 +437,12 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict, ...@@ -439,8 +437,12 @@ def parse_cond_return(parent_vars_dict, if_vars_dict, else_vars_dict,
new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars new_vars_in_body_and_orelse = body_new_vars & orelse_new_vars
# 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node. # 3. new var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# TODO(zhhsplendid): the _vars_loaded can be optimized as _vars_loaded_before_store. Because if a variable is stored before load,
# the value would change by the store statement, we don't have to return to change the value. However, analysis is
# complex because if the IfElse is nested and outer IfElse store statement may not run at all. We will put this optimization
# as the future TODO
used_vars_after_ifelse = set( used_vars_after_ifelse = set(
[var for var in _vars_loaded_before_store(after_ifelse_vars_dict)]) [var for var in _vars_loaded(after_ifelse_vars_dict)])
new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse new_vars_to_create = new_vars_in_one_of_body_or_orelse & used_vars_after_ifelse | new_vars_in_body_and_orelse
# 4. generate return_ids of if/else node. # 4. generate return_ids of if/else node.
......
...@@ -135,6 +135,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -135,6 +135,7 @@ class PartialProgramLayer(layers.Layer):
self._origin_main_program = self._verify_program(main_program) self._origin_main_program = self._verify_program(main_program)
self._inner_scope = core.Scope() self._inner_scope = core.Scope()
# Set default mode to train # Set default mode to train
self._double_grads = self._get_double_grads(self._origin_main_program)
self.training = True self.training = True
@LazyInitialized @LazyInitialized
...@@ -192,24 +193,44 @@ class PartialProgramLayer(layers.Layer): ...@@ -192,24 +193,44 @@ class PartialProgramLayer(layers.Layer):
""" """
required_params = [] required_params = []
for param in self._params: for param in self._params:
found_param = False
for block in program.blocks: for block in program.blocks:
if param.name in block.vars: for op in block.ops:
required_params.append(param) if param.name in op.input_arg_names or param.name in op.output_arg_names:
required_params.append(param)
found_param = True
break
if found_param:
break break
self._params = required_params self._params = required_params
def _get_double_grads(self, program):
double_grads = []
for block in program.blocks:
for name in block.vars:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(), False)
double_grads.append(var_base)
return double_grads
def forward(self, inputs): def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
framework._dygraph_tracer().trace_op( framework._dygraph_tracer().trace_op(
type='run_program', type='run_program',
inputs={ inputs={
'X': valid_vars(in_vars), 'X': valid_vars(in_vars),
'Params': valid_vars(self._params) 'Params': valid_vars(self._params)
}, },
outputs={'Out': valid_vars(out_vars), outputs={
'OutScope': tmp_scope_vec}, 'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec,
'DOut': valid_vars(self._double_grads)
},
attrs={ attrs={
'global_block': self.program.desc.block(0), 'global_block': self.program.desc.block(0),
'start_op_index': 0, 'start_op_index': 0,
......
...@@ -166,29 +166,46 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all): ...@@ -166,29 +166,46 @@ def _get_loaded_var_new_old(program_desc, all_new_old_dict_all):
def _rename_var_program_desc(program_desc, include=None, exclude=None): def _rename_var_program_desc(program_desc, include=None, exclude=None):
""" """
Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication Change the name of the loaded variables.Use 'unique_name.generate' to avoid duplication.
e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. It is used when loading multiple program during inference.
If 'include' is not `None`,variables that are not in include are not renamed.
If 'exclude' is not `None`,variables that are in exclude will are not renamed. e.g. linear_0.tmp_3 ==> linear_0.tmp_1, x ==> x_0. For double grad, x@GRAD ==> x_0@GRAD
If 'include' is not `None`,variables in include and the corresponding
double grad variables (if exist) are renamed.
If 'exclude' is not `None`,variables that are in exclude and the
corresponding double grad variables (if exist) are not renamed.
Args: Args:
program_desc(ProgramDesc):the variables in it will be modified. program_desc(ProgramDesc):the variables in it will be modified.
include(List):list of names of variables. include(List):list of names of variables.
exclude(List):list of names of variables. exclude(List):list of names of variables.
Returns:
tuple of (dict_rename_var_new_old, dict_rename_var_old_new)
dict_rename_var_new_old is a dict mapping from new name to old name
dict_rename_var_old_new is a dict mapping from old name to new name
""" """
dict_rename_var_old_new = dict() dict_rename_var_old_new = dict()
dict_rename_var_new_old = dict() dict_rename_var_new_old = dict()
old_names = [] old_names = []
# Store all old names
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx) cur_block = program_desc.block(b_idx)
for var in cur_block.all_vars(): for var in cur_block.all_vars():
old_names.append(var.name()) old_names.append(var.name())
# Create dict_rename_var_new_old and dict_rename_var_old_new for non double
# grad variables
has_double_grad = False
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx) cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()): for var_idx, var in enumerate(cur_block.all_vars()):
name_old = var.name() name_old = var.name()
is_double_grad_var = "@GRAD" in name_old
has_double_grad = has_double_grad or is_double_grad_var
should_rename = (include is None or name_old in include) and ( should_rename = (include is None or name_old in include) and (
exclude is None or name_old not in exclude) exclude is None or
name_old not in exclude) and not is_double_grad_var
if should_rename: if should_rename:
temp_name = name_old.split('_') temp_name = name_old.split('_')
if len(temp_name) > 1 and temp_name[-1].isnumeric(): if len(temp_name) > 1 and temp_name[-1].isnumeric():
...@@ -206,9 +223,29 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -206,9 +223,29 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
if name_old != name_new: if name_old != name_new:
cur_block._rename_var( cur_block._rename_var(
cpt.to_bytes(name_old), cpt.to_bytes(name_new)) cpt.to_bytes(name_old), cpt.to_bytes(name_new))
dict_rename_var_old_new[name_old] = name_new if not is_double_grad_var:
dict_rename_var_new_old[name_new] = name_old dict_rename_var_old_new[name_old] = name_new
dict_rename_var_new_old[name_new] = name_old
# Handle double grad names
if has_double_grad:
double_grad_rename_dict = {}
for name_old in dict_rename_var_old_new:
for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx)
for var_idx, var in enumerate(cur_block.all_vars()):
var_name = var.name()
if "@GRAD" in var_name and name_old in var_name:
new_var_name = var_name.replace(
name_old, dict_rename_var_old_new[name_old])
double_grad_rename_dict[var_name] = new_var_name
for var_name in double_grad_rename_dict:
dict_rename_var_old_new[var_name] = double_grad_rename_dict[
var_name]
dict_rename_var_new_old[double_grad_rename_dict[
var_name]] = var_name
# Rename on program desc
for b_idx in six.moves.range(program_desc.num_blocks()): for b_idx in six.moves.range(program_desc.num_blocks()):
cur_block = program_desc.block(b_idx) cur_block = program_desc.block(b_idx)
for op_idx in six.moves.range(cur_block.op_size()): for op_idx in six.moves.range(cur_block.op_size()):
...@@ -220,6 +257,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -220,6 +257,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op._rename_input( op._rename_input(
input_arg_name, input_arg_name,
dict_rename_var_old_new[input_arg_name]) dict_rename_var_old_new[input_arg_name])
if cur_block.has_var(cpt.to_bytes(input_arg_name)):
cur_block._rename_var(
cpt.to_bytes(input_arg_name),
cpt.to_bytes(dict_rename_var_old_new[
input_arg_name]))
for output_arg_name in op.output_arg_names(): for output_arg_name in op.output_arg_names():
if output_arg_name in dict_rename_var_old_new: if output_arg_name in dict_rename_var_old_new:
if output_arg_name != dict_rename_var_old_new[ if output_arg_name != dict_rename_var_old_new[
...@@ -227,6 +269,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None): ...@@ -227,6 +269,11 @@ def _rename_var_program_desc(program_desc, include=None, exclude=None):
op._rename_output( op._rename_output(
output_arg_name, output_arg_name,
dict_rename_var_old_new[output_arg_name]) dict_rename_var_old_new[output_arg_name])
if cur_block.has_var(cpt.to_bytes(output_arg_name)):
cur_block._rename_var(
cpt.to_bytes(output_arg_name),
cpt.to_bytes(dict_rename_var_old_new[
output_arg_name]))
program_desc.flush() program_desc.flush()
return dict_rename_var_new_old, dict_rename_var_old_new return dict_rename_var_new_old, dict_rename_var_old_new
...@@ -267,9 +314,10 @@ class _ProgramHolder(object): ...@@ -267,9 +314,10 @@ class _ProgramHolder(object):
def __init__(self, program_desc): def __init__(self, program_desc):
super(_ProgramHolder, self).__init__() super(_ProgramHolder, self).__init__()
# input, output, persistable var info # input, output, persistable, double_grads var info
self._input_descs = [] self._input_descs = []
self._output_descs = [] self._output_descs = []
self._double_grad_descs = []
self._persistable_names = [] self._persistable_names = []
# execution scope # execution scope
...@@ -277,7 +325,6 @@ class _ProgramHolder(object): ...@@ -277,7 +325,6 @@ class _ProgramHolder(object):
# append suffix var name dict # append suffix var name dict
self._suffix_varname_dict = None self._suffix_varname_dict = None
# forward program # forward program
self._infer_program_desc = self._preprocess(program_desc) self._infer_program_desc = self._preprocess(program_desc)
# forward + backward program # forward + backward program
...@@ -304,6 +351,10 @@ class _ProgramHolder(object): ...@@ -304,6 +351,10 @@ class _ProgramHolder(object):
def persistable_names(self): def persistable_names(self):
return self._persistable_names return self._persistable_names
@property
def double_grad_descs(self):
return self._double_grad_descs
@property @property
def scope(self): def scope(self):
return self._inner_scope return self._inner_scope
...@@ -347,6 +398,12 @@ class _ProgramHolder(object): ...@@ -347,6 +398,12 @@ class _ProgramHolder(object):
for op_idx in reversed(ops_to_remove): for op_idx in reversed(ops_to_remove):
root_block._remove_op(op_idx, op_idx + 1) root_block._remove_op(op_idx, op_idx + 1)
for i in range(program_desc.num_blocks()):
block_desc = program_desc.block(i)
for var_desc in block_desc.all_vars():
if "@GRAD" in var_desc.name():
self._double_grad_descs.append(var_desc)
# 2. Input processing, reverse feed vars # 2. Input processing, reverse feed vars
self._input_descs.reverse() self._input_descs.reverse()
...@@ -412,7 +469,6 @@ class _ProgramHolder(object): ...@@ -412,7 +469,6 @@ class _ProgramHolder(object):
# rewrite a series of methods for append_backward for program_desc. # rewrite a series of methods for append_backward for program_desc.
# Therefore, in order to reuse the method of backward.py, build the program here. # Therefore, in order to reuse the method of backward.py, build the program here.
program = _build_program_by_desc(program_desc_copy) program = _build_program_by_desc(program_desc_copy)
# 3. Add the outputs which is only used for training and not saved in # 3. Add the outputs which is only used for training and not saved in
# inference program. # inference program.
for block_idx in six.moves.range(program.num_blocks): for block_idx in six.moves.range(program.num_blocks):
...@@ -738,6 +794,20 @@ def _run_dygraph(instance, input, program_holder): ...@@ -738,6 +794,20 @@ def _run_dygraph(instance, input, program_holder):
core.VarDesc.VarType.STEP_SCOPES, True) core.VarDesc.VarType.STEP_SCOPES, True)
tmp_scope_vec.value().set_scope(program_holder.scope) tmp_scope_vec.value().set_scope(program_holder.scope)
double_grad_vars = []
for var_desc in program_holder.double_grad_descs:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
double_grad_vars.append(var)
if len(double_grad_vars) == 0:
double_grad_vars = [
core.VarBase(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
]
# 2. run program by op # 2. run program by op
trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program trace_program = program_holder.infer_program if instance._is_test else program_holder.train_program
end_op_index = program_holder.infer_program.block(0).op_size() end_op_index = program_holder.infer_program.block(0).op_size()
...@@ -745,8 +815,11 @@ def _run_dygraph(instance, input, program_holder): ...@@ -745,8 +815,11 @@ def _run_dygraph(instance, input, program_holder):
type='run_program', type='run_program',
inputs={'X': input_vars, inputs={'X': input_vars,
'Params': persistable_vars}, 'Params': persistable_vars},
outputs={'Out': output_vars, outputs={
'OutScope': tmp_scope_vec}, 'Out': output_vars,
'OutScope': tmp_scope_vec,
'DOut': double_grad_vars
},
attrs={ attrs={
'global_block': trace_program.block(0), 'global_block': trace_program.block(0),
'start_op_index': 0, 'start_op_index': 0,
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle
import unittest
class GradLayer(paddle.nn.Layer):
def __init__(self):
super(GradLayer, self).__init__()
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
y = x * x
dx = paddle.grad(outputs=[y], inputs=[x])[0]
return dx
class GradLinearLayer(paddle.nn.Layer):
def __init__(self):
super(GradLinearLayer, self).__init__()
self.linear = paddle.nn.Linear(5, 5, bias_attr=False)
@paddle.jit.to_static
def forward(self, x):
x.stop_gradient = False
tmp = x + x
for i in range(10):
tmp = self.linear(tmp)
out = tmp
dx = paddle.grad(
[out], [x], None, create_graph=True, allow_unused=False)[0]
return dx
class TestGrad(unittest.TestCase):
def setUp(self):
self.func = GradLayer()
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
def _run(self, func, to_static):
prog_trans = paddle.jit.ProgramTranslator()
prog_trans.enable(to_static)
ret = func(self.x).numpy()
prog_trans.enable(True)
return ret
def test_forward(self):
dygraph_res = self._run(self.func, to_static=False)
static_res = self._run(self.func, to_static=True)
self.assertTrue(np.allclose(static_res, dygraph_res))
class TestGradLinear(TestGrad):
def setUp(self):
self.func = GradLinearLayer()
self.x = paddle.ones(shape=[10, 2, 5], dtype='float32')
self.x.stop_gradient = False
def test_save_infer_program(self):
path = "double_grad_infer_model"
input_spec = [
paddle.static.InputSpec(
shape=[10, 2, 5], dtype='float32')
]
paddle.jit.save(self.func, path, input_spec=input_spec)
load_func = paddle.jit.load(path)
origin_res = self.func(self.x).numpy()
load_res = load_func(self.x).numpy()
self.assertTrue(np.allclose(origin_res, load_res))
def test_save_train_program(self):
grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
optimizer = paddle.optimizer.SGD(learning_rate=0.01,
grad_clip=grad_clip,
parameters=self.func.parameters())
for i in range(10):
out = self.func(self.x)
avg_loss = paddle.mean(paddle.abs(out - 1))
avg_loss.backward()
optimizer.minimize(avg_loss)
self.func.clear_gradients()
path = "double_grad_train_model"
paddle.jit.save(self.func, path)
load_func = paddle.jit.load(path)
origin_res = self.func(self.x).numpy()
load_res = load_func(self.x).numpy()
self.assertTrue(np.allclose(origin_res, load_res))
if __name__ == '__main__':
unittest.main()
...@@ -19,10 +19,13 @@ import unittest ...@@ -19,10 +19,13 @@ import unittest
import numpy as np import numpy as np
import six import six
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle import compat as cpt from paddle import compat as cpt
from paddle.fluid import core, framework, executor from paddle.fluid import core, framework, executor
paddle.enable_static()
@contextlib.contextmanager @contextlib.contextmanager
def program_scope_guard(): def program_scope_guard():
...@@ -164,6 +167,8 @@ class RunProgramOpTest(unittest.TestCase): ...@@ -164,6 +167,8 @@ class RunProgramOpTest(unittest.TestCase):
persistable=True) persistable=True)
inner_scope = core.Scope() inner_scope = core.Scope()
outputs['OutScope'].value().set_scope(inner_scope) outputs['OutScope'].value().set_scope(inner_scope)
outputs['DOut'] = [create_var_base(False, "Fake_var")]
return outputs return outputs
def calc_dygraph_output(self, place): def calc_dygraph_output(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册