提交 d71396bf 编写于 作者: F fengjiayi

Add global function `flatten_to_2d()`

上级 69fbc542
......@@ -247,12 +247,6 @@ ssize_t product(const DDim& ddim) {
return boost::apply_visitor(visitor, ddim);
}
ssize_t product(const DDim& ddim, int begin, int end) {
ProductVisitor visitor;
DDim sliced_ddim = slice_ddim(ddim, begin, end);
return boost::apply_visitor(visitor, sliced_ddim);
}
/// \cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> {
......@@ -289,5 +283,13 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim& src, int num_row_dims) {
int rank = src.size();
return make_ddim(
{static_cast<int>(product(slice_ddim(src, 0, rank - num_row_dims))),
static_cast<int>(product(slice_ddim(src, rank - num_row_dims, rank)))});
}
} // namespace framework
} // namespace paddle
......@@ -96,8 +96,6 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim);
ssize_t product(const DDim& ddim, int begin, int end);
/**
* \brief Slice a ddim
*
......@@ -117,6 +115,8 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
DDim flatten_to_2d(const DDim& src, int num_row_dims);
} // namespace framework
} // namespace paddle
......
......@@ -68,11 +68,8 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(
tensor, make_ddim({static_cast<int>(
product(tensor.dims_, 0, rank - num_row_dims)),
static_cast<int>(product(
tensor.dims_, rank - num_row_dims, rank))}));
return EigenMatrix::From(tensor,
flatten_to_2d(tensor.dims(), num_row_dims));
}
};
......
......@@ -152,11 +152,7 @@ template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) {
Tensor res;
res.ShareDataWith<T>(src);
DDim src_dim = src.dims();
int rank = src_dim.size();
res.Resize(make_ddim(
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)),
static_cast<int>(product(src_dim, rank - num_row_dims, rank))}));
res.Resize(flatten_to_2d(src.dims(), num_row_dims));
return res;
}
......
......@@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dims = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("x_num_row_dims");
int y_num_row_dims = GetAttr<int>("y_num_row_dims");
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
PADDLE_ENFORCE(x_dims.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `x_num_row_dims`.",
ctx.op().Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
PADDLE_ENFORCE(y_dims.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `y_num_row_dims`.",
ctx.op().Input("Y"));
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims);
PADDLE_ENFORCE_EQ(
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims),
x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize(
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
static_cast<int>(
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
......@@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) ==
out_dims[0],
auto x_mat_dims =
framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_row_dims"));
auto y_mat_dims =
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_row_dims"));
PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand.");
PADDLE_ENFORCE(
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"),
y_dims.size()) == out_dims[1],
PADDLE_ENFORCE_EQ(
y_mat_dims[1], out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册