diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index 8375d988045dc24fa1109646b46ff477e2a78132..6825dce332adc0dc11dda187d1bd367875b8603e 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -17,7 +17,9 @@ namespace paddle { namespace operators { -class RowWiseAddOp : public framework::OperatorWithKernel { +using framework::Tensor; + +class RowwiseAddOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -34,9 +36,9 @@ class RowWiseAddOp : public framework::OperatorWithKernel { } }; -class RowWiseAddOpMaker : public framework::OpProtoAndCheckerMaker { +class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker { public: - RowWiseAddOpMaker(framework::OpProto *proto, + RowwiseAddOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "The left input of row-wise add op, must be matrix"); @@ -49,12 +51,32 @@ for i in xrange(X.shape[0]): )DOC"); } }; +class RowwiseAddGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "X should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto dims0 = ctx.Input("X")->dims(); + auto dims1 = ctx.Input("b")->dims(); + PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1") + ctx.Output(framework::GradVarName("X"))->Resize(dims0); + ctx.Output(framework::GradVarName("b"))->Resize(dims1); + } +}; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, - ops::RowWiseAddOpMaker); +REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker, + rowwise_add_grad, ops::RowwiseAddGradOp); +REGISTER_OP_CPU_KERNEL( + rowwise_add, ops::RowwiseAddKernel); REGISTER_OP_CPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add_grad, + ops::RowwiseAddGradKernel); diff --git a/paddle/operators/rowwise_add_op.cu b/paddle/operators/rowwise_add_op.cu index 86f80b81228a69ac4c05a4693901570f2b9966e0..cbc61ad3e117fc79a674ca21831d3fec59d1ec5b 100644 --- a/paddle/operators/rowwise_add_op.cu +++ b/paddle/operators/rowwise_add_op.cu @@ -17,4 +17,4 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - rowwise_add, ops::RowWiseAddKernel); + rowwise_add, ops::RowwiseAddKernel); diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h index 01f88f2198774fbaa4c98ff9bf286f2f08496a9a..232135c38de68d4002e044972b282b43a1374c72 100644 --- a/paddle/operators/rowwise_add_op.h +++ b/paddle/operators/rowwise_add_op.h @@ -28,7 +28,7 @@ template ; template -class RowWiseAddKernel : public framework::OpKernel { +class RowwiseAddKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto out = context.Output("Out"); @@ -47,5 +47,25 @@ class RowWiseAddKernel : public framework::OpKernel { } }; +template +class RowwiseAddGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* dOut = context.Input(framework::GradVarName("Out")); + auto* dX = context.Output(framework::GradVarName("X")); + auto* db = context.Output(framework::GradVarName("b")); + dX->mutable_data(context.GetPlace()); + db->mutable_data(context.GetPlace()); + + auto OutGrad = EigenMatrix::From(*dOut); + auto place = context.GetEigenDevice(); + EigenMatrix::From(*dX).device(place) = OutGrad; + + // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html + // colwise add + Eigen::array dims{{1}}; /* dimension to reduce */ + EigenVector::Flatten(*db).device(place) = OutGrad.sum(dims); + } +}; } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py index f8521eb517057fbeb104b28af7da4fffe54f37de..29d72e850099734a9828ccceed47cc0a57fc3d6b 100644 --- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from op_test_util import OpTestMeta +from gradient_checker import GradientChecker, create_op class TestRowwiseAddOp(unittest.TestCase): @@ -15,5 +16,15 @@ class TestRowwiseAddOp(unittest.TestCase): self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])} +class RowwiseAddGradOpTest(GradientChecker): + def test_rowwise_add(self): + op = create_op("rowwise_add") + inputs = { + "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), + "b": np.random.uniform(0.1, 1, [10]).astype("float32") + } + self.check_grad(op, inputs, set(["X", "b"]), "Out") + + if __name__ == '__main__': unittest.main()