From 12ee5014857e751fb429e0d3ebcfd41dcd5da29d Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 14 Aug 2017 20:57:46 +0800 Subject: [PATCH] "fix operator grad config" --- paddle/operators/rowwise_add_op.cc | 23 +++++++++++++++++------ paddle/operators/rowwise_add_op.h | 21 +++++++++++---------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 0c6ae64d0c..60e5d7749c 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,6 +17,8 @@ namespace paddle { namespace operators { +using framework::Tensor; + class RowwiseAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -50,14 +52,23 @@ for i in xrange(X.shape[0]): } }; class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE(ctx.InputSize() == 4UL, - "RowwiseAddGrad inputs is I, O, OG, size must be 4"); - PADDLE_ENFORCE(ctx.OutputSize() == 2, - "RowwiseAddGrad output is IG, size must be 2"); - ctx.Output(0)->Resize(ctx.Input(0)->dims()); - ctx.Output(1)->Resize(ctx.Input(1)->dims()); + // PADDLE_ENFORCE(ctx.InputSize() == 4UL, + // "RowwiseAddGrad inputs is I, O, OG, size must be 4"); + // PADDLE_ENFORCE(ctx.OutputSize() == 2, + // "RowwiseAddGrad output is IG, size must be 2"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); } }; diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 3ad60172c1..6593d811e4 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -51,19 +51,20 @@ template class RowwiseAddGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* XGrad = context.Output(0); - auto* bGrad = context.Output(1); - XGrad->mutable_data(context.GetPlace()); - bGrad->mutable_data(context.GetPlace()); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + auto* dOut = context.Output(framework::GradVarName("Out")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); - // I, O, OG => [X, b], [Out], [OutGrad] - auto OutGrad = EigenMatrix::From(*context.Input(3)); - EigenMatrix::From(*XGrad).device(context.GetEigenDevice()) = - OutGrad; + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html - EigenVector::Flatten(*bGrad).device(context.GetEigenDevice()) = - OutGrad.cumsum(1); // colwise add + // colwise add + Eigen::array dims{{1}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); } }; } // namespace operators -- GitLab