提交 d71396bf 编写于 作者: F fengjiayi

Add global function `flatten_to_2d()`

上级 69fbc542
...@@ -247,12 +247,6 @@ ssize_t product(const DDim& ddim) { ...@@ -247,12 +247,6 @@ ssize_t product(const DDim& ddim) {
return boost::apply_visitor(visitor, ddim); return boost::apply_visitor(visitor, ddim);
} }
ssize_t product(const DDim& ddim, int begin, int end) {
ProductVisitor visitor;
DDim sliced_ddim = slice_ddim(ddim, begin, end);
return boost::apply_visitor(visitor, sliced_ddim);
}
/// \cond HIDDEN /// \cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> { struct ArityVisitor : boost::static_visitor<int> {
...@@ -289,5 +283,13 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ...@@ -289,5 +283,13 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list<int> init_list) { DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
DDim flatten_to_2d(const DDim& src, int num_row_dims) {
int rank = src.size();
return make_ddim(
{static_cast<int>(product(slice_ddim(src, 0, rank - num_row_dims))),
static_cast<int>(product(slice_ddim(src, rank - num_row_dims, rank)))});
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -96,8 +96,6 @@ std::vector<int> vectorize(const DDim& ddim); ...@@ -96,8 +96,6 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); ssize_t product(const DDim& ddim);
ssize_t product(const DDim& ddim, int begin, int end);
/** /**
* \brief Slice a ddim * \brief Slice a ddim
* *
...@@ -117,6 +115,8 @@ int arity(const DDim& ddim); ...@@ -117,6 +115,8 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&); std::ostream& operator<<(std::ostream&, const DDim&);
DDim flatten_to_2d(const DDim& src, int num_row_dims);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -68,11 +68,8 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> { ...@@ -68,11 +68,8 @@ struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
int rank = tensor.dims_.size(); int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor)."); "`num_row_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From( return EigenMatrix::From(tensor,
tensor, make_ddim({static_cast<int>( flatten_to_2d(tensor.dims(), num_row_dims));
product(tensor.dims_, 0, rank - num_row_dims)),
static_cast<int>(product(
tensor.dims_, rank - num_row_dims, rank))}));
} }
}; };
......
...@@ -152,11 +152,7 @@ template <typename T> ...@@ -152,11 +152,7 @@ template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) {
Tensor res; Tensor res;
res.ShareDataWith<T>(src); res.ShareDataWith<T>(src);
DDim src_dim = src.dims(); res.Resize(flatten_to_2d(src.dims(), num_row_dims));
int rank = src_dim.size();
res.Resize(make_ddim(
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)),
static_cast<int>(product(src_dim, rank - num_row_dims, rank))}));
return res; return res;
} }
......
...@@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto x_dim = ctx.Input<Tensor>("X")->dims(); auto x_dims = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims(); auto y_dims = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("x_num_row_dims"); int x_num_row_dims = GetAttr<int>("x_num_row_dims");
int y_num_row_dims = GetAttr<int>("y_num_row_dims"); int y_num_row_dims = GetAttr<int>("y_num_row_dims");
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, PADDLE_ENFORCE(x_dims.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than " "The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `x_num_row_dims`.", "`mul_op`'s `x_num_row_dims`.",
ctx.op().Input("X")); ctx.op().Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, PADDLE_ENFORCE(y_dims.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than " "The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `y_num_row_dims`.", "`mul_op`'s `y_num_row_dims`.",
ctx.op().Input("Y")); ctx.op().Input("Y"));
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), x_mat_dims[1], y_mat_dims[0],
product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize( ctx.Output<Tensor>("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
{static_cast<int>(product(x_dim, 0, x_dim.size() - x_num_row_dims)),
static_cast<int>(
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))});
} }
}; };
...@@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(
product(x_dims, 0, x_dims.size() - GetAttr<int>("x_num_row_dims")) == auto x_mat_dims =
out_dims[0], framework::flatten_to_2d(x_dims, GetAttr<int>("x_num_row_dims"));
auto y_mat_dims =
framework::flatten_to_2d(y_dims, GetAttr<int>("y_num_row_dims"));
PADDLE_ENFORCE_EQ(
x_mat_dims[0], out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of " "The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand."); "the first operand.");
PADDLE_ENFORCE( PADDLE_ENFORCE_EQ(
product(y_dims, y_dims.size() - GetAttr<int>("y_num_row_dims"), y_mat_dims[1], out_dims[1],
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second " "The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand."); "dimension of the second operand.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册