提交 e76fa85c 编写于 作者: F fengjiayi

WIP

上级 5e78359f
......@@ -51,6 +51,18 @@ class LargerThanChecker {
T lower_bound_;
};
template <typename T>
class EqualLargerThanChecker {
public:
explicit EqualLargerThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
PADDLE_ENFORCE(value >= lower_bound_, "equal_larger_than check fail");
}
private:
T lower_bound_;
};
// we can provide users more common Checker, like 'LessThanChecker',
// 'BetweenChecker'...
......@@ -114,6 +126,11 @@ class TypedAttrChecker {
return *this;
}
TypedAttrChecker& EqualLargerThan(const T& lower_bound) {
value_checkers_.push_back(EqualLargerThanChecker<T>(lower_bound));
return *this;
}
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
......
......@@ -195,18 +195,6 @@ std::vector<int> vectorize(const DDim& ddim) {
return result;
}
struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector;
int begin;
......@@ -247,6 +235,24 @@ DDim slice_ddim(const DDim& dim, int begin, int end) {
return make_ddim(vec);
}
struct ProductVisitor : public boost::static_visitor<ssize_t> {
template <int D>
ssize_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
ssize_t product(const DDim& ddim) {
ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim);
}
ssize_t product(const DDim& ddim, int begin, int end) {
ProductVisitor visitor;
DDim sliced_ddim = slice_ddim(ddim, begin, end);
return boost::apply_visitor(visitor, sliced_ddim);
}
/// \cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> {
......
......@@ -96,6 +96,8 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim);
ssize_t product(const DDim& ddim, int begin, int end);
/**
* \brief Slice a ddim
*
......
......@@ -63,7 +63,18 @@ struct EigenTensor {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {};
struct EigenMatrix : public EigenTensor<T, 2, MajorType, IndexType> {
static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_row_dims) {
int rank = tensor.dims_.size();
PADDLE_ENFORCE(num_row_dims > 0 && num_row_dims < rank,
"`num_row_dims` must be between (0, rank_of_tensor).");
return EigenMatrix::From(
tensor, make_ddim({static_cast<int>(
product(tensor.dims_, 0, rank - num_row_dims)),
static_cast<int>(product(
tensor.dims_, rank - num_row_dims, rank))}));
}
};
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
......
......@@ -108,5 +108,25 @@ TEST(Eigen, Matrix) {
}
}
TEST(Eigen, MatrixReshape) {
Tensor t;
float* p =
t.mutable_data<float>(make_ddim({2, 3, 6, 4}), platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
p[i] = static_cast<float>(i);
}
EigenMatrix<float>::Type em = EigenMatrix<float>::Reshape(t, 2);
ASSERT_EQ(2 * 3, em.dimension(0));
ASSERT_EQ(6 * 4, em.dimension(1));
for (int i = 0; i < 2 * 3; i++) {
for (int j = 0; j < 6 * 4; j++) {
ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
}
}
}
} // namespace framework
} // namespace paddle
......@@ -43,6 +43,9 @@ class Tensor {
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
template <typename T, int MajorType, typename IndexType>
friend struct EigenMatrix;
template <typename T, int MajorType, typename IndexType>
friend struct EigenVector;
......
......@@ -148,5 +148,17 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; }
template <typename T>
inline Tensor FlattenToMatrix(const Tensor& src, int num_row_dims) {
Tensor res;
res.ShareDataWith<T>(src);
DDim src_dim = src.dims();
int rank = src_dim.size();
res.Resize(make_ddim(
{static_cast<int>(product(src_dim, 0, rank - num_row_dims)),
static_cast<int>(product(src_dim, rank - num_row_dims, rank))}));
return res;
}
} // namespace framework
} // namespace paddle
......@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
TEST(Tensor, FlattenToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
int* src_ptr = src.mutable_data<int>(make_ddim({2, 3, 4, 9}), CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
Tensor res = FlattenToMatrix<int>(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
\ No newline at end of file
......@@ -25,18 +25,26 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y"));
auto x_dim = ctx.Input<Tensor>("X")->dims();
auto y_dim = ctx.Input<Tensor>("Y")->dims();
int x_num_row_dims = GetAttr<int>("X_num_raw_dims");
int y_num_row_dims = GetAttr<int>("Y_num_raw_dims");
PADDLE_ENFORCE(x_dim.size() > x_num_row_dims,
"The rank of input tensor X(%s) should be larger than "
"`mul_op`'s `X_num_raw_dims`.",
ctx.op_.Input("X"));
PADDLE_ENFORCE(y_dim.size() > y_num_row_dims,
"The rank of input tensor Y(%s) should be larger than "
"`mul_op`'s `Y_num_raw_dims`.",
ctx.op_.Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
product(x_dim, x_dim.size() - x_num_row_dims, x_dim.size()),
product(y_dim, 0, y_dim.size() - y_num_row_dims),
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
ctx.Output<Tensor>("Out")->Resize(
{product(x_dim, 0, x_dim.size() - x_num_row_dims),
product(y_dim, y_dim.size() - y_num_row_dims, y_dim.size())});
}
};
......@@ -47,6 +55,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
AddAttr<int>(
"x_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `X`, "
"in that case, tensors will be flattened to a matrix. The matrix's "
"second dimension(row length) will be the product of tensor's last "
"`num_row_dims` dimensions, and the matrix's first dimension(column "
"length) will be the product of tensor's first `rank - num_row_dims` "
"dimensions.")
.SetDefault(1)
.EqualLargerThan(1);
AddAttr<int>(
"y_num_row_dims",
"mul_op can take tensors with more than two dimensions as input `Y`, "
"in that case, tensors will be flattened to a matrix. Just like input "
"`X`.")
.SetDefault(1)
.EqualLargerThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
......@@ -70,10 +95,14 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE(x_dims[0] == out_dims[0],
"Out@GRAD M X N must equal to X dims 0, M ");
PADDLE_ENFORCE(y_dims[1] == out_dims[1],
"Out@GRAD M X N must equal to Y dims 1, N ");
PADDLE_ENFORCE(
product(x_dim, 0, x_dims.size() - x_num_row_dims) == out_dims[0],
"The first dimension of Out@GRAD must equal to the first dimension of "
"the first operand.");
PADDLE_ENFORCE(product(y_dim, y_dims.size() - y_num_row_dims,
y_dims.size()) == out_dims[1],
"The second dimension of Out@GRAD must equal to the second "
"dimension of the second operand.");
x_grad->Resize(x_dims);
y_grad->Resize(y_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册