diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 21d1ebddbd45990f39e41e16d5428d36fec83daf..a5683c9e88a56d822393969161aba1d1bc0dc679 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -110,6 +110,35 @@ class ElementwiseMulDoubleGradMaker : public framework::SingleGradOpMaker { } }; +template +class ElementwiseMulTripleGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("elementwise_mul_triple_grad"); + // get input from double grad + op->SetInput("X", this->Input("X")); + op->SetInput("Y", this->Input("Y")); + op->SetInput("DOut", this->Input("DOut")); + op->SetInput("DDX", this->Input("DDX")); + op->SetInput("DDY", this->Input("DDY")); + op->SetInput("D_DX", this->OutputGrad(framework::GradVarName("X"))); + op->SetInput("D_DY", this->OutputGrad(framework::GradVarName("Y"))); + op->SetInput("D_DDOut", this->OutputGrad("DDOut")); + + op->SetAttrMap(this->Attrs()); + + // set outputs + op->SetOutput("D_X", this->InputGrad("X")); + op->SetOutput("D_Y", this->InputGrad("Y")); + op->SetOutput("D_DOut", this->InputGrad("DOut")); + op->SetOutput("D_DDX", this->InputGrad("DDX")); + op->SetOutput("D_DDY", this->InputGrad("DDY")); + } +}; + } // namespace operators } // namespace paddle @@ -123,8 +152,13 @@ REGISTER_OPERATOR( ops::ElementwiseMulDoubleGradMaker, ops::ElementwiseMulDoubleGradMaker); -REGISTER_OPERATOR(elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, - ops::ElementwiseDoubleGradOpInplaceInferer); +REGISTER_OPERATOR( + elementwise_mul_grad_grad, ops::ElementwiseOpDoubleGrad, + ops::ElementwiseDoubleGradOpInplaceInferer, + ops::ElementwiseMulTripleGradMaker, + ops::ElementwiseMulTripleGradMaker); + +REGISTER_OPERATOR(elementwise_mul_triple_grad, ops::ElementwiseOpTripleGrad); REGISTER_OP_CPU_KERNEL( elementwise_mul, @@ -164,6 +198,22 @@ REGISTER_OP_CPU_KERNEL( paddle::platform::complex>, ops::ElementwiseMulDoubleGradKernel>); +REGISTER_OP_CPU_KERNEL( + elementwise_mul_triple_grad, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel>, + ops::ElementwiseMulTripleGradKernel>); REGISTER_OP_VERSION(elementwise_mul) .AddCheckpoint( diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu index 1a9ac4bd9157fa6fb50b538444666d610398ead8..375063813ede8addc095c7d8a32d429740446e94 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cu +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cu @@ -141,3 +141,15 @@ REGISTER_OP_CUDA_KERNEL( plat::complex>, ops::ElementwiseMulDoubleGradKernel>); +REGISTER_OP_CUDA_KERNEL( + elementwise_mul_triple_grad, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel, + ops::ElementwiseMulTripleGradKernel>, + ops::ElementwiseMulTripleGradKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 80fa430c44307c0b1bc9fe67d211214078655f07..211bf6e3fb539dc486a264b1774f6d320c6970d5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -283,5 +283,96 @@ class ElementwiseMulDoubleGradKernel : public framework::OpKernel { } }; +template +class ElementwiseMulTripleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + // get input + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dout = ctx.Input("DOut"); + auto* ddx = ctx.Input("DDX"); + auto* ddy = ctx.Input("DDY"); + + auto* d_dx = ctx.Input("D_DX"); + auto* d_dy = ctx.Input("D_DY"); + auto* d_ddout = ctx.Input("D_DDOut"); + + // get output + auto* out_d_x = ctx.Output("D_X"); + auto* out_d_y = ctx.Output("D_Y"); + auto* out_d_dout = ctx.Output("D_DOut"); + + auto* out_d_ddx = ctx.Output("D_DDX"); + auto* out_d_ddy = ctx.Output("D_DDY"); + + if (out_d_x) out_d_x->mutable_data(x->dims(), ctx.GetPlace()); + if (out_d_y) out_d_y->mutable_data(y->dims(), ctx.GetPlace()); + if (out_d_dout) out_d_dout->mutable_data(dout->dims(), ctx.GetPlace()); + if (out_d_ddx) out_d_ddx->mutable_data(x->dims(), ctx.GetPlace()); + if (out_d_ddy) out_d_ddy->mutable_data(y->dims(), ctx.GetPlace()); + + auto& place = *ctx.template device_context().eigen_device(); + + Tensor ddx_safe, ddy_safe; + GetDoubleGradSafeTensor(ctx, x, ddx, &ddx_safe); + GetDoubleGradSafeTensor(ctx, y, ddy, &ddy_safe); + + if (d_ddout) { + if (out_d_x) { + // out_d_x = ddy * d_ddout + default_elementwise_mul(ctx, &ddy_safe, d_ddout, + out_d_x); + } + if (out_d_y) { + // out_d_y = ddx * d_ddout + default_elementwise_mul(ctx, &ddx_safe, d_ddout, + out_d_y); + } + } + + if (out_d_dout) { + // get out_d_dout + // out_d_dout = ddy * d_dx + d_dy * ddx + Tensor out_d_dout_tmp; + out_d_dout_tmp.mutable_data(dout->dims(), ctx.GetPlace()); + default_elementwise_mul(ctx, d_dy, &ddx_safe, + out_d_dout); + default_elementwise_mul(ctx, &ddy_safe, d_dx, + &out_d_dout_tmp); + auto out_d_dout_t = framework::EigenVector::Flatten(*out_d_dout); + auto out_d_dout_tmp_t = + framework::EigenVector::Flatten(out_d_dout_tmp); + out_d_dout_t.device(place) = out_d_dout_t + out_d_dout_tmp_t; + } + + if (out_d_ddx) { + // get out_d_ddx + // out_d_ddx = dout * d_dy + y * d_ddout + Tensor out_d_ddx_tmp; + out_d_ddx_tmp.mutable_data(ddx->dims(), ctx.GetPlace()); + default_elementwise_mul(ctx, dout, d_dy, out_d_ddx); + default_elementwise_mul(ctx, y, d_ddout, + &out_d_ddx_tmp); + auto out_d_ddx_t = framework::EigenVector::Flatten(*out_d_ddx); + auto out_d_ddx_tmp_t = framework::EigenVector::Flatten(out_d_ddx_tmp); + out_d_ddx_t.device(place) = out_d_ddx_t + out_d_ddx_tmp_t; + } + + if (out_d_ddy) { + // get out_d_ddy + // out_d_ddy = dout * d_dx + x * d_ddout + Tensor out_d_ddy_tmp; + out_d_ddy_tmp.mutable_data(ddy->dims(), ctx.GetPlace()); + default_elementwise_mul(ctx, dout, d_dx, out_d_ddy); + default_elementwise_mul(ctx, x, d_ddout, + &out_d_ddy_tmp); + auto out_d_ddy_t = framework::EigenVector::Flatten(*out_d_ddy); + auto out_d_ddy_tmp_t = framework::EigenVector::Flatten(out_d_ddy_tmp); + out_d_ddy_t.device(place) = out_d_ddy_t + out_d_ddy_tmp_t; + } + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index 90ad8ce682bdb2eb39cf66504b4888e832d9bc51..e7a013e267d2b0780a02f0c70dbc583f9be0c1a9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -451,6 +451,18 @@ class ElementwiseOpTripleGrad : public framework::OperatorWithKernel { ctx->ShareDim("DDY", "D_DDY"); ctx->ShareLoD("DDY", "D_DDY"); } + if (ctx->HasOutput("D_X")) { + ctx->ShareDim("X", "D_X"); + ctx->ShareLoD("X", "D_X"); + } + if (ctx->HasOutput("D_Y")) { + ctx->ShareDim("Y", "D_Y"); + ctx->ShareLoD("Y", "D_Y"); + } + if (ctx->HasOutput("D_DOut")) { + ctx->ShareDim("DOut", "D_DOut"); + ctx->ShareLoD("DOut", "D_DOut"); + } } framework::OpKernelType GetExpectedKernelType( diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py index 0dba2b1924d2490cb0848188da9aff2537cb05c4..c51c8098706a6190426a55cfa21f4d553696387b 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_nn_grad.py @@ -297,5 +297,59 @@ class TestElementwiseAddBroadcastTripleGradCheck(unittest.TestCase): self.func(p) +class TestElementwiseMulTripleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable should be clearly specified, not inlcude -1. + shape = [2, 3, 4, 5] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape, False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_mul(x, y) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape).astype(dtype) + + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + +class TestElementwiseMulBroadcastTripleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable should be clearly specified, not inlcude -1. + shape = [2, 3, 4, 5] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', shape, False, dtype) + y = layers.data('y', shape[:-1], False, dtype) + x.persistable = True + y.persistable = True + out = layers.elementwise_add(x, y, axis=0) + x_arr = np.random.uniform(-1, 1, shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, shape[:-1]).astype(dtype) + + gradient_checker.triple_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main()