提交 632b320e 编写于 作者: D dongzhihong

"refine argument with new style "

上级 426d7328
......@@ -77,6 +77,15 @@ void matmul(const framework::Tensor& matrix_a, bool trans_a,
framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context);
// // matrix multiply with continuous memory
// template <typename Place, typename T>
// void matmul(const framework::Tensor& matrix_a, bool trans_a,
// const framework::Tensor& matrix_b, bool trans_b,
// framework::Tensor* matrix_out,
// platform::DeviceContext* context) {
// matmul(matrix_a, matrix_b, trans_a, trans_b, 1, matrix_out, 0, context);
// }
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -18,6 +18,8 @@
namespace paddle {
namespace operators {
using framework::Tensor;
class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -60,19 +62,19 @@ class MulOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
"Input of MulOpGrad should be 3, X, Y, Out@GRAD");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL,
"Output of MulOpGrad should be 2, X@GRAD, Y@GRAD");
// PADDLE_ENFORCE_EQ(ctx.InputSize(), 3UL,
// "Input of MulOpGrad should be 3, X, Y, Out@GRAD");
// PADDLE_ENFORCE_EQ(ctx.OutputSize(), 2UL,
// "Output of MulOpGrad should be 2, X@GRAD, Y@GRAD");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto *x_grad = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
auto out_dims = ctx.Input<Tensor>(2)->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto dim0 = ctx.Input<Tensor>(framework::GradVarName("X"))->dims();
auto dim1 = ctx.Input<Tensor>(framework::GradVarName("Y"))->dims();
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
PADDLE_ENFORCE(dim0[0] * dim1[0] == out_dims[0],
"Out@GRAD[0] must equal to X[0] * Y[0]");
PADDLE_ENFORCE(dim0[1] * dim1[1] == out_dims[1],
......
......@@ -31,18 +31,22 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto Z = EigenMatrix<T>::From(*output);
auto& place = context.GetEigenDevice<Place>();
Z.device(place) = X.contract(Y, dim_pair);
// Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
// {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
// auto X = EigenMatrix<T>::From(*input0);
// auto Y = EigenMatrix<T>::From(*input1);
// auto Z = EigenMatrix<T>::From(*output);
// auto& place = context.GetEigenDevice<Place>();
// Z.device(place) = X.contract(Y, dim_pair);
}
};
......@@ -50,27 +54,31 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input0 = ctx.Input<Tensor>("X");
auto* input1 = ctx.Input<Tensor>("Y");
auto* input2 = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* output0 = ctx.Output<Tensor>(0);
auto* output1 = ctx.Output<Tensor>(1);
output0->mutable_data<T>(ctx.GetPlace());
output1->mutable_data<T>(ctx.GetPlace());
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
// auto* dXdata = dX->template mutable_data<T>(ctx.GetPlace());
// auto* dYdata = dY->template mutable_data<T>(ctx.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto dOut = EigenMatrix<T>::From(*input2);
auto dX = EigenMatrix<T>::From(*output0);
auto dY = EigenMatrix<T>::From(*output1);
// auto X = EigenMatrix<T>::From(*input0);
// auto Y = EigenMatrix<T>::From(*input1);
// auto dOut = EigenMatrix<T>::From(*input2);
// auto dX = EigenMatrix<T>::From(*output0);
// auto dY = EigenMatrix<T>::From(*output1);
// dX = Out@G * Y'
// dY = X' * Out@G
auto place = ctx.GetEigenDevice<Place>();
// auto place = ctx.GetEigenDevice<Place>();
// TODO(dzh,qijun) : need transpose feature of blas library
// Eigen Tensor does not support it very well
// dX.device(place) = dOut.contract(dOut, transpose)
// dX.device(place) = matmul(input2, )
}
};
......
import unittest
from op_test_util import OpTestMeta
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestMulOp(unittest.TestCase):
......@@ -15,6 +16,16 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
class MulGradOpTest(GradientChecker):
def test_mul(self):
op = create_op("mul")
inputs = {
'X': np.random.random((32, 84)).astype("float32"),
'Y': np.random.random((84, 100)).astype("float32")
}
self.check_grad(op, inputs, set(["X", "Y"]), "Out")
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册