From 264b644718c14da348114bb9a44afddcd7166f11 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Wed, 2 Aug 2017 21:26:29 +0800 Subject: [PATCH] "add rowwise add backward op" --- paddle/operators/rowwise_add_op.cc | 15 +++++++++++++++ paddle/operators/rowwise_add_op.h | 19 +++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 2ad2b66c8..cc763a8cf 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -46,6 +46,17 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowWiseAddGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 4UL, + "RowWiseAddGrad inputs is I, O, OG, size must be 4"); + PADDLE_ENFORCE(ctx.OutputSize() == 2, + "RowWiseAddGrad output is IG, size must be 2"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output(1)->Resize(ctx.Input(1)->dims()); + } +}; } // namespace operators } // namespace paddle @@ -53,3 +64,7 @@ for i in xrange(X.shape[0]): REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP_CPU_KERNEL(rowwise_add, ops::RowWiseAddKernel); + +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp); +REGISTER_OP_CPU_KERNEL(rowwise_add_grad, + ops::RowWiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index b86dd5463..940459e0f 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -38,5 +38,24 @@ public: } }; +template +class RowWiseAddGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto XGrad = context.Output(0); + auto bGrad = context.Output(1); + XGrad->mutable_data(context.GetPlace()); + bGrad->mutable_data(context.GetPlace()); + + // I, O, OG => [X, b], [Out], [OutGrad] + auto OutGrad = EigenMatrix::From(*context.Input(3)); + EigenMatrix::From(*XGrad).device(*(context.GetEigenDevice())) = + OutGrad; + // const int dimension = bGrad.dimension(0); + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + EigenVector::Flatten(*bGrad).device(*(context.GetEigenDevice())) = + OutGrad.cumsum(1); // colwise add + } +}; } // namespace operators } // namespace paddle -- GitLab