From af0264aa6b420f1401823792854c3a5c1e889cd2 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 10:50:58 -0700 Subject: [PATCH] Add global function `FalttenToMatrix` and add `axis` for MulOp --- paddle/operators/mul_op.cc | 25 ++++++++++--------- paddle/operators/mul_op.h | 50 +++++++++++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 22 deletions(-) diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index d301a8619f9..be1782bb6b9 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -27,24 +27,25 @@ class MulOp : public framework::OperatorWithKernel { void InferShape(const framework::InferShapeContext &ctx) const override { auto x_dim = ctx.Input("X")->dims(); auto y_dim = ctx.Input("Y")->dims(); - int x_num_row_dims = GetAttr("X_num_raw_dims"); - int y_num_row_dims = GetAttr("Y_num_raw_dims"); + int x_num_row_dims = GetAttr("x_num_row_dims"); + int y_num_row_dims = GetAttr("y_num_row_dims"); PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, "The rank of input tensor X(%s) should be larger than " - "`mul_op`'s `X_num_raw_dims`.", + "`mul_op`'s `x_num_row_dims`.", ctx.op().Input("X")); PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, "The rank of input tensor Y(%s) should be larger than " - "`mul_op`'s `Y_num_raw_dims`.", + "`mul_op`'s `y_num_row_dims`.", ctx.op().Input("Y")); PADDLE_ENFORCE_EQ( product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), product(y_dim, 0, y_dim.size() - y_num_row_dims), "First matrix's width must be equal with second matrix's height."); ctx.Output("Out")->Resize( - {product(x_dim, 0, x_dim.size() - x_num_row_dims), - product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())}); + {static_cast(product(x_dim, 0, x_dim.size() - x_num_row_dims)), + static_cast( + product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))}); } }; @@ -96,13 +97,15 @@ class MulOpGrad : public framework::OperatorWithKernel { auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); PADDLE_ENFORCE( - product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0], + product(x_dims, 0, x_dims.size() - GetAttr("x_num_row_dims")) == + out_dims[0], "The first dimension of Out@GRAD must equal to the first dimension of " "the first operand."); - PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims, - y_dims.size()) == out_dims[1], - "The second dimension of Out@GRAD must equal to the second " - "dimension of the second operand."); + PADDLE_ENFORCE( + product(y_dims, y_dims.size() - GetAttr("y_num_row_dims"), + y_dims.size()) == out_dims[1], + "The second dimension of Out@GRAD must equal to the second " + "dimension of the second operand."); x_grad->Resize(x_dims); y_grad->Resize(y_dims); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 8facc028144..73a53798e0c 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,13 +31,25 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* Z = context.Output("Out"); + const Tensor* X = context.Input("X"); + const Tensor* Y = context.Input("Y"); + Tensor* Z = context.Output("Out"); + const Tensor X_matrix = + X->dims().size() > 2 + ? framework::FlattenToMatrix( + *X, context.template GetAttr("x_num_row_dims")) + : *X; + const Tensor Y_matrix = + Y->dims().size() > 2 + ? framework::FlattenToMatrix( + *Y, context.template GetAttr("y_num_row_dims")) + : *Y; + Z->mutable_data(context.GetPlace()); auto* device_context = const_cast(context.device_context_); - math::matmul(*X, false, *Y, false, 1, Z, 0, device_context); + math::matmul(X_matrix, false, Y_matrix, false, 1, Z, 0, + device_context); } }; @@ -45,20 +57,36 @@ template class MulGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* X = ctx.Input("X"); - auto* Y = ctx.Input("Y"); - auto* dOut = ctx.Input(framework::GradVarName("Out")); + int x_num_row_dims = ctx.template GetAttr("x_num_row_dims"); + int y_num_row_dims = ctx.template GetAttr("y_num_row_dims"); + const Tensor* X = ctx.Input("X"); + const Tensor* Y = ctx.Input("Y"); + const Tensor X_matrix = + X->dims().size() > 2 ? framework::FlattenToMatrix(*X, x_num_row_dims) + : *X; + const Tensor Y_matrix = + Y->dims().size() > 2 ? framework::FlattenToMatrix(*Y, y_num_row_dims) + : *Y; + const Tensor* dOut = ctx.Input(framework::GradVarName("Out")); - auto* dX = ctx.Output(framework::GradVarName("X")); - auto* dY = ctx.Output(framework::GradVarName("Y")); + Tensor* dX = ctx.Output(framework::GradVarName("X")); + Tensor* dY = ctx.Output(framework::GradVarName("Y")); dX->mutable_data(ctx.GetPlace()); dY->mutable_data(ctx.GetPlace()); + Tensor dX_matrix = dX->dims().size() > 2 + ? framework::FlattenToMatrix(*dX, x_num_row_dims) + : *dX; + Tensor dY_matrix = dY->dims().size() > 2 + ? framework::FlattenToMatrix(*dY, y_num_row_dims) + : *dY; auto* device_context = const_cast(ctx.device_context_); // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N - math::matmul(*dOut, false, *Y, true, 1, dX, 0, device_context); + math::matmul(*dOut, false, Y_matrix, true, 1, &dX_matrix, 0, + device_context); // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K - math::matmul(*X, true, *dOut, false, 1, dY, 0, device_context); + math::matmul(X_matrix, true, *dOut, false, 1, &dY_matrix, 0, + device_context); } }; -- GitLab