提交 af0264aa 编写于 作者: F fengjiayi

Add global function `FalttenToMatrix` and add `axis` for MulOp

上级 86655cb9
...@@ -27,24 +27,25 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -27,24 +27,25 @@ class MulOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims(); auto y_dim = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("X_num_raw_dims"); int x_num_row_dims = GetAttr<int>("x_num_row_dims");
int y_num_row_dims = GetAttr<int>("Y_num_raw_dims"); int y_num_row_dims = GetAttr<int>("y_num_row_dims");
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than " "The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `X_num_raw_dims`.", "`mul_op`'s `x_num_row_dims`.",
ctx.op().Input("X")); ctx.op().Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than " "The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `Y_num_raw_dims`.", "`mul_op`'s `y_num_row_dims`.",
ctx.op().Input("Y")); ctx.op().Input("Y"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims), product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize( ctx.Output<Tensor>("Out")->Resize(
{product(x_dim, 0, x_dim.size() - x_num_row_dims), {static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())}); static_cast<int>(
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
} }
}; };
...@@ -96,10 +97,12 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -96,10 +97,12 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE( PADDLE_ENFORCE(
product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0], product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) ==
out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of " "The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand."); "the first operand.");
PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims, PADDLE_ENFORCE(
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"),
y_dims.size()) == out_dims[1], y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second " "The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand."); "dimension of the second operand.");
......
...@@ -31,13 +31,25 @@ template <typename Place, typename T> ...@@ -31,13 +31,25 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X"); const Tensor* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y"); const Tensor* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out"); Tensor* Z = context.Output<Tensor>("Out");
const Tensor X_matrix =
X->dims().size() > 2
? framework::FlattenToMatrix<T>(
*X, context.template GetAttr<int>("x_num_row_dims"))
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2
? framework::FlattenToMatrix<T>(
*Y, context.template GetAttr<int>("y_num_row_dims"))
: *Y;
Z->mutable_data<T>(context.GetPlace()); Z->mutable_data<T>(context.GetPlace());
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_); const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context); math::matmul<Place, T>(X_matrix, false, Y_matrix, false, 1, Z, 0,
device_context);
} }
}; };
...@@ -45,20 +57,36 @@ template <typename Place, typename T> ...@@ -45,20 +57,36 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X"); int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims");
auto* Y = ctx.Input<Tensor>("Y"); int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out")); const Tensor* X = ctx.Input<Tensor>("X");
const Tensor* Y = ctx.Input<Tensor>("Y");
const Tensor X_matrix =
X->dims().size() > 2 ? framework::FlattenToMatrix<T>(*X, x_num_row_dims)
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*Y, y_num_row_dims)
: *Y;
const Tensor* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X")); Tensor* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y")); Tensor* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace()); dY->mutable_data<T>(ctx.GetPlace());
Tensor dX_matrix = dX->dims().size() > 2
? framework::FlattenToMatrix<T>(*dX, x_num_row_dims)
: *dX;
Tensor dY_matrix = dY->dims().size() > 2
? framework::FlattenToMatrix<T>(*dY, y_num_row_dims)
: *dY;
auto* device_context = auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_); const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N // dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context); math::matmul<Place, T>(*dOut, false, Y_matrix, true, 1, &dX_matrix, 0,
device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K // dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context); math::matmul<Place, T>(X_matrix, true, *dOut, false, 1, &dY_matrix, 0,
device_context);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册