From e76fa85cc842a5b9f21e66f64cd34d6f7fa00719 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Mon, 4 Sep 2017 16:55:29 -0700 Subject: [PATCH] WIP --- paddle/framework/attribute.h | 17 ++++++++++ paddle/framework/ddim.cc | 30 ++++++++++------- paddle/framework/ddim.h | 2 ++ paddle/framework/eigen.h | 13 +++++++- paddle/framework/eigen_test.cc | 20 ++++++++++++ paddle/framework/tensor.h | 3 ++ paddle/framework/tensor_impl.h | 12 +++++++ paddle/framework/tensor_test.cc | 13 ++++++++ paddle/operators/mul_op.cc | 57 +++++++++++++++++++++++++-------- 9 files changed, 140 insertions(+), 27 deletions(-) diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 08b47cabd4c..7da34e3f2b6 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -51,6 +51,18 @@ class LargerThanChecker { T lower_bound_; }; +template +class EqualLargerThanChecker { + public: + explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {} + void operator()(T& value) const { + PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail"); + } + + private: + T lower_bound_; +}; + // we can provide users more common Checker, like 'LessThanChecker', // 'BetweenChecker'... @@ -114,6 +126,11 @@ class TypedAttrChecker { return *this; } + TypedAttrChecker& EqualLargerThan(const T& lower_bound) { + value_checkers_.push_back(EqualLargerThanChecker(lower_bound)); + return *this; + } + // we can add more common limits, like LessThan(), Between()... TypedAttrChecker& SetDefault(const T& default_value) { diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index cfd3e8dfdec..c32d66f41c0 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -195,18 +195,6 @@ std::vector vectorize(const DDim& ddim) { return result; } -struct ProductVisitor : public boost::static_visitor { - template - ssize_t operator()(const Dim& dim) { - return product(dim); - } -}; - -ssize_t product(const DDim& ddim) { - ProductVisitor visitor; - return boost::apply_visitor(visitor, ddim); -} - struct SliceVectorizeVisitor : public boost::static_visitor<> { std::vector& vector; int begin; @@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) { return make_ddim(vec); } +struct ProductVisitor : public boost::static_visitor { + template + ssize_t operator()(const Dim& dim) { + return product(dim); + } +}; + +ssize_t product(const DDim& ddim) { + ProductVisitor visitor; + return boost::apply_visitor(visitor, ddim); +} + +ssize_t product(const DDim& ddim, int begin, int end) { + ProductVisitor visitor; + DDim sliced_ddim = slice_ddim(ddim, begin, end); + return boost::apply_visitor(visitor, sliced_ddim); +} + /// \cond HIDDEN struct ArityVisitor : boost::static_visitor { diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 95f294b6273..7a02af6b8ac 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -96,6 +96,8 @@ std::vector vectorize(const DDim& ddim); ssize_t product(const DDim& ddim); +ssize_t product(const DDim& ddim, int begin, int end); + /** * \brief Slice a ddim * diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h index a4667cc51fa..47551634a64 100644 --- a/paddle/framework/eigen.h +++ b/paddle/framework/eigen.h @@ -63,7 +63,18 @@ struct EigenTensor { template -struct EigenMatrix : public EigenTensor {}; +struct EigenMatrix : public EigenTensor { + static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) { + int rank = tensor.dims_.size(); + PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank, + "`num_row_dims` must be between (0, rank_of_tensor)."); + return EigenMatrix::From( + tensor, make_ddim({static_cast( + product(tensor.dims_, 0, rank - num_row_dims)), + static_cast(product( + tensor.dims_, rank - num_row_dims, rank))})); + } +}; template diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc index dc1957691b1..bae82fdb7d4 100644 --- a/paddle/framework/eigen_test.cc +++ b/paddle/framework/eigen_test.cc @@ -108,5 +108,25 @@ TEST(Eigen, Matrix) { } } +TEST(Eigen, MatrixReshape) { + Tensor t; + float* p = + t.mutable_data(make_ddim({2, 3, 6, 4}), platform::CPUPlace()); + for (int i = 0; i < 2 * 3 * 6 * 4; ++i) { + p[i] = static_cast(i); + } + + EigenMatrix::Type em = EigenMatrix::Reshape(t, 2); + + ASSERT_EQ(2 * 3, em.dimension(0)); + ASSERT_EQ(6 * 4, em.dimension(1)); + + for (int i = 0; i < 2 * 3; i++) { + for (int j = 0; j < 6 * 4; j++) { + ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f); + } + } +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 643f8754917..ce938b21437 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -43,6 +43,9 @@ class Tensor { template friend struct EigenTensor; + template + friend struct EigenMatrix; + template friend struct EigenVector; diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h index 7893e233b77..7c47c389a1f 100644 --- a/paddle/framework/tensor_impl.h +++ b/paddle/framework/tensor_impl.h @@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) { inline const DDim& Tensor::dims() const { return dims_; } +template +inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) { + Tensor res; + res.ShareDataWith(src); + DDim src_dim = src.dims(); + int rank = src_dim.size(); + res.Resize(make_ddim( + {static_cast(product(src_dim, 0, rank - num_row_dims)), + static_cast(product(src_dim, rank - num_row_dims, rank))})); + return res; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index 7db38d5caee..cdd68b303c1 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { } #endif } + +TEST(Tensor, FlattenToMatrix) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + int* src_ptr = src.mutable_data(make_ddim({2, 3, 4, 9}), CPUPlace()); + for (int i = 0; i < 2 * 3 * 4 * 9; ++i) { + src_ptr[i] = i; + } + Tensor res = FlattenToMatrix(src, 2); + ASSERT_EQ(res.dims()[0], 2 * 3); + ASSERT_EQ(res.dims()[1], 4 * 9); +} \ No newline at end of file diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 559d19e6bdc..f668008a10f 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto dim0 = ctx.Input("X")->dims(); - auto dim1 = ctx.Input("Y")->dims(); - PADDLE_ENFORCE_EQ(dim0.size(), 2, - "input X(%s) should be a tensor with 2 dims, a matrix", - ctx.op_.Input("X")); - PADDLE_ENFORCE_EQ(dim1.size(), 2, - "input Y(%s) should be a tensor with 2 dims, a matrix", - ctx.op_.Input("Y")); + auto x_dim = ctx.Input("X")->dims(); + auto y_dim = ctx.Input("Y")->dims(); + int x_num_row_dims = GetAttr("X_num_raw_dims"); + int y_num_row_dims = GetAttr("Y_num_raw_dims"); + + PADDLE_ENFORCE(x_dim.size() > x_num_row_dims, + "The rank of input tensor X(%s) should be larger than " + "`mul_op`'s `X_num_raw_dims`.", + ctx.op_.Input("X")); + PADDLE_ENFORCE(y_dim.size() > y_num_row_dims, + "The rank of input tensor Y(%s) should be larger than " + "`mul_op`'s `Y_num_raw_dims`.", + ctx.op_.Input("Y")); PADDLE_ENFORCE_EQ( - dim0[1], dim1[0], + product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()), + product(y_dim, 0, y_dim.size() - y_num_row_dims), "First matrix's width must be equal with second matrix's height."); - ctx.Output("Out")->Resize({dim0[0], dim1[1]}); + ctx.Output("Out")->Resize( + {product(x_dim, 0, x_dim.size() - x_num_row_dims), + product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())}); } }; @@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The first input of mul op"); AddInput("Y", "The second input of mul op"); AddOutput("Out", "The output of mul op"); + AddAttr( + "x_num_row_dims", + "mul_op can take tensors with more than two dimensions as input `X`, " + "in that case, tensors will be flattened to a matrix. The matrix's " + "second dimension(row length) will be the product of tensor's last " + "`num_row_dims` dimensions, and the matrix's first dimension(column " + "length) will be the product of tensor's first `rank - num_row_dims` " + "dimensions.") + .SetDefault(1) + .EqualLargerThan(1); + AddAttr( + "y_num_row_dims", + "mul_op can take tensors with more than two dimensions as input `Y`, " + "in that case, tensors will be flattened to a matrix. Just like input " + "`X`.") + .SetDefault(1) + .EqualLargerThan(1); AddComment(R"DOC( Two Element Mul Operator. @@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); auto *x_grad = ctx.Output(framework::GradVarName("X")); auto *y_grad = ctx.Output(framework::GradVarName("Y")); - PADDLE_ENFORCE(x_dims[0] == out_dims[0], - "Out@GRAD M X N must equal to X dims 0, M "); - PADDLE_ENFORCE(y_dims[1] == out_dims[1], - "Out@GRAD M X N must equal to Y dims 1, N "); + PADDLE_ENFORCE( + product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0], + "The first dimension of Out@GRAD must equal to the first dimension of " + "the first operand."); + PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims, + y_dims.size()) == out_dims[1], + "The second dimension of Out@GRAD must equal to the second " + "dimension of the second operand."); x_grad->Resize(x_dims); y_grad->Resize(y_dims); -- GitLab