From 60be66e2c0ff3e127768de566a6e4925c21ffbec Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Tue, 14 May 2019 11:53:05 +0800 Subject: [PATCH] support fc_op double grad (#17317) * add double grad for mul_op. test=develop * fix format. test=develop * fix format. test=develop * fix format. test=develop * refine code. test=develop * remove setzero. test=develop * fix dx/dy init bug. test=develop * fix format. test=develop --- paddle/fluid/operators/mul_op.cc | 54 ++++++++++- paddle/fluid/operators/mul_op.cu.cc | 4 + paddle/fluid/operators/mul_op.h | 91 +++++++++++++++++++ .../fluid/tests/unittests/test_nn_grad.py | 28 ++++++ 4 files changed, 176 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 05afdf5324..6dac9041b6 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mul_op.h" +#include #include +#include #include namespace paddle { @@ -178,16 +180,66 @@ class MulOpGradMaker : public framework::SingleGradOpDescMaker { } }; +class MulDoubleGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("DOut"), "Input(DOut) should not be null"); + + if (ctx->HasOutput("DX")) { + ctx->ShareDim("X", "DX"); + } + if (ctx->HasOutput("DY")) { + ctx->ShareDim("Y", "DY"); + } + if (ctx->HasOutput("DDOut")) { + ctx->ShareDim("DOut", "DDOut"); + } + } +}; + +class MulDoubleGradMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr retv(new framework::OpDesc()); + retv->SetType("mul_grad_grad"); + + retv->SetInput("X", Input("X")); + retv->SetInput("Y", Input("Y")); + retv->SetInput("DOut", Input(framework::GradVarName("Out"))); + retv->SetInput("DDX", OutputGrad(framework::GradVarName("X"))); + retv->SetInput("DDY", OutputGrad(framework::GradVarName("Y"))); + + retv->SetOutput("DDOut", InputGrad(framework::GradVarName("Out"))); + retv->SetOutput("DX", InputGrad("X")); + retv->SetOutput("DY", InputGrad("Y")); + + retv->SetAttrMap(Attrs()); + return retv; + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, ops::MulOpGradMaker); -REGISTER_OPERATOR(mul_grad, ops::MulGradOp); +REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ops::MulDoubleGradMaker); +REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); REGISTER_OP_CPU_KERNEL( mul, ops::MulKernel, ops::MulKernel); REGISTER_OP_CPU_KERNEL( mul_grad, ops::MulGradKernel, ops::MulGradKernel); +REGISTER_OP_CPU_KERNEL( + mul_grad_grad, + ops::MulDoubleGradKernel, + ops::MulDoubleGradKernel); diff --git a/paddle/fluid/operators/mul_op.cu.cc b/paddle/fluid/operators/mul_op.cu.cc index 6c5a83c6a5..6e841712b9 100644 --- a/paddle/fluid/operators/mul_op.cu.cc +++ b/paddle/fluid/operators/mul_op.cu.cc @@ -24,3 +24,7 @@ REGISTER_OP_CUDA_KERNEL( mul_grad, ops::MulGradKernel, ops::MulGradKernel, ops::MulGradKernel); +REGISTER_OP_CUDA_KERNEL( + mul_grad_grad, + ops::MulDoubleGradKernel, + ops::MulDoubleGradKernel); diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index f72824806e..c77eb5c4cc 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/math_function.h" @@ -109,5 +110,95 @@ class MulGradKernel : public framework::OpKernel { } }; +template +class MulDoubleGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto x_mat = x->dims().size() > 2 + ? framework::ReshapeToMatrix(*x, x_num_col_dims) + : static_cast(*x); + auto y_mat = y->dims().size() > 2 + ? framework::ReshapeToMatrix(*y, y_num_col_dims) + : static_cast(*y); + + const int m = framework::flatten_to_2d(x->dims(), x_num_col_dims)[0]; + const int n = framework::flatten_to_2d(y->dims(), y_num_col_dims)[1]; + + auto* dout = ctx.Input("DOut"); + Tensor dout_mat; + dout_mat.ShareDataWith(*dout); + dout_mat.Resize({m, n}); + + auto* ddx = ctx.Input("DDX"); + auto* ddy = ctx.Input("DDY"); + + auto* dx = ctx.Output("DX"); + auto* dy = ctx.Output("DY"); + auto* ddout = ctx.Output("DDOut"); + + Tensor ddout_mat; + if (ddout) { + ddout->set_lod(dout->lod()); + // allocate and reshape ddout + ddout->mutable_data(ctx.GetPlace()); + ddout_mat.ShareDataWith(*ddout); + ddout_mat.Resize({m, n}); + } + + auto& dev_ctx = ctx.template device_context(); + auto blas = math::GetBlas(dev_ctx); + // a flag to specify whether ddout value has been set, if flag + // is false, MatMul beta should be 0 to set ddout, if flag is + // true, MatMul beta should be 1 to add result to ddout. + bool ddout_flag = false; + if (ddx) { + auto ddx_mat = ddx->dims().size() > 2 + ? framework::ReshapeToMatrix(*ddx, x_num_col_dims) + : static_cast(*ddx); + + // dy = ddx' * dout. dy : K x M, ddx' : K x M, dout : M x N + if (dy) { + dy->set_lod(y->lod()); + // allocate and reshape dy + dy->mutable_data(ctx.GetPlace()); + Tensor dy_mat = dy->dims().size() > 2 + ? framework::ReshapeToMatrix(*dy, y_num_col_dims) + : *dy; + blas.MatMul(ddx_mat, true, dout_mat, false, &dy_mat); + } + // ddout1 = ddx * y. ddx : M x K, y : K x N, ddout1 : M x N + if (ddout) { + blas.MatMul(ddx_mat, false, y_mat, false, static_cast(1.0), + &ddout_mat, static_cast(ddout_flag)); + ddout_flag = true; + } + } + if (ddy) { + auto ddy_mat = ddy->dims().size() > 2 + ? framework::ReshapeToMatrix(*ddy, y_num_col_dims) + : static_cast(*ddy); + // dx = dout * ddy'. dout : M x N, ddy' : N x K, dx : M x K + if (dx) { + dx->set_lod(x->lod()); + // allocate and reshape dx + dx->mutable_data(ctx.GetPlace()); + Tensor dx_mat = dx->dims().size() > 2 + ? framework::ReshapeToMatrix(*dx, x_num_col_dims) + : *dx; + blas.MatMul(dout_mat, false, ddy_mat, true, &dx_mat); + } + // ddout2 = x * ddy. x : M x K, ddy : K x N, ddout2 : M x N + if (ddout) { + blas.MatMul(x_mat, false, ddy_mat, false, static_cast(1.0), + &ddout_mat, static_cast(ddout_flag)); + } + } + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_nn_grad.py b/python/paddle/fluid/tests/unittests/test_nn_grad.py index dbfdca0e44..2ef722c913 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_nn_grad.py @@ -193,5 +193,33 @@ class TestElementwiseMulBroadcastDoubleGradCheck(unittest.TestCase): self.func(p) +class TestMulDoubleGradCheck(unittest.TestCase): + @prog_scope() + def func(self, place): + # the shape of input variable shoule be clearly specified, not inlcude -1. + x_shape = [7, 11] + y_shape = [11, 9] + eps = 0.005 + dtype = np.float64 + + x = layers.data('x', x_shape, False, dtype) + x.persistable = True + y = layers.data('y', y_shape, False, dtype) + y.persistable = True + out = layers.mul(x, y) + x_arr = np.random.uniform(-1, 1, x_shape).astype(dtype) + y_arr = np.random.uniform(-1, 1, y_shape).astype(dtype) + + gradient_checker.double_grad_check( + [x, y], out, x_init=[x_arr, y_arr], place=place, eps=eps) + + def test_grad(self): + places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + places.append(fluid.CUDAPlace(0)) + for p in places: + self.func(p) + + if __name__ == "__main__": unittest.main() -- GitLab