From 632b320e9dc11c6991d95187631c311cae7f7162 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 14 Aug 2017 17:19:15 +0800 Subject: [PATCH] "refine argument with new style " --- paddle/operators/math/math_function.h | 9 +++ paddle/operators/mul_op.cc | 20 ++++--- paddle/operators/mul_op.h | 60 +++++++++++-------- .../paddle/v2/framework/tests/test_mul_op.py | 13 +++- 4 files changed, 66 insertions(+), 36 deletions(-) diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 155589fadb3..c7c603929bf 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -77,6 +77,15 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); +// // matrix multiply with continuous memory +// template +// void matmul(const framework::Tensor& matrix_a, bool trans_a, +// const framework::Tensor& matrix_b, bool trans_b, +// framework::Tensor* matrix_out, +// platform::DeviceContext* context) { +// matmul(matrix_a, matrix_b, trans_a, trans_b, 1, matrix_out, 0, context); +// } + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index a1ca66a24d9..d77c0607a08 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -18,6 +18,8 @@ namespace paddle { namespace operators { +using framework::Tensor; + class MulOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -60,19 +62,19 @@ class MulOpGrad : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, - "Input of MulOpGrad should be 3, X, Y, Out@GRAD"); - PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL, - "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"); + // PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, + // "Input of MulOpGrad should be 3, X, Y, Out@GRAD"); + // PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL, + // "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); - auto *x_grad = ctx.Output(framework::GradVarName("X")); - auto *y_grad = ctx.Output(framework::GradVarName("Y")); - auto dim0 = ctx.Input(0)->dims(); - auto dim1 = ctx.Input(1)->dims(); - auto out_dims = ctx.Input(2)->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + auto *y_grad = ctx.Output(framework::GradVarName("Y")); + auto dim0 = ctx.Input(framework::GradVarName("X"))->dims(); + auto dim1 = ctx.Input(framework::GradVarName("Y"))->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0], "Out@GRAD[0] must equal to X[0] * Y[0]"); PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1], diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ad40e3cf115..279454c7f3f 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,18 +31,22 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto* input0 = context.Input("X"); - auto* input1 = context.Input("Y"); - auto* output = context.Output("Out"); - output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto& place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + // Eigen::array, 1> dim_pair = { + // {Eigen::IndexPair(1, 0)}}; + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* Z = context.Output("Out"); + Z->mutable_data(context.GetPlace()); + auto* device_context = + const_cast(context.device_context_); + math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + + // auto X = EigenMatrix::From(*input0); + // auto Y = EigenMatrix::From(*input1); + // auto Z = EigenMatrix::From(*output); + // auto& place = context.GetEigenDevice(); + + // Z.device(place) = X.contract(Y, dim_pair); } }; @@ -50,27 +54,31 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* input0 = ctx.Input("X"); - auto* input1 = ctx.Input("Y"); - auto* input2 = ctx.Input(framework::GradVarName("Out")); + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); - auto* output0 = ctx.Output(0); - auto* output1 = ctx.Output(1); - output0->mutable_data(ctx.GetPlace()); - output1->mutable_data(ctx.GetPlace()); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + // auto* dXdata = dX->template mutable_data(ctx.GetPlace()); + // auto* dYdata = dY->template mutable_data(ctx.GetPlace()); + auto* device_context = + const_cast(ctx.device_context_); + math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto dOut = EigenMatrix::From(*input2); - auto dX = EigenMatrix::From(*output0); - auto dY = EigenMatrix::From(*output1); + // auto X = EigenMatrix::From(*input0); + // auto Y = EigenMatrix::From(*input1); + // auto dOut = EigenMatrix::From(*input2); + // auto dX = EigenMatrix::From(*output0); + // auto dY = EigenMatrix::From(*output1); // dX = Out@G * Y' // dY = X' * Out@G - auto place = ctx.GetEigenDevice(); + // auto place = ctx.GetEigenDevice(); // TODO(dzh,qijun) : need transpose feature of blas library // Eigen Tensor does not support it very well - // dX.device(place) = dOut.contract(dOut, transpose) + // dX.device(place) = matmul(input2, ) } }; diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py index 126a7f39851..eef5a4f9617 100644 --- a/python/paddle/v2/framework/tests/test_mul_op.py +++ b/python/paddle/v2/framework/tests/test_mul_op.py @@ -1,6 +1,7 @@ import unittest -from op_test_util import OpTestMeta import numpy as np +from gradient_checker import GradientChecker, create_op +from op_test_util import OpTestMeta class TestMulOp(unittest.TestCase): @@ -15,6 +16,16 @@ class TestMulOp(unittest.TestCase): self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} +class MulGradOpTest(GradientChecker): + def test_mul(self): + op = create_op("mul") + inputs = { + 'X': np.random.random((32, 84)).astype("float32"), + 'Y': np.random.random((84, 100)).astype("float32") + } + self.check_grad(op, inputs, set(["X", "Y"]), "Out") + + # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library if __name__ == '__main__': -- GitLab