diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 2ad2b66c8f385c858eb34c7ea766f168de9c817e..cc763a8cf4565f633ee31b98c2ec7cd583ad77ed 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -46,6 +46,17 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowWiseAddGradOp : public OperatorWithKernel { +protected: + void InferShape(const InferShapeContext &ctx) const override { + PADDLE_ENFORCE(ctx.InputSize() == 4UL, + "RowWiseAddGrad inputs is I, O, OG, size must be 4"); + PADDLE_ENFORCE(ctx.OutputSize() == 2, + "RowWiseAddGrad output is IG, size must be 2"); + ctx.Output(0)->Resize(ctx.Input(0)->dims()); + ctx.Output(1)->Resize(ctx.Input(1)->dims()); + } +}; } // namespace operators } // namespace paddle @@ -53,3 +64,7 @@ for i in xrange(X.shape[0]): REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); REGISTER_OP_CPU_KERNEL(rowwise_add, ops::RowWiseAddKernel); + +REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, ops::RowWiseAddGradOp); +REGISTER_OP_CPU_KERNEL(rowwise_add_grad, + ops::RowWiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index b86dd5463436bf521f9939b1c421b39f11102769..940459e0f176f07e63aef828baf3f9aa599c9a2c 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -38,5 +38,24 @@ public: } }; +template +class RowWiseAddGradKernel : public OpKernel { +public: + void Compute(const ExecutionContext& context) const override { + auto XGrad = context.Output(0); + auto bGrad = context.Output(1); + XGrad->mutable_data(context.GetPlace()); + bGrad->mutable_data(context.GetPlace()); + + // I, O, OG => [X, b], [Out], [OutGrad] + auto OutGrad = EigenMatrix::From(*context.Input(3)); + EigenMatrix::From(*XGrad).device(*(context.GetEigenDevice())) = + OutGrad; + // const int dimension = bGrad.dimension(0); + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + EigenVector::Flatten(*bGrad).device(*(context.GetEigenDevice())) = + OutGrad.cumsum(1); // colwise add + } +}; } // namespace operators } // namespace paddle