提交 af0264aa 编写于 作者: F fengjiayi

Add global function `FalttenToMatrix` and add `axis` for MulOp

上级 86655cb9
......@@ -27,24 +27,25 @@ class MulOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("X_num_raw_dims");
int y_num_row_dims = GetAttr<int>("Y_num_raw_dims");
int x_num_row_dims = GetAttr<int>("x_num_row_dims");
int y_num_row_dims = GetAttr<int>("y_num_row_dims");
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `X_num_raw_dims`.",
"`mul_op`'s `x_num_row_dims`.",
ctx.op().Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `Y_num_raw_dims`.",
"`mul_op`'s `y_num_row_dims`.",
ctx.op().Input("Y"));
PADDLE_ENFORCE_EQ(
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize(
{product(x_dim, 0, x_dim.size() - x_num_row_dims),
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())});
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
static_cast<int>(
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
}
};
......@@ -96,13 +97,15 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(
product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0],
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) ==
out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand.");
PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims,
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
PADDLE_ENFORCE(
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"),
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
......
......@@ -31,13 +31,25 @@ template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
const Tensor* X = context.Input<Tensor>("X");
const Tensor* Y = context.Input<Tensor>("Y");
Tensor* Z = context.Output<Tensor>("Out");
const Tensor X_matrix =
X->dims().size() > 2
? framework::FlattenToMatrix<T>(
*X, context.template GetAttr<int>("x_num_row_dims"))
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2
? framework::FlattenToMatrix<T>(
*Y, context.template GetAttr<int>("y_num_row_dims"))
: *Y;
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
math::matmul<Place, T>(X_matrix, false, Y_matrix, false, 1, Z, 0,
device_context);
}
};
......@@ -45,20 +57,36 @@ template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y");
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
int x_num_row_dims = ctx.template GetAttr<int>("x_num_row_dims");
int y_num_row_dims = ctx.template GetAttr<int>("y_num_row_dims");
const Tensor* X = ctx.Input<Tensor>("X");
const Tensor* Y = ctx.Input<Tensor>("Y");
const Tensor X_matrix =
X->dims().size() > 2 ? framework::FlattenToMatrix<T>(*X, x_num_row_dims)
: *X;
const Tensor Y_matrix =
Y->dims().size() > 2 ? framework::FlattenToMatrix<T>(*Y, y_num_row_dims)
: *Y;
const Tensor* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
Tensor* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
Tensor* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
dX->mutable_data<T>(ctx.GetPlace());
dY->mutable_data<T>(ctx.GetPlace());
Tensor dX_matrix = dX->dims().size() > 2
? framework::FlattenToMatrix<T>(*dX, x_num_row_dims)
: *dX;
Tensor dY_matrix = dY->dims().size() > 2
? framework::FlattenToMatrix<T>(*dY, y_num_row_dims)
: *dY;
auto* device_context =
const_cast<platform::DeviceContext*>(ctx.device_context_);
// dX = dOut * Y'. dX: M x K, dOut : M x N, Y : K x N
math::matmul<Place, T>(*dOut, false, *Y, true, 1, dX, 0, device_context);
math::matmul<Place, T>(*dOut, false, Y_matrix, true, 1, &dX_matrix, 0,
device_context);
// dY = X' * dOut. dY: K x N, dOut : M x N, X : M x K
math::matmul<Place, T>(*X, true, *dOut, false, 1, dY, 0, device_context);
math::matmul<Place, T>(X_matrix, true, *dOut, false, 1, &dY_matrix, 0,
device_context);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册