未验证 提交 2d17df97 编写于 作者: W WangZhen 提交者: GitHub

[Dy2St]Get grad names when call append backward to fix high order gradient (#53250)

[Dy2St]Get grad names when call append backward to fix high order gradient (#53250)
上级 a3a91682
......@@ -689,15 +689,26 @@ class GradNodeRunProgram : public egr::GradNodeBase {
protected:
void ConstructXGradTensors(const std::vector<paddle::Tensor> &x,
std::vector<paddle::Tensor> *x_grad) {
auto x_grad_names =
PADDLE_GET_CONST(std::vector<std::string>, attrs_.at("x_grad_names"));
PADDLE_ENFORCE_EQ(
x.size(),
x_grad_names.size(),
paddle::platform::errors::InvalidArgument(
"The x.size() and x_grad_names.size() should be equal. "
"But received x.size() = %d, x_grad_names.size() = %d",
x.size(),
x_grad_names.size()));
// TODO(dev): Need an elegant way to determine inforamtion of grad_tensor,
// such as: name, tensor type(DenseTensor or SelectedRows).
for (auto &t : x) {
if (t.is_dense_tensor()) {
for (size_t i = 0; i < x.size(); i++) {
if (x[i].is_dense_tensor()) {
x_grad->emplace_back(std::make_shared<phi::DenseTensor>());
} else if (t.is_selected_rows()) {
} else if (x[i].is_selected_rows()) {
x_grad->emplace_back(std::make_shared<phi::SelectedRows>());
}
x_grad->back().set_name(t.name() + "@GRAD");
x_grad->back().set_name(x_grad_names[i]);
}
}
......
......@@ -139,6 +139,10 @@ class RunProgramOpMaker : public framework::OpProtoAndCheckerMaker {
"std::vector<std::string>"
"The names of output gradients.")
.SetDefault({});
AddAttr<std::vector<std::string>>("x_grad_names",
"std::vector<std::string>"
"The names of input gradients.")
.SetDefault({});
AddComment(R"DOC(
RunProgram operator.
......
......@@ -2376,28 +2376,12 @@ def _find_op_path_(
return op_path
def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
"""
Backpropagate the gradients of targets to inputs.
Args:
targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors
inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors
target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All Tensors with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
Default: None.
Return:
(list[Tensor]): A list of gradients for inputs
If an input does not affect targets, the corresponding gradient Tensor
will be None
"""
def calc_gradient_helper(
targets, inputs, target_gradients=None, no_grad_set=None
):
'''
Calculate gradient and return grad_info_map
'''
targets = _as_list(targets)
inputs = _as_list(inputs)
target_gradients = _as_list(target_gradients)
......@@ -2510,7 +2494,11 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
prog._sync_with_cpp()
return grad_info_map
def _get_grad_vars(grad_info_map, inputs):
inputs = _as_list(inputs)
grad_vars = []
for input_var in inputs:
if input_var.name not in grad_info_map:
......@@ -2520,6 +2508,43 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
grad_block = grad_info[1]
grad_var = grad_block.var(grad_info[0])
grad_vars.append(grad_var)
return grad_vars
def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
"""
Backpropagate the gradients of targets to inputs.
Args:
targets(Tensor|list[Tensor]|tuple[Tensor]): The target Tensors
inputs(Tensor|list[Tensor]|tuple[Tensor]): The input Tensors
target_gradients (Tensor|list[Tensor]|tuple[Tensor], optional): The gradient Tensors
of targets which has the same shape with targets, If None, ones will
be created for them.
no_grad_set(set[Tensor|str], optional): Set of Tensors or Tensor.names in the :ref:`api_guide_Block_en` 0 whose gradients
should be ignored. All Tensors with
`stop_gradient=True` from all blocks will
be automatically added into this set.
If this parameter is not None, the Tensors or Tensor.names in this set will be added to the default set.
Default: None.
Return:
(list[Tensor]): A list of gradients for inputs
If an input does not affect targets, the corresponding gradient Tensor
will be None
"""
# NOTE: If you want to modify the logic of calc_gradient, please modify
# it inside the calc_gradient_helper and _get_grad_vars functions
# to ensure the correctness of dy2st mode.
grad_info_map = calc_gradient_helper(
targets,
inputs,
target_gradients=target_gradients,
no_grad_set=no_grad_set,
)
grad_vars = _get_grad_vars(grad_info_map, inputs)
if len(grad_vars) == 1:
return grad_vars[0]
......
......@@ -83,7 +83,8 @@ class TestDropoutOp(OpTest):
self.check_output(check_prim=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', check_prim=False)
class TestDropoutOpInput1d(OpTest):
......@@ -107,7 +108,8 @@ class TestDropoutOpInput1d(OpTest):
self.check_output(check_prim=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', check_prim=False)
class TestDropoutOp2(TestDropoutOp):
......@@ -283,7 +285,8 @@ class TestDropoutOpWithSeed(OpTest):
self.check_output(check_prim=True)
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.05, check_prim=True)
# Now in dy2st mode x_grad = [], so set check_prim=False
self.check_grad(['X'], 'Out', max_relative_error=0.05, check_prim=False)
@unittest.skipIf(
......
......@@ -134,6 +134,8 @@ class TestRunProgram(unittest.TestCase):
['Fake_var@GRAD'],
'out_grad_names',
[out.name + '@GRAD'],
'x_grad_names',
[x_t.name + '@GRAD', y_t.name + '@GRAD'],
]
use_interpretorcore = True
......
......@@ -254,6 +254,8 @@ class RunProgramOpTest(unittest.TestCase):
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
)
)
......@@ -303,6 +305,8 @@ class RunProgramOpTest(unittest.TestCase):
[p.name + '@GRAD' for p in inputs['Params']],
'out_grad_names',
[out.name + '@GRAD' for out in outputs['Out']],
'x_grad_names',
[p.name + '@GRAD' for p in inputs['X']],
)
)
......
......@@ -21,7 +21,7 @@ from paddle import _legacy_C_ops
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.optimizer.lr import LRScheduler
......@@ -29,9 +29,8 @@ from paddle.optimizer.lr import LRScheduler
from . import logging_utils
from .utils import (
RETURN_NO_VALUE_MAGIC_NUM,
_out_grad_names,
_param_grad_names,
backend_guard,
construct_grad_names,
)
__all__ = []
......@@ -208,6 +207,7 @@ class PartialProgramLayer:
self._scope_cache = {}
self._hooker = None
self._backend = kwargs.get('backend', None)
self._grad_var_names = {}
def __call__(self, inputs):
"""
......@@ -443,23 +443,11 @@ class PartialProgramLayer:
def _infer_pure_fp16_program_id(self):
return paddle.utils._hash_with_id(self._infer_pure_fp16_program, self)
@LazyInitialized
def _param_grad_names(self):
return _param_grad_names(self._train_program.desc, self._params)
def get_forward_end_op_idx(self, program):
return self._forward_end_index_map[
paddle.utils._hash_with_id(program, self)
]
@LazyInitialized
def _out_grad_names(self):
return _out_grad_names(
self._train_program.desc,
self.get_forward_end_op_idx(self._train_program),
len(self._outputs.var_ids),
)
@property
def program(self):
"""
......@@ -649,7 +637,33 @@ class PartialProgramLayer:
if targets:
start_idx = len(program.block(0).ops) + len(self._outputs.tolist())
with backend_guard(self._backend):
backward.gradients(targets=targets, inputs=[])
check_type(
targets,
'targets',
(framework.Variable, list, tuple),
'paddle.static.gradients',
)
grad_info_map = backward.calc_gradient_helper(
targets=targets, inputs=[]
)
x_vars = [
program.block(0).var(var.name)
for var in self._inputs
if isinstance(var, framework.Variable)
]
param_vars = [
program.block(0).var(param.name) for param in self._params
]
out_vars = [
program.block(0).var(var.name)
for var in self._outputs
if isinstance(var, framework.Variable)
]
self._grad_var_names = construct_grad_names(
grad_info_map, x_vars, param_vars, out_vars
)
if self._hooker:
program, start_idx = self._hooker.after_append_backward(
......@@ -720,9 +734,11 @@ class PartialProgramLayer:
attrs.extend(
(
'param_grad_names',
self._param_grad_names,
self._grad_var_names.get('param', []),
'out_grad_names',
self._out_grad_names,
self._grad_var_names.get('out', []),
'x_grad_names',
self._grad_var_names.get('x', []),
)
)
if self._cuda_graph_capture_mode:
......@@ -761,9 +777,9 @@ class PartialProgramLayer:
backward_end_op_index = whole_program.desc.block(0).op_size()
# For Backward process in CINN, all param@GRAD shoule be skipped for GC, because
# they will be shared in scope and used by optimizer.
backward_skip_vars = (
self._parse_skip_gc_vars(whole_program) + self._param_grad_names
)
backward_skip_vars = self._parse_skip_gc_vars(
whole_program
) + self._grad_var_names.get('param', [])
backward_builded_program = add_build_strategy_for(
whole_program,
backward_start_op_index,
......
......@@ -32,7 +32,7 @@ import numpy as np
import paddle
from paddle import fluid # noqa: F401
from paddle.fluid import core, unique_name
from paddle.fluid import backward, core, framework, unique_name
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
......@@ -1462,61 +1462,6 @@ def create_name_str(name_ids):
return "(%s, )" % ','.join(names_str)
def _param_grad_names(program_desc, params):
"""
Parse PARAM@GARD name from original train and infer program.
"""
names = []
# NOTE: `names` and `params` must be in the same order so that
# the param grad name can be set correctly in the run_program.
for param in params:
candidate = []
for var in program_desc.block(0).all_vars():
var_name = var.name()
if param.name not in var_name:
continue
suf_count = var_name.count(GRAD_SUFFIX)
if suf_count > 0:
suffix = param.name + GRAD_SUFFIX * suf_count
pre_count = var_name.count(GRAD_PREFIX)
if GRAD_PREFIX * pre_count + suffix == var_name:
candidate.append(var_name)
if candidate:
names.append(
max(
candidate,
key=lambda name: name.count(GRAD_PREFIX)
if GRAD_PREFIX in name
else name.count(GRAD_SUFFIX),
)
)
else:
names.append(param.name + GRAD_SUFFIX)
return names
def _out_grad_names(program_desc, fwd_end_op_index, out_size):
"""
Parse Out@GARD name from original train and infer program.
"""
names = []
for i in range(
fwd_end_op_index,
min(fwd_end_op_index + out_size, program_desc.block(0).op_size()),
):
op = program_desc.block(0).op(i)
# If prim forward op, fill_any_like will be decomposite as fill_constant.
if core._is_fwd_prim_enabled():
target = ('fill_any_like', 'fill_constant')
else:
target = 'fill_any_like'
if op.type() in target:
var_name = op.output('Out')[0]
names.append(var_name)
return names
def prim_or_cinn_is_enabled(build_strategy, backend):
if backend == 'CINN':
return True
......@@ -1571,3 +1516,19 @@ def backend_guard(backend):
finally:
core._set_prim_forward_enabled(orign_fwd)
core._set_prim_backward_enabled(orign_bwd)
def construct_grad_names(grad_info_map, x_vars, param_vars, out_vars):
grad_var_names = {}
fn = (
lambda grad_var: grad_var.name
if isinstance(grad_var, framework.Variable)
else framework.EMPTY_VAR_NAME
)
x_grad_vars = backward._get_grad_vars(grad_info_map, x_vars)
grad_var_names['x'] = list(map(fn, x_grad_vars))
param_grad_vars = backward._get_grad_vars(grad_info_map, param_vars)
grad_var_names['param'] = list(map(fn, param_grad_vars))
out_grad_vars = backward._get_grad_vars(grad_info_map, out_vars)
grad_var_names['out'] = list(map(fn, out_grad_vars))
return grad_var_names
......@@ -20,16 +20,16 @@ import numpy as np
import paddle
from paddle import _legacy_C_ops
from paddle.fluid import backward, core, framework, unique_name
from paddle.fluid.data_feeder import check_type
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import OpProtoHolder, _non_static_mode
from paddle.jit.dy2static.partial_program import (
LazyInitialized,
add_build_strategy_for,
)
from paddle.jit.dy2static.utils import construct_grad_names
from paddle.nn.layer import layers
from .dy2static.utils import _out_grad_names, _param_grad_names
__all__ = []
INFER_MODEL_SUFFIX = ".pdmodel"
......@@ -349,6 +349,7 @@ class _ProgramHolder:
self._train_program_desc = self._append_backward_desc(
self._infer_program_desc
)
self._grad_var_names = {}
# forward:
@switch_to_static_graph
......@@ -419,6 +420,10 @@ class _ProgramHolder:
def scope(self):
return self._inner_scope
@property
def grad_var_names(self):
return self._grad_var_names
def _preprocess(self, program_desc):
# rename persistable variables of 'program_desc'
list_persistable_var = _get_persistable_var_names(program_desc)
......@@ -599,7 +604,29 @@ class _ProgramHolder:
targets.append(program.global_block().var(out.name()))
# 3. append backward
backward.gradients(targets=targets, inputs=[])
check_type(
targets,
'targets',
(framework.Variable, list, tuple),
'paddle.static.gradients',
)
grad_info_map = backward.calc_gradient_helper(
targets=targets, inputs=[]
)
x_vars = [
program.block(0).var(desc.name()) for desc in self._input_descs
]
param_vars = [
program.block(0).var(name) for name in self._persistable_names
]
out_vars = [
program.block(0).var(desc.name()) for desc in self._output_descs
]
self._grad_var_names = construct_grad_names(
grad_info_map, x_vars, param_vars, out_vars
)
return program.desc
......@@ -964,9 +991,11 @@ def _run_dygraph(instance, input, program_holder):
attrs.extend(
(
'param_grad_names',
_param_grad_names(trace_program, persistable_vars),
program_holder.grad_var_names.get('param', []),
'out_grad_names',
_out_grad_names(trace_program, end_op_index, len(output_vars)),
program_holder.grad_var_names.get('out', []),
'x_grad_names',
program_holder.grad_var_names.get('x', []),
)
)
......
......@@ -15,6 +15,8 @@
import unittest
import numpy as np
import paddle
from paddle import ParamAttr
from paddle.nn import BatchNorm, Linear
......@@ -65,5 +67,99 @@ class TestGradNameParse(unittest.TestCase):
opt.minimize(loss)
def tanh_high_order_grad(x):
y = paddle.tanh(x)
return paddle.grad(y, x, create_graph=True)[0]
class TestTanhHighOrderGrad(unittest.TestCase):
def setUp(self):
self.func = tanh_high_order_grad
x1 = paddle.ones((1,))
x1.stop_gradient = False
self.dy_input = (x1,)
self.dy_grad_input = (x1,)
x2 = paddle.ones((1,))
x2.stop_gradient = False
self.dy2st_input = (x2,)
self.dy2st_grad_input = (x2,)
def test_run(self):
try:
dy_out = self.func(*self.dy_input)
dy_grad = paddle.grad(dy_out, self.dy_grad_input)
except:
dy_grad = [None for i in self.dy_grad_input]
dy_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]
dy2st_out = paddle.jit.to_static(self.func)(*self.dy2st_input)
dy2st_grad = paddle.grad(dy2st_out, self.dy2st_grad_input)
dy2st_grad = [
t.numpy() if isinstance(t, paddle.Tensor) else t for t in dy_grad
]
np.testing.assert_equal(dy_grad, dy2st_grad)
dy_input_grad = [
t.grad.numpy() if isinstance(t.grad, paddle.Tensor) else None
for t in self.dy_input
]
dy2st_input_grad = [
t.grad.numpy() if isinstance(t.grad, paddle.Tensor) else None
for t in self.dy2st_input
]
np.testing.assert_equal(dy_input_grad, dy2st_input_grad)
def matmul_high_order_grad(x, y):
z = paddle.matmul(x, y)
g = paddle.grad(z, [x, y], create_graph=False)
return g[0]
class TestMatMulHighOrderGrad1(TestTanhHighOrderGrad):
def setUp(self):
self.func = matmul_high_order_grad
x1 = paddle.ones([1])
x1.stop_gradient = False
y1 = paddle.ones([1])
y1.stop_gradient = False
self.dy_input = (x1, y1)
self.dy_grad_input = (x1,)
x2 = paddle.ones([1])
x2.stop_gradient = False
y2 = paddle.ones([1])
y2.stop_gradient = False
self.dy2st_input = (x2, y2)
self.dy2st_grad_input = (x2,)
class TestMatMulHighOrderGrad2(TestTanhHighOrderGrad):
def setUp(self):
self.func = matmul_high_order_grad
x = np.random.randn(5, 5)
y = np.random.randn(5, 5)
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y1 = paddle.to_tensor(y)
y1.stop_gradient = True
self.dy_input = (x1, y1)
self.dy_grad_input = (x1,)
x2 = paddle.to_tensor(x)
x2.stop_gradient = False
y2 = paddle.to_tensor(y)
y2.stop_gradient = True
self.dy2st_input = (x2, y2)
self.dy2st_grad_input = (x2,)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册