diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index db81fd555d1c7bea7c0c3bbd70266b4952ed3724..fb79796f36d392d1c5f5914ba7dca4ff5f09edb8 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -54,10 +54,27 @@ The equation is: Out = X * Y class MulOpGrad : public framework::OperatorWithKernel { protected: - void InferShape(const framework::InferShapeContext &ctx) const override {} - std::string DebugString() const override { - LOG(INFO) << "MulGrad"; - return ""; + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, + "Input of MulOpGrad should be 3, X, Y, Out@GRAD"); + PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL, + "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto dim0 = ctx.Input(0)->dims(); + auto dim1 = ctx.Input(1)->dims(); + auto out_dims = ctx.Input(2)->dims(); + PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0], + "Out@GRAD[0] must equal to X[0] * Y[0]"); + PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1], + "Out@GRAD shape must equal to X[1] * Y[1]"); + + x_grad->Resize(dim1); + y_grad->Resize(dim0); } }; @@ -69,3 +86,5 @@ REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_CPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 43debbc21a365a15c914e60e151f7782b82080cb..a81444dbe63edeecedc5d822c65ff56c42b5db90 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,6 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; - REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul_grad, + ops::MulGradKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ab12631c03453a18fbb067e2d12c2bc332acd567..2032a2addd2a34c9ba150355ca390d6ce57c134d 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -46,5 +46,33 @@ class MulKernel : public framework::OpKernel { } }; +template +class MulGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input0 = ctx.Input("X"); + auto* input1 = ctx.Input("Y"); + auto* input2 = ctx.Input(framework::GradVarName("Out")); + + auto* output0 = ctx.Output(0); + auto* output1 = ctx.Output(1); + output0->mutable_data(ctx.GetPlace()); + output1->mutable_data(ctx.GetPlace()); + + auto X = EigenMatrix::From(*input0); + auto Y = EigenMatrix::From(*input1); + auto dOut = EigenMatrix::From(*input2); + auto dX = EigenMatrix::From(*output0); + auto dY = EigenMatrix::From(*output1); + + // dX = Out@G * Y' + // dY = X' * Out@G + auto place = ctx.GetEigenDevice(); + // TODO(dzh,qijun) : need transpose feature of blas library + // Eigen Tensor does not support it very well + // dX.device(place) = dOut.contract(dOut, transpose) + } +}; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index ec0ac99156a546dd3fb7b27778032bece38ab5a9..126a7f398510f12d6ce854c6d7ae75558616b93a 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -15,5 +15,7 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library + if __name__ == '__main__': unittest.main()