提交 e76fa85c 编写于 作者: F fengjiayi

WIP

上级 5e78359f
...@@ -51,6 +51,18 @@ class LargerThanChecker { ...@@ -51,6 +51,18 @@ class LargerThanChecker {
T lower_bound_; T lower_bound_;
}; };
template <typename T>
class EqualLargerThanChecker {
public:
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail");
}
private:
T lower_bound_;
};
// we can provide users more common Checker, like 'LessThanChecker', // we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'... // 'BetweenChecker'...
...@@ -114,6 +126,11 @@ class TypedAttrChecker { ...@@ -114,6 +126,11 @@ class TypedAttrChecker {
return *this; return *this;
} }
TypedAttrChecker& EqualLargerThan(const T& lower_bound) {
value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound));
return *this;
}
// we can add more common limits, like LessThan(), Between()... // we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) { TypedAttrChecker& SetDefault(const T& default_value) {
......
...@@ -195,18 +195,6 @@ std::vector<int> vectorize(const DDim& ddim) { ...@@ -195,18 +195,6 @@ std::vector<int> vectorize(const DDim& ddim) {
return result; return result;
} }
struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}
struct SliceVectorizeVisitor : public boost::static_visitor<> { struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector; std::vector<int>& vector;
int begin; int begin;
...@@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) { ...@@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) {
return make_ddim(vec); return make_ddim(vec);
} }
struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}
ssize_t product(const DDim& ddim, int begin, int end) {
ProductVisitor visitor;
DDim sliced_ddim = slice_ddim(ddim, begin, end);
return boost::apply_visitor(visitor, sliced_ddim);
}
/// \cond HIDDEN /// \cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> { struct ArityVisitor : boost::static_visitor<int> {
......
...@@ -96,6 +96,8 @@ std::vector<int> vectorize(const DDim& ddim); ...@@ -96,6 +96,8 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); ssize_t product(const DDim& ddim);
ssize_t product(const DDim& ddim, int begin, int end);
/** /**
* \brief Slice a ddim * \brief Slice a ddim
* *
......
...@@ -63,7 +63,18 @@ struct EigenTensor { ...@@ -63,7 +63,18 @@ struct EigenTensor {
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {}; struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) {
int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(
tensor, make_ddim({static_cast<int>(
product(tensor.dims_, 0, rank - num_row_dims)),
static_cast<int>(product(
tensor.dims_, rank - num_row_dims, rank))}));
}
};
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
......
...@@ -108,5 +108,25 @@ TEST(Eigen, Matrix) { ...@@ -108,5 +108,25 @@ TEST(Eigen, Matrix) {
} }
} }
TEST(Eigen, MatrixReshape) {
Tensor t;
float* p =
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
p[i] = static_cast<float>(i);
}
EigenMatrix<float>::Type em = EigenMatrix<float>::Reshape(t, 2);
ASSERT_EQ(2 * 3, em.dimension(0));
ASSERT_EQ(6 * 4, em.dimension(1));
for (int i = 0; i < 2 * 3; i++) {
for (int j = 0; j < 6 * 4; j++) {
ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
}
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -43,6 +43,9 @@ class Tensor { ...@@ -43,6 +43,9 @@ class Tensor {
template <typename T, size_t D, int MajorType, typename IndexType> template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
template <typename T, int MajorType, typename IndexType>
friend struct EigenMatrix;
template <typename T, int MajorType, typename IndexType> template <typename T, int MajorType, typename IndexType>
friend struct EigenVector; friend struct EigenVector;
......
...@@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) { ...@@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; } inline const DDim& Tensor::dims() const { return dims_; }
template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) {
Tensor res;
res.ShareDataWith<T>(src);
DDim src_dim = src.dims();
int rank = src_dim.size();
res.Resize(make_ddim(
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)),
static_cast<int>(product(src_dim, rank - num_row_dims, rank))}));
return res;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) { ...@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
} }
#endif #endif
} }
TEST(Tensor, FlattenToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
Tensor res = FlattenToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
\ No newline at end of file
...@@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims(); auto x_dim = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims(); auto y_dim = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2, int x_num_row_dims = GetAttr<int>("X_num_raw_dims");
"input X(%s) should be a tensor with 2 dims, a matrix", int y_num_row_dims = GetAttr<int>("Y_num_raw_dims");
ctx.op_.Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2, PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
"input Y(%s) should be a tensor with 2 dims, a matrix", "The rank of input tensor X(%s) should be larger than "
ctx.op_.Input("Y")); "`mul_op`'s `X_num_raw_dims`.",
ctx.op_.Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `Y_num_raw_dims`.",
ctx.op_.Input("Y"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim0[1], dim1[0], product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]}); ctx.Output<Tensor>("Out")->Resize(
{product(x_dim, 0, x_dim.size() - x_num_row_dims),
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())});
} }
}; };
...@@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op"); AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op"); AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op"); AddOutput("Out", "The output of mul op");
AddAttr<int>(
"x_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `X`, "
"in that case, tensors will be flattened to a matrix. The matrix's "
"second dimension(row length) will be the product of tensor's last "
"`num_row_dims` dimensions, and the matrix's first dimension(column "
"length) will be the product of tensor's first `rank - num_row_dims` "
"dimensions.")
.SetDefault(1)
.EqualLargerThan(1);
AddAttr<int>(
"y_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, "
"in that case, tensors will be flattened to a matrix. Just like input "
"`X`.")
.SetDefault(1)
.EqualLargerThan(1);
AddComment(R"DOC( AddComment(R"DOC(
Two Element Mul Operator. Two Element Mul Operator.
...@@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel { ...@@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims(); auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y")); auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims[0] == out_dims[0], PADDLE_ENFORCE(
"Out@GRAD M X N must equal to X dims 0, M "); product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0],
PADDLE_ENFORCE(y_dims[1] == out_dims[1], "The first dimension of Out@GRAD must equal to the first dimension of "
"Out@GRAD M X N must equal to Y dims 1, N "); "the first operand.");
PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims,
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
x_grad->Resize(x_dims); x_grad->Resize(x_dims);
y_grad->Resize(y_dims); y_grad->Resize(y_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册