未验证 提交 b73594b4 编写于 作者: X Xiaoxu Chen 提交者: GitHub

【Prim】support higher order autodiff for dy2static+composite (#53171)

* [Dy2St]Fix x grad names when high order gradient

* Polish error msg

* Add inputs var to backward in dy2st

* Fix error

* Get grad names for backward API

* Fix save load

* Polish code

* Add ut

* [prim] fix not support optional grad bugs in higher order autodiff

* [prim] remove duplicate fill_any_like caused by infershape_for_composite

* fix _strip_grad_suffix_ bugs in higher-order autodiff

* [prim] create output for test_static_prim.cc

---------
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 3e3297c7
......@@ -350,10 +350,7 @@ class CompositeGradOpMakerBase {
framework::VarDesc* SingleOutputGrad(const std::string& name) const {
auto* var = this->SingleForwardOutput(name);
if (!var) {
PADDLE_THROW(platform::errors::InvalidArgument(
"GetSingleOutputGrad for %s_grad faild, if it is Optional input,"
"please use GetOptionalSingleOutputGrad replaced. ",
name));
return nullptr;
}
auto var_name = var->Name();
auto grad_var_name = framework::GradVarName(var_name);
......@@ -371,7 +368,7 @@ class CompositeGradOpMakerBase {
return StaticCompositeContext::Instance().GetBlock()->FindVar(
grad_var_name);
} else {
return StaticCompositeContext::Instance().GetBlock()->Var(grad_var_name);
return nullptr;
}
}
......
......@@ -28,6 +28,8 @@ import warnings
from collections.abc import Sequence
import re
__all__ = [
'append_backward',
'gradients',
......@@ -459,10 +461,14 @@ def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given variable name
e.g. x@GRAD ==> x
x@GRAD@GRAD ==> x
y@GRAD@RENAME@1 ==> y
z@GRAD_slice_0@GRAD ==> z@GRAD_slice_0
"""
pos = name.find(core.grad_var_suffix())
new_name = name[:pos] if pos != -1 else name
pos = re.search(f'{core.grad_var_suffix()}$', name) or re.search(
f'{core.grad_var_suffix()}@', name
)
new_name = name[: pos.start()] if pos is not None else name
new_pos = name.rfind('grad/')
return new_name[new_pos + 5 :] if new_pos != -1 else new_name
......@@ -1343,15 +1349,17 @@ def _append_backward_ops_(
if core._is_bwd_prim_enabled():
composite_block = program.clone().current_block()
# Infer shape for operators whose output haven't been created.
# Create output and infer shape for operators whose output haven't
# been created.
for op in composite_block.ops:
if not all(
tuple(
composite_block._find_var_recursive(arg)
for arg in op.output_arg_names
)
):
infershape_for_composite(composite_block, op.desc)
for name in op.output_arg_names:
if not (
composite_block.desc.has_var_recursive(name.encode())
or name == core.empty_var_name()
):
composite_block.create_var(name=name)
op.desc.infer_var_type(composite_block.desc)
op.desc.infer_shape(composite_block.desc)
# add grad_op_desc by reversed ops
for op in reversed(ops):
......@@ -1492,6 +1500,15 @@ def _append_backward_ops_(
or name in input_grad_names_set
)
is_append_grad = False
# NOTE: In primitive mode, the intermediate variable generated by
# decompositing raw grad op are not satisfied the rule of 'XX@GRAD',
# which will cause it be pruned according to current pruning logic.
# For simplicity, we treate all prmitive operators as one raw
# operator, and keep the pruning logic consistent with currently
# logic. The drawback of this solution is may lead to some primitive
# operators are not pruned, which is needed to fixed.
# FIXME: Optimize pruning logic from the perspective of whole graph.
input_grad_names = []
for op_desc in grad_op_desc:
input_grad_names += [
......@@ -1499,20 +1516,20 @@ def _append_backward_ops_(
for name in op_desc.input_arg_names()
if is_grad_name(name)
]
# some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs.
if len(input_grad_names) == 0:
is_append_grad = True
break
for op_desc in grad_op_desc:
# some code of gradient ops, like increment, are not very
# standard, there is no @GRAD in these ops' inputs.
continue
if _some_in_set_(input_grad_names, input_grad_names_set):
if _some_in_set_(input_grad_names, input_grad_names_set):
is_append_grad = True
for op_desc in grad_op_desc:
grad_op_descs.append(op_desc)
is_append_grad = True
for name in op_desc.output_arg_names():
input_grad_names_set.add(name)
if is_append_grad:
grad_to_var.update(op_grad_to_var)
else:
......@@ -1774,17 +1791,19 @@ def infershape_for_composite(block, grad_op_desc):
op_desc.check_attrs()
op_desc.infer_var_type(block.desc)
op_desc.infer_shape(block.desc)
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
grad_op_desc.copy_from(op_desc)
# NOTE: Some operator doesn't infer dtype correctly, this patch set the
# grad_var dtype same with corresponding forward variable.
for arg in grad_op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
if not framework.OpProtoHolder.instance().has_op_proto(grad_op_desc.type()):
# NOTE: Some raw fluid grad operators which hadn't been decomposed may not
# implement InferVarType method, such as elementwise_xx_grad, and it will
# cause the dtype or shape of corresponding cotangent incorrect. This
# patch set the cotangent dtype and shape same with corresponding
# forward variable. For primitive operators, we have ensure all
# InferVarType method to be executed correctly in PR#52818, we skip
# this patch for primitive operators.
for arg in grad_op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_shape_(arg, block)
def _rename_grad_(
......
......@@ -20,6 +20,8 @@ import paddle
from paddle import fluid
from paddle.fluid.backward import calc_gradient
paddle.enable_static()
class TestCalcGradient(unittest.TestCase):
def test_calc_gradient(self):
......
......@@ -206,6 +206,11 @@ TEST(StaticPrim, TanhBackwardComposite) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
Tensor out_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out_grad_desc =
static_cast<prim::DescTensor*>(out_grad.impl().get())->get_ptr();
target_block->RenameVar(out_grad_desc->Name(), "b@GRAD");
std::vector<std::unique_ptr<framework::OpDesc>> grad_ops =
std::move(framework::OpInfoMap::Instance()
.Get(forward_opdesc->Type())
......@@ -288,6 +293,11 @@ TEST(StaticCompositeGradMaker, TestMutiInputMethod) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
Tensor out_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out_grad_desc =
static_cast<prim::DescTensor*>(out_grad.impl().get())->get_ptr();
target_block->RenameVar(out_grad_desc->Name(), "out@GRAD");
auto test = TestCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
......@@ -353,6 +363,19 @@ TEST(StaticCompositeGradMaker, TestMutiOutputMethod) {
auto* forward_opdesc = target_block->AllOps()[0];
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<framework::BlockDesc*> grad_sub_block;
Tensor out1_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out1_grad_desc =
static_cast<prim::DescTensor*>(out1_grad.impl().get())->get_ptr();
target_block->RenameVar(out1_grad_desc->Name(), "out1@GRAD");
Tensor out2_grad = prim::empty<prim::DescTensor>(
shape, phi::DataType::FLOAT32, paddle::Place());
framework::VarDesc* out2_grad_desc =
static_cast<prim::DescTensor*>(out2_grad.impl().get())->get_ptr();
target_block->RenameVar(out2_grad_desc->Name(), "out2@GRAD");
auto test = TestCompositeGradMaker(*forward_opdesc,
std::unordered_set<std::string>(),
&grad_to_var,
......
......@@ -67,6 +67,11 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
for n, vs in cls.outputs.items()
},
)
for _, outs in cls.outputs.items():
for out in outs:
block.create_var(name=out + core.grad_var_suffix())
cls.fwd = block.ops[0].desc
@classmethod
......
......@@ -46,6 +46,11 @@ class TestGetGradOpDescPrimEnabled(unittest.TestCase):
for n, vs in self.outputs.items()
},
)
for _, outs in self.outputs.items():
for out in outs:
block.create_var(name=out + core.grad_var_suffix())
self.fwd = block.ops[0].desc
def tearDown(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册