From d71396bf870966a14638d5ea108b2b2f8babbe2f Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 5 Sep 2017 15:20:00 -0700 Subject: [PATCH] Add global function `flatten_to_2d()` --- paddle/framework/ddim.cc | 14 +++++++------ paddle/framework/ddim.h | 4 ++-- paddle/framework/eigen.h | 7 ++----- paddle/framework/tensor_impl.h | 6 +----- paddle/operators/mul_op.cc | 36 +++++++++++++++++++--------------- 5 files changed, 33 insertions(+), 34 deletions(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index c32d66f41c0..47d1f301167 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -247,12 +247,6 @@ ssize_t product(const DDim& ddim) { return boost::apply_visitor(visitor, ddim); } -ssize_t product(const DDim& ddim, int begin, int end) { - ProductVisitor visitor; - DDim sliced_ddim = slice_ddim(ddim, begin, end); - return boost::apply_visitor(visitor, sliced_ddim); -} - /// \cond HIDDEN struct ArityVisitor : boost::static_visitor { @@ -289,5 +283,13 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDim::DDim(std::initializer_list init_list) { *this = make_ddim(init_list); } + +DDim flatten_to_2d(const DDim& src, int num_row_dims) { + int rank = src.size(); + return make_ddim( + {static_cast(product(slice_ddim(src, 0, rank - num_row_dims))), + static_cast(product(slice_ddim(src, rank - num_row_dims, rank)))}); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 7a02af6b8ac..cf786d140ed 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -96,8 +96,6 @@ std::vector vectorize(const DDim& ddim); ssize_t product(const DDim& ddim); -ssize_t product(const DDim& ddim, int begin, int end); - /** * \brief Slice a ddim * @@ -117,6 +115,8 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +DDim flatten_to_2d(const DDim& src, int num_row_dims); + } // namespace framework } // namespace paddle diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index 47551634a64..656aef42127 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -68,11 +68,8 @@ struct EigenMatrix : public EigenTensor { int rank = tensor.dims_.size(); PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, "`num_row_dims` must be between (0, rank_of_tensor)."); - return EigenMatrix::From( - tensor, make_ddim({static_cast( - product(tensor.dims_, 0, rank - num_row_dims)), - static_cast(product( - tensor.dims_, rank - num_row_dims, rank))})); + return EigenMatrix::From(tensor, + flatten_to_2d(tensor.dims(), num_row_dims)); } }; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7c47c389a1f..d32fe78f425 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -152,11 +152,7 @@ template inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { Tensor res; res.ShareDataWith(src); - DDim src_dim = src.dims(); - int rank = src_dim.size(); - res.Resize(make_ddim( - {static_cast(product(src_dim, 0, rank - num_row_dims)), - static_cast(product(src_dim, rank - num_row_dims, rank))})); + res.Resize(flatten_to_2d(src.dims(), num_row_dims)); return res; } diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 935fe889e5f..dfc22decdc7 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -25,27 +25,27 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto x_dim = ctx.Input("X")->dims(); - auto y_dim = ctx.Input("Y")->dims(); + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); int x_num_row_dims = GetAttr("x_num_row_dims"); int y_num_row_dims = GetAttr("y_num_row_dims"); - PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, + PADDLE_ENFORCE(x_dims.size() > x_num_row_dims, "The rank of input tensor X(%s) should be larger than " "`mul_op`'s `x_num_row_dims`.", ctx.op().Input("X")); - PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, + PADDLE_ENFORCE(y_dims.size() > y_num_row_dims, "The rank of input tensor Y(%s) should be larger than " "`mul_op`'s `y_num_row_dims`.", ctx.op().Input("Y")); + + auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_row_dims); + auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_row_dims); + PADDLE_ENFORCE_EQ( - product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), - product(y_dim, 0, y_dim.size() - y_num_row_dims), + x_mat_dims[1], y_mat_dims[0], "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize( - {static_cast(product(x_dim, 0, x_dim.size() - x_num_row_dims)), - static_cast( - product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size()))}); + ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]}); } }; @@ -96,14 +96,18 @@ class MulOpGrad : public framework::OperatorWithKernel { auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - PADDLE_ENFORCE( - product(x_dims, 0, x_dims.size() - GetAttr("x_num_row_dims")) == - out_dims[0], + + auto x_mat_dims = + framework::flatten_to_2d(x_dims, GetAttr("x_num_row_dims")); + auto y_mat_dims = + framework::flatten_to_2d(y_dims, GetAttr("y_num_row_dims")); + + PADDLE_ENFORCE_EQ( + x_mat_dims[0], out_dims[0], "The first dimension of Out@GRAD must equal to the first dimension of " "the first operand."); - PADDLE_ENFORCE( - product(y_dims, y_dims.size() - GetAttr("y_num_row_dims"), - y_dims.size()) == out_dims[1], + PADDLE_ENFORCE_EQ( + y_mat_dims[1], out_dims[1], "The second dimension of Out@GRAD must equal to the second " "dimension of the second operand."); -- GitLab