提交 632b320e 编写于 作者: D dongzhihong

"refine argument with new style "

上级 426d7328
...@@ -77,6 +77,15 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a, ...@@ -77,6 +77,15 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta, framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context); platform::DeviceContext* context);
// // matrix multiply with continuous memory
// template <typename Place, typename T>
// void matmul(const framework::Tensor& matrix_a, bool trans_a,
// const framework::Tensor& matrix_b, bool trans_b,
// framework::Tensor* matrix_out,
// platform::DeviceContext* context) {
// matmul(matrix_a, matrix_b, trans_a, trans_b, 1, matrix_out, 0, context);
// }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
class MulOp : public framework::OperatorWithKernel { class MulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -60,19 +62,19 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -60,19 +62,19 @@ class MulOpGrad : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL, // PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
"Input of MulOpGrad should be 3, X, Y, Out@GRAD"); // "Input of MulOpGrad should be 3, X, Y, Out@GRAD");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL, // PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL,
"Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"); // "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim0 = ctx.Input<Tensor>(framework::GradVarName("X"))->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims(); auto dim1 = ctx.Input<Tensor>(framework::GradVarName("Y"))->dims();
auto out_dims = ctx.Input<Tensor>(2)->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0], PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0],
"Out@GRAD[0] must equal to X[0] * Y[0]"); "Out@GRAD[0] must equal to X[0] * Y[0]");
PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1], PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1],
......
...@@ -31,18 +31,22 @@ template <typename Place, typename T> ...@@ -31,18 +31,22 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { // Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; // {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto* input0 = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y"); auto* Y = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out"); auto* Z = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); Z->mutable_data<T>(context.GetPlace());
auto X = EigenMatrix<T>::From(*input0); auto* device_context =
auto Y = EigenMatrix<T>::From(*input1); const_cast<platform::DeviceContext*>(context.device_context_);
auto Z = EigenMatrix<T>::From(*output); math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
auto& place = context.GetEigenDevice<Place>();
// auto X = EigenMatrix<T>::From(*input0);
Z.device(place) = X.contract(Y, dim_pair); // auto Y = EigenMatrix<T>::From(*input1);
// auto Z = EigenMatrix<T>::From(*output);
// auto& place = context.GetEigenDevice<Place>();
// Z.device(place) = X.contract(Y, dim_pair);
} }
}; };
...@@ -50,27 +54,31 @@ template <typename Place, typename T> ...@@ -50,27 +54,31 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input0 = ctx.Input<Tensor>("X"); auto* X = ctx.Input<Tensor>("X");
auto* input1 = ctx.Input<Tensor>("Y"); auto* Y = ctx.Input<Tensor>("Y");
auto* input2 = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output0 = ctx.Output<Tensor>(0); auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* output1 = ctx.Output<Tensor>(1); auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
output0->mutable_data<T>(ctx.GetPlace()); // auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
output1->mutable_data<T>(ctx.GetPlace()); // auto* dYdata = dY->template mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
auto X = EigenMatrix<T>::From(*input0); // auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1); // auto Y = EigenMatrix<T>::From(*input1);
auto dOut = EigenMatrix<T>::From(*input2); // auto dOut = EigenMatrix<T>::From(*input2);
auto dX = EigenMatrix<T>::From(*output0); // auto dX = EigenMatrix<T>::From(*output0);
auto dY = EigenMatrix<T>::From(*output1); // auto dY = EigenMatrix<T>::From(*output1);
// dX = Out@G * Y' // dX = Out@G * Y'
// dY = X' * Out@G // dY = X' * Out@G
auto place = ctx.GetEigenDevice<Place>(); // auto place = ctx.GetEigenDevice<Place>();
// TODO(dzh,qijun) : need transpose feature of blas library // TODO(dzh,qijun) : need transpose feature of blas library
// Eigen Tensor does not support it very well // Eigen Tensor does not support it very well
// dX.device(place) = dOut.contract(dOut, transpose) // dX.device(place) = matmul(input2, )
} }
}; };
......
import unittest import unittest
from op_test_util import OpTestMeta
import numpy as np import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestMulOp(unittest.TestCase): class TestMulOp(unittest.TestCase):
...@@ -15,6 +16,16 @@ class TestMulOp(unittest.TestCase): ...@@ -15,6 +16,16 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
class MulGradOpTest(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library # TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册