diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index e16bcb187f85a7fc59fc06a0d8564841c384c7e2..21b6b882a6f15ebcb532c37643d394ede9b55646 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -21,8 +21,10 @@ import os ######################## ### Global Variables ### ######################## -ops_to_fill_zero_for_empty_grads = set( - ["split_grad", "rnn_grad", "matmul_double_grad"]) +ops_to_fill_zero_for_empty_grads = set([ + "split_grad", "rnn_grad", "matmul_double_grad", "matmul_triple_grad", + "sigmoid_triple_grad" +]) # For API dispatch used at python-level # { op_name : [arg_name, ...] } @@ -171,12 +173,6 @@ def GetForwardFunctionName(string): return f"{string}_final_state_dygraph_function" -def TransformGradVarNameForDoubleGradGeneration(string): - if IsGradName(string): - string = "grad_" + string[:-5] - return string - - def GetIndent(num): tab = " " return "".join([tab for i in range(num)]) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index b2db256f6026a42b9d1c875842419b35a2d5b619..19e42e1bdf6406557e6a515fe658acc97d73fae7 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -31,7 +31,6 @@ from codegen_utils import ParseYamlArgs, ParseYamlReturns, ParseYamlForwardFromB from codegen_utils import ParseYamlForward, ParseYamlBackward from codegen_utils import FunctionGeneratorBase, YamlGeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads -from codegen_utils import TransformGradVarNameForDoubleGradGeneration from codegen_utils import AssertMessage, GetIndent @@ -483,10 +482,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): orig_forward_returns_list = self.orig_forward_returns_list for i in range(len(forward_inputs_list)): - forward_input_name = forward_inputs_list[i][0] forward_input_type = forward_inputs_list[i][1] forward_input_pos = forward_inputs_list[i][2] - orig_input_name = orig_forward_inputs_list[i][0] orig_input_type = orig_forward_inputs_list[i][1] orig_input_pos = orig_forward_inputs_list[i][2] @@ -496,11 +493,9 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase): forward_input_pos, orig_input_pos) for i in range(len(forward_attrs_list)): - orig_attr_name = orig_forward_attrs_list[i][0] orig_attr_type = orig_forward_attrs_list[i][1] orig_attr_default = orig_forward_attrs_list[i][2] orig_attr_pos = orig_forward_attrs_list[i][3] - forward_attr_name = forward_attrs_list[i][0] forward_attr_type = forward_attrs_list[i][1] forward_attr_default = forward_attrs_list[i][2] forward_attr_pos = forward_attrs_list[i][3] @@ -1133,11 +1128,20 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): DygraphFunctionGeneratorBase.__init__(self, forward_api_contents, grad_api_contents, namespace) + # Record name mapping from forward_api_name to grad_api_names + self.to_next_grad_name_mapping = {} # {name : name} + # Generated Results self.node_declaration_str = "" self.node_definition_str = "" self.next_grad_api_contents = next_grad_api_contents + def TransformToNextGradName(self, string): + name_mapping = self.to_next_grad_name_mapping + if string in name_mapping.keys(): + return name_mapping[string] + return string + def ResetOptionalInputs(self): namespace = self.namespace grad_api_contents = self.grad_api_contents @@ -1147,6 +1151,22 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): self.optional_inputs = base_generator.optional_inputs + def RecordGrad2NextGradNameMapping(self, next_node_generator): + next_orig_inputs_list = next_node_generator.orig_forward_inputs_list + next_orig_returns_list = next_node_generator.orig_forward_returns_list + + next_forward_inputs_list = next_node_generator.forward_inputs_list + next_forward_returns_list = next_node_generator.forward_returns_list + for i in range(len(next_orig_inputs_list)): + grad_name = next_orig_inputs_list[i][0] + next_forward_name = next_forward_inputs_list[i][0] + self.to_next_grad_name_mapping[grad_name] = next_forward_name + + for i in range(len(next_orig_returns_list)): + grad_ret_name = next_orig_returns_list[i][0] + next_ret_name = next_forward_returns_list[i][0] + self.to_next_grad_name_mapping[grad_ret_name] = next_ret_name + def GenerateHigherOrderNodeCreationCode(self): namespace = self.namespace grad_api_contents = self.grad_api_contents @@ -1164,6 +1184,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): next_node_generator.GenerateNodeCreationCodes() grad_node_creation_str = next_node_generator.node_creation_str + self.RecordGrad2NextGradNameMapping(next_node_generator) + return grad_node_creation_str def GenerateNodeDeclaration(self): @@ -1253,8 +1275,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): for name, (_, is_fwd_input, grad_api_position), in backward_forward_inputs_map.items(): tensor_wrapper_name = GetSavedName(name) - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) is_optional = (name in self.optional_inputs) tensor_wrapper_recover_str = f"{indent}auto {transformed_tensor_name} = egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, this->shared_from_this());" @@ -1274,8 +1295,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # Grad Ins from grads for name, (ttype, fwd_position, grad_api_position) in backward_grad_inputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) is_optional = (name in self.optional_inputs) if IsPlainTensorType(ttype): @@ -1316,8 +1336,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): num_outputs = len(backward_grad_outputs_map.keys()) for name, (ttype, fwd_position, grad_api_position) in backward_grad_outputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) if num_outputs == 1: get_tensor_str = f"{indent}auto& {transformed_tensor_name} = grad_api_result;" @@ -1339,8 +1358,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): compute_require_grad_args_list = ["trace_backward"] for name, (ttype, pos, grad_api_position) in backward_grad_inputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) input_autograd_meta_name = GetAutoGradMetaName( transformed_tensor_name) @@ -1358,8 +1376,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): # 2. Get TensorWrapper AutoGradMeta for name, (ttype, _, pos), in backward_forward_inputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) input_autograd_meta_name = GetAutoGradMetaName( transformed_tensor_name) @@ -1382,8 +1399,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): outputs_autograd_meta_list = [] num_fwd_outputs = len(backward_grad_outputs_map.keys()) for name, (rtype, pos, _) in backward_grad_outputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) output_autograd_meta_name = GetAutoGradMetaName( transformed_tensor_name) @@ -1417,8 +1433,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): returns_str = f"{indent}std::vector> returns({slot_num_bwd_outputs});\n" for name, (ttype, fwd_position, grad_api_position) in backward_grad_outputs_map.items(): - transformed_tensor_name = TransformGradVarNameForDoubleGradGeneration( - name) + transformed_tensor_name = self.TransformToNextGradName(name) # Infer Grad API Return Type if num_bwd_outputs == 1: @@ -1441,6 +1456,9 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): grad_node_name = GetGradNodeName(forward_api_name) + if len(grad_node_creation_str) == 0: + grad_node_creation_str = f"if(create_graph) VLOG(3) << \"Higher order grad node for {grad_node_name} has not been implemented yet.\";" + self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format( grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name, grad_function_call_str, get_outputs_str, inputs_autograd_meta_str, @@ -1457,11 +1475,11 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase): ##################### ## Code Generation ## ##################### - self.GenerateNodeDeclaration() - # Higher-order GradNode generation grad_node_creation_str = self.GenerateHigherOrderNodeCreationCode() + self.GenerateNodeDeclaration() + self.GenerateNodeDefinition(grad_node_creation_str) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 64acc887b42c09b4d2736821f8b27172e4b15ac0..4e029d4c27c034df702e517f167ee8d1cd210fad 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -206,6 +206,54 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x, dz->share_meta(z); } } +void GeneralQuaternaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& z, + const MetaTensor& k, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz, + MetaTensor* dk) { + if (dx) { + dx->share_meta(x); + } + if (dy) { + dy->share_meta(y); + } + if (dz) { + dz->share_meta(z); + } + if (dk) { + dk->share_meta(k); + } +} + +void GeneralQuinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& z, + const MetaTensor& k, + const MetaTensor& l, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz, + MetaTensor* dk, + MetaTensor* dl) { + if (dx) { + dx->share_meta(x); + } + if (dy) { + dy->share_meta(y); + } + if (dz) { + dz->share_meta(z); + } + if (dk) { + dk->share_meta(k); + } + if (dl) { + dl->share_meta(l); + } +} void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx) { if (dx) { diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index c0eb478168988f6fdbb782049974ebb0336cd4a8..3cd4875e99923ffebecbeba3282e337129e17eea 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -96,6 +96,26 @@ void GeneralTernaryGradInferMeta(const MetaTensor& x, MetaTensor* dy, MetaTensor* dz); +void GeneralQuaternaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& z, + const MetaTensor& k, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz, + MetaTensor* dk); + +void GeneralQuinaryGradInferMeta(const MetaTensor& x, + const MetaTensor& y, + const MetaTensor& z, + const MetaTensor& k, + const MetaTensor& l, + MetaTensor* dx, + MetaTensor* dy, + MetaTensor* dz, + MetaTensor* dk, + MetaTensor* dl); + void GeneralUnaryGradInferMeta(const MetaTensor& x, MetaTensor* dx); void GumbelSoftmaxGradInferMeta(const MetaTensor& out, diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index be6f97ad7c96e33ee3ca4c18191cdbb1c186a395..82e168a3c630b319d3d53d8d6761e7ed55d8528a 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -125,18 +125,18 @@ void EluDoubleGradKernel(const Context& dev_ctx, template void SigmoidDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, - const DenseTensor& ddx, const DenseTensor& dout, + const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout); template void SigmoidTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, - const DenseTensor& ddx, const DenseTensor& dout, - const DenseTensor& d_ddout, + const DenseTensor& ddx, const DenseTensor& d_dout_new, + const DenseTensor& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx); diff --git a/paddle/phi/kernels/impl/activation_grad_impl.h b/paddle/phi/kernels/impl/activation_grad_impl.h index 37273b7944edea142ad1df27d1e9ee367d33ce48..bf9b7cdf559d3bdc2cd8866066642e6d75414cfe 100644 --- a/paddle/phi/kernels/impl/activation_grad_impl.h +++ b/paddle/phi/kernels/impl/activation_grad_impl.h @@ -243,8 +243,8 @@ void LogitGradKernel(const Context& dev_ctx, template void SigmoidDoubleGradKernel(const Context& dev_ctx, const DenseTensor& out, - const DenseTensor& ddx, const DenseTensor& dout, + const DenseTensor& ddx, DenseTensor* dout_new, DenseTensor* ddout) { if (dout_new) { @@ -262,10 +262,10 @@ void SigmoidDoubleGradKernel(const Context& dev_ctx, template void SigmoidTripleGradKernel(const Context& dev_ctx, const DenseTensor& out, - const DenseTensor& ddx, const DenseTensor& dout, - const DenseTensor& d_ddout, + const DenseTensor& ddx, const DenseTensor& d_dout_new, + const DenseTensor& d_ddout, DenseTensor* d_out_new, DenseTensor* d_dout, DenseTensor* d_ddx) { diff --git a/paddle/phi/ops/compat/activation_sig.cc b/paddle/phi/ops/compat/activation_sig.cc index 34f830abe7ea374e29b1aa747765eb3d488aa83e..8add832c366cfdc6bdf9e4cfdbe5b025afcf9b13 100644 --- a/paddle/phi/ops/compat/activation_sig.cc +++ b/paddle/phi/ops/compat/activation_sig.cc @@ -139,13 +139,13 @@ KernelSignature TanhTripleGradOpArgumentMapping( KernelSignature SigmoidDoubleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature( - "sigmoid_double_grad", {"Out", "DDX", "DOut"}, {}, {"DOutNew", "DDOut"}); + "sigmoid_double_grad", {"Out", "DOut", "DDX"}, {}, {"DOutNew", "DDOut"}); } KernelSignature SigmoidTripleGradOpArgumentMapping( const ArgumentMappingContext& ctx) { return KernelSignature("sigmoid_triple_grad", - {"Out", "DDX", "DOut", "D_DDOut", "D_DOut_New"}, + {"Out", "DOut", "DDX", "D_DOut_New", "D_DDOut"}, {}, {"D_OutNew", "D_DOut", "D_DDx"}); } diff --git a/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py b/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py index 027c0002c7103997a5f58ead3079349d77616a46..f0c5316412f1e4c24efba99c6fde20c91ae1a0ca 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_triple_grad.py @@ -42,6 +42,68 @@ def random_var(size, low=-1, high=1, dtype='float32'): return fluid.dygraph.to_variable(x_np) +class TestDygraphTripleGradMatmul(TestCase): + def test_matmul_triple_grad(self): + input_numpy = np.ones([3, 3]) * 2 + with _test_eager_guard(): + x = paddle.to_tensor( + input_numpy, stop_gradient=False, dtype='float32') + y = paddle.to_tensor( + input_numpy, stop_gradient=False, dtype='float32') + out = paddle.matmul(x, y, False, False) + + new_out_g = paddle.to_tensor( + np.ones([3, 3]), stop_gradient=False, dtype='float32') + new_x_g, new_y_g = paddle.grad( + [out], [x, y], [new_out_g], + retain_graph=True, + create_graph=True) + + new_x_g_g = paddle.to_tensor( + np.ones([3, 3]), stop_gradient=False, dtype='float32') + new_y_g_g = paddle.to_tensor( + np.ones([3, 3]), stop_gradient=False, dtype='float32') + new_a, new_b, new_c = paddle.grad( + [new_x_g, new_y_g], [x, y, new_out_g], [new_x_g_g, new_y_g_g], + retain_graph=True, + create_graph=True) + + new_a.backward() + + out_ref = np.ones([3, 3]) * 12.0 + self.assertTrue(np.array_equal(out.numpy(), out_ref)) + + new_x_g_ref = np.ones([3, 3]) * 6.0 + new_y_g_ref = np.ones([3, 3]) * 6.0 + self.assertTrue(np.array_equal(new_x_g.numpy(), new_x_g_ref)) + self.assertTrue(np.array_equal(new_y_g.numpy(), new_y_g_ref)) + + new_a_ref = np.ones([3, 3]) * 3.0 + new_b_ref = np.ones([3, 3]) * 3.0 + new_c_ref = np.ones([3, 3]) * 12.0 + + self.assertTrue(np.array_equal(new_a.numpy(), new_a_ref)) + self.assertTrue(np.array_equal(new_b.numpy(), new_b_ref)) + self.assertTrue(np.array_equal(new_c.numpy(), new_c_ref)) + + x_grad_ref = np.ones([3, 3]) * 0.0 + self.assertTrue(np.array_equal(x.grad.numpy(), x_grad_ref)) + + y_grad_ref = np.ones([3, 3]) * 0.0 + self.assertTrue(np.array_equal(y.grad.numpy(), y_grad_ref)) + + new_out_g_ref = np.ones([3, 3]) * 3.0 + self.assertTrue( + np.array_equal(new_out_g.grad.numpy(), new_out_g_ref)) + + new_x_g_g_ref = np.ones([3, 3]) * 0.0 + new_y_g_g_ref = np.ones([3, 3]) * 3.0 + self.assertTrue( + np.array_equal(new_x_g_g.grad.numpy(), new_x_g_g_ref)) + self.assertTrue( + np.array_equal(new_y_g_g.grad.numpy(), new_y_g_g_ref)) + + class TestDygraphTripleGrad(TestCase): def setUp(self): self.sort_sum_gradient = False diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 5908e05a514d701a5be1c6d24e690e56567d62c8..e268675bdcfae7e2686f900866ceacaacda022ba 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -706,6 +706,7 @@ param : [x, y, grad_out] kernel : func : matmul_double_grad + backward : matmul_triple_grad optional : grad_x_grad, grad_y_grad - backward_api : matmul_grad @@ -719,6 +720,17 @@ func : matmul_grad backward : matmul_double_grad +- backward_api : matmul_triple_grad + forward : matmul_double_grad (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, bool transpose_x=false, bool transpose_y=false) -> Tensor(grad_x), Tensor(grad_y), Tensor(grad_grad_out) + args : (Tensor x, Tensor y, Tensor fwd_grad_out, Tensor fwd_grad_grad_x, Tensor fwd_grad_grad_y, Tensor grad_x_grad, Tensor grad_y_grad, Tensor grad_grad_out_grad, bool transpose_x=false, bool transpose_y=false) + output : Tensor(x_grad), Tensor(y_grad), Tensor(fwd_grad_out_grad), Tensor(fwd_grad_grad_x_grad), Tensor(fwd_grad_grad_y_grad) + infer_meta : + func : GeneralQuinaryGradInferMeta + param : [x, y, fwd_grad_out, fwd_grad_grad_x, fwd_grad_grad_y] + kernel : + func : matmul_triple_grad + optional : grad_x_grad, grad_y_grad, grad_grad_out_grad + - backward_api : matrix_power_grad forward : matrix_power (Tensor x, int n) -> Tensor(out) args : (Tensor x, Tensor out, Tensor out_grad, int n) @@ -1090,6 +1102,17 @@ kernel : func : sigmoid_cross_entropy_with_logits_grad +- backward_api : sigmoid_double_grad + forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x) + args : (Tensor out, Tensor fwd_grad_out, Tensor grad_x_grad) + output : Tensor(out_grad), Tensor(fwd_grad_out_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [out, fwd_grad_out] + kernel : + func : sigmoid_double_grad + backward : sigmoid_triple_grad + - backward_api : sigmoid_grad forward : sigmoid (Tensor x) -> Tensor(out) args : (Tensor out, Tensor out_grad) @@ -1099,6 +1122,17 @@ param : [out] kernel : func : sigmoid_grad + backward : sigmoid_double_grad + +- backward_api : sigmoid_triple_grad + forward : sigmoid_double_grad (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x) -> Tensor(grad_out), Tensor(grad_grad_out) + args : (Tensor out, Tensor fwd_grad_out, Tensor grad_grad_x, Tensor grad_out_grad, Tensor grad_grad_out_grad) + output : Tensor(out_grad), Tensor(fwd_grad_out_grad), Tensor(grad_grad_x_grad) + infer_meta : + func : GeneralTernaryGradInferMeta + param : [out, fwd_grad_out, grad_grad_x] + kernel : + func : sigmoid_double_grad - backward_api : silu_grad forward : silu (Tensor x) -> Tensor(out)