未验证 提交 d8a10977 编写于 作者: Z Zhanlue Yang 提交者: GitHub

[DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul (#41387)

* [Refactor] refactored eager_gen.py PR #2

* [DoubleGrad PR #1] Decoupled code generation logics for Dygraph ForwardFunctions and GradNodes

* Fixed minor issue

* Adjusted logics of GenerateNodeCreationCodes and GenerateForwardDefinition

* Fixed issues

* Supported higher-order grad node generation

* [DoubleGrad PR #4] Supported higher-order GradNode generation

* [DoubleGrad #4] Bug Fixes to Double Grad Node Generation

* Fixed yaml typo

* Fixed yaml typo

* fixed minor issues

* [DoubleGrad PR #5] Enabled gradient computations for grad_tensors passed to paddle.grad()

* Fixed minor issue

* Fixed CI-Inference issue

* Fixed CI-inference issues

* [DoubleGrad PR #7] paddle.grad() to copy backward graph before backward run

* Fixed minor issues

* Fixed issue with backward graph construction logic

* Fixed implementation issues with backward graph reconstruction

* Fixed unittest issue

* Fixed issues

* [DoubleGrad PR #8] Enabled triple grads for sigmoid and matmul

* Fixed issues with phi kernel

* Added triple grad test case

* Fixed minor issue
上级 84e8ae77
......@@ -21,8 +21,10 @@ import os
########################
### Global Variables ###
########################
ops_to_fill_zero_for_empty_grads = set(
["split_grad", "rnn_grad", "matmul_double_grad"])
ops_to_fill_zero_for_empty_grads = set([
"split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad",
"sigmoid_triple_grad"
])
# For API dispatch used at python-level
# { op_name : [arg_name, ...] }
......@@ -171,12 +173,6 @@ def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def TransformGradVarNameForDoubleGradGeneration(string):
if IsGradName(string):
string = "grad_" + string[:-5]
return string
def GetIndent(num):
tab = " "
return "".join([tab for i in range(num)])
......
......@@ -31,7 +31,6 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB
from codegen_utils import ParseYamlForward, ParseYamlBackward
from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import TransformGradVarNameForDoubleGradGeneration
from codegen_utils import AssertMessage, GetIndent
......@@ -483,10 +482,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
orig_forward_returns_list = self.orig_forward_returns_list
for i in range(len(forward_inputs_list)):
forward_input_name = forward_inputs_list[i][0]
forward_input_type = forward_inputs_list[i][1]
forward_input_pos = forward_inputs_list[i][2]
orig_input_name = orig_forward_inputs_list[i][0]
orig_input_type = orig_forward_inputs_list[i][1]
orig_input_pos = orig_forward_inputs_list[i][2]
......@@ -496,11 +493,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
forward_input_pos, orig_input_pos)
for i in range(len(forward_attrs_list)):
orig_attr_name = orig_forward_attrs_list[i][0]
orig_attr_type = orig_forward_attrs_list[i][1]
orig_attr_default = orig_forward_attrs_list[i][2]
orig_attr_pos = orig_forward_attrs_list[i][3]
forward_attr_name = forward_attrs_list[i][0]
forward_attr_type = forward_attrs_list[i][1]
forward_attr_default = forward_attrs_list[i][2]
forward_attr_pos = forward_attrs_list[i][3]
......@@ -1133,11 +1128,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
DygraphFunctionGeneratorBase.__init__(self, forward_api_contents,
grad_api_contents, namespace)
# Record name mapping from forward_api_name to grad_api_names
self.to_next_grad_name_mapping = {} # {name : name}
# Generated Results
self.node_declaration_str = ""
self.node_definition_str = ""
self.next_grad_api_contents = next_grad_api_contents
def TransformToNextGradName(self, string):
name_mapping = self.to_next_grad_name_mapping
if string in name_mapping.keys():
return name_mapping[string]
return string
def ResetOptionalInputs(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
......@@ -1147,6 +1151,22 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.optional_inputs = base_generator.optional_inputs
def RecordGrad2NextGradNameMapping(self, next_node_generator):
next_orig_inputs_list = next_node_generator.orig_forward_inputs_list
next_orig_returns_list = next_node_generator.orig_forward_returns_list
next_forward_inputs_list = next_node_generator.forward_inputs_list
next_forward_returns_list = next_node_generator.forward_returns_list
for i in range(len(next_orig_inputs_list)):
grad_name = next_orig_inputs_list[i][0]
next_forward_name = next_forward_inputs_list[i][0]
self.to_next_grad_name_mapping[grad_name] = next_forward_name
for i in range(len(next_orig_returns_list)):
grad_ret_name = next_orig_returns_list[i][0]
next_ret_name = next_forward_returns_list[i][0]
self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name
def GenerateHigherOrderNodeCreationCode(self):
namespace = self.namespace
grad_api_contents = self.grad_api_contents
......@@ -1164,6 +1184,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
next_node_generator.GenerateNodeCreationCodes()
grad_node_creation_str = next_node_generator.node_creation_str
self.RecordGrad2NextGradNameMapping(next_node_generator)
return grad_node_creation_str
def GenerateNodeDeclaration(self):
......@@ -1253,8 +1275,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
for name, (_, is_fwd_input,
grad_api_position), in backward_forward_inputs_map.items():
tensor_wrapper_name = GetSavedName(name)
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
is_optional = (name in self.optional_inputs)
tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());"
......@@ -1274,8 +1295,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# Grad Ins from grads
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
is_optional = (name in self.optional_inputs)
if IsPlainTensorType(ttype):
......@@ -1316,8 +1336,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
num_outputs = len(backward_grad_outputs_map.keys())
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
if num_outputs == 1:
get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;"
......@@ -1339,8 +1358,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos,
grad_api_position) in backward_grad_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
......@@ -1358,8 +1376,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
# 2. Get TensorWrapper AutoGradMeta
for name, (ttype, _, pos), in backward_forward_inputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
input_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
......@@ -1382,8 +1399,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
outputs_autograd_meta_list = []
num_fwd_outputs = len(backward_grad_outputs_map.keys())
for name, (rtype, pos, _) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
output_autograd_meta_name = GetAutoGradMetaName(
transformed_tensor_name)
......@@ -1417,8 +1433,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
returns_str = f"{indent}std::vector<std::vector<paddle::experimental::Tensor>> returns({slot_num_bwd_outputs});\n"
for name, (ttype, fwd_position,
grad_api_position) in backward_grad_outputs_map.items():
transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration(
name)
transformed_tensor_name = self.TransformToNextGradName(name)
# Infer Grad API Return Type
if num_bwd_outputs == 1:
......@@ -1441,6 +1456,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_node_name = GetGradNodeName(forward_api_name)
if len(grad_node_creation_str) == 0:
grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";"
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, get_outputs_str, inputs_autograd_meta_str,
......@@ -1457,11 +1475,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
#####################
## Code Generation ##
#####################
self.GenerateNodeDeclaration()
# Higher-order GradNode generation
grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode()
self.GenerateNodeDeclaration()
self.GenerateNodeDefinition(grad_node_creation_str)
......
......@@ -206,6 +206,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
dz->share_meta(z);
}
}
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
}
void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl) {
if (dx) {
dx->share_meta(x);
}
if (dy) {
dy->share_meta(y);
}
if (dz) {
dz->share_meta(z);
}
if (dk) {
dk->share_meta(k);
}
if (dl) {
dl->share_meta(l);
}
}
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) {
if (dx) {
......
......@@ -96,6 +96,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x,
MetaTensor* dy,
MetaTensor* dz);
void GeneralQuaternaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk);
void GeneralQuinaryGradInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& z,
const MetaTensor& k,
const MetaTensor& l,
MetaTensor* dx,
MetaTensor* dy,
MetaTensor* dz,
MetaTensor* dk,
MetaTensor* dl);
void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx);
void GumbelSoftmaxGradInferMeta(const MetaTensor& out,
......
......@@ -125,18 +125,18 @@ void EluDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout);
template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx);
......
......@@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& ddx,
DenseTensor* dout_new,
DenseTensor* ddout) {
if (dout_new) {
......@@ -262,10 +262,10 @@ void SigmoidDoubleGradKernel(const Context& dev_ctx,
template <typename T, typename Context>
void SigmoidTripleGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& ddx,
const DenseTensor& dout,
const DenseTensor& d_ddout,
const DenseTensor& ddx,
const DenseTensor& d_dout_new,
const DenseTensor& d_ddout,
DenseTensor* d_out_new,
DenseTensor* d_dout,
DenseTensor* d_ddx) {
......
......@@ -139,13 +139,13 @@ KernelSignature TanhTripleGradOpArgumentMapping(
KernelSignature SigmoidDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"sigmoid_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"});
"sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"});
}
KernelSignature SigmoidTripleGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("sigmoid_triple_grad",
{"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"},
{"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"},
{},
{"D_OutNew", "D_DOut", "D_DDx"});
}
......
......@@ -42,6 +42,68 @@ def random_var(size, low=-1, high=1, dtype='float32'):
return fluid.dygraph.to_variable(x_np)
class TestDygraphTripleGradMatmul(TestCase):
def test_matmul_triple_grad(self):
input_numpy = np.ones([3, 3]) * 2
with _test_eager_guard():
x = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
y = paddle.to_tensor(
input_numpy, stop_gradient=False, dtype='float32')
out = paddle.matmul(x, y, False, False)
new_out_g = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32')
new_x_g, new_y_g = paddle.grad(
[out], [x, y], [new_out_g],
retain_graph=True,
create_graph=True)
new_x_g_g = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32')
new_y_g_g = paddle.to_tensor(
np.ones([3, 3]), stop_gradient=False, dtype='float32')
new_a, new_b, new_c = paddle.grad(
[new_x_g, new_y_g], [x, y, new_out_g], [new_x_g_g, new_y_g_g],
retain_graph=True,
create_graph=True)
new_a.backward()
out_ref = np.ones([3, 3]) * 12.0
self.assertTrue(np.array_equal(out.numpy(), out_ref))
new_x_g_ref = np.ones([3, 3]) * 6.0
new_y_g_ref = np.ones([3, 3]) * 6.0
self.assertTrue(np.array_equal(new_x_g.numpy(), new_x_g_ref))
self.assertTrue(np.array_equal(new_y_g.numpy(), new_y_g_ref))
new_a_ref = np.ones([3, 3]) * 3.0
new_b_ref = np.ones([3, 3]) * 3.0
new_c_ref = np.ones([3, 3]) * 12.0
self.assertTrue(np.array_equal(new_a.numpy(), new_a_ref))
self.assertTrue(np.array_equal(new_b.numpy(), new_b_ref))
self.assertTrue(np.array_equal(new_c.numpy(), new_c_ref))
x_grad_ref = np.ones([3, 3]) * 0.0
self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_ref))
y_grad_ref = np.ones([3, 3]) * 0.0
self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_ref))
new_out_g_ref = np.ones([3, 3]) * 3.0
self.assertTrue(
np.array_equal(new_out_g.grad.numpy(), new_out_g_ref))
new_x_g_g_ref = np.ones([3, 3]) * 0.0
new_y_g_g_ref = np.ones([3, 3]) * 3.0
self.assertTrue(
np.array_equal(new_x_g_g.grad.numpy(), new_x_g_g_ref))
self.assertTrue(
np.array_equal(new_y_g_g.grad.numpy(), new_y_g_g_ref))
class TestDygraphTripleGrad(TestCase):
def setUp(self):
self.sort_sum_gradient = False
......
......@@ -706,6 +706,7 @@
param : [x, y, grad_out]
kernel :
func : matmul_double_grad
backward : matmul_triple_grad
optional : grad_x_grad, grad_y_grad
- backward_api : matmul_grad
......@@ -719,6 +720,17 @@
func : matmul_grad
backward : matmul_double_grad
- backward_api : matmul_triple_grad
forward : matmul_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out)
args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false)
output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad)
infer_meta :
func : GeneralQuinaryGradInferMeta
param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y]
kernel :
func : matmul_triple_grad
optional : grad_x_grad, grad_y_grad, grad_grad_out_grad
- backward_api : matrix_power_grad
forward : matrix_power (Tensor x, int n) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int n)
......@@ -1090,6 +1102,17 @@
kernel :
func : sigmoid_cross_entropy_with_logits_grad
- backward_api : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
args : (Tensor out, Tensor fwd_grad_out, Tensor grad_x_grad)
output : Tensor(out_grad), Tensor(fwd_grad_out_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out, fwd_grad_out]
kernel :
func : sigmoid_double_grad
backward : sigmoid_triple_grad
- backward_api : sigmoid_grad
forward : sigmoid (Tensor x) -> Tensor(out)
args : (Tensor out, Tensor out_grad)
......@@ -1099,6 +1122,17 @@
param : [out]
kernel :
func : sigmoid_grad
backward : sigmoid_double_grad
- backward_api : sigmoid_triple_grad
forward : sigmoid_double_grad (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x) -> Tensor(grad_out), Tensor(grad_grad_out)
args : (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x, Tensor grad_out_grad, Tensor grad_grad_out_grad)
output : Tensor(out_grad), Tensor(fwd_grad_out_grad), Tensor(grad_grad_x_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [out, fwd_grad_out, grad_grad_x]
kernel :
func : sigmoid_double_grad
- backward_api : silu_grad
forward : silu (Tensor x) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册