提交 039c0bf2 编写于 作者: H hedaoyuan

Add some constructors for generating object that only contains shape (do not contains data).

上级 2a20fdc1
...@@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 }; ...@@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class BufferArg; class BufferArg;
class SequenceArg; class SequenceArg;
class SparseMatrixArg; class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr;
/** /**
* \brief BufferArg used as the argument type of Function. * \brief BufferArg used as the argument type of Function.
...@@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr; ...@@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* 3. SequenceArg for a Buffer of sequence data. * 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix. * 4. SparseMatrixArg for a Buffer of sparse matrix.
* *
* Buffer shape
* For most buffers, the first dimension `shape()[0]` represents
* the size of the mini-batch.
*
* Buffer argType
* There is an ArgType property for the BufferArg used as Function Output. * There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the * Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the * output Buffer or added to the output Buffer is determined by the
...@@ -71,6 +75,14 @@ public: ...@@ -71,6 +75,14 @@ public:
ArgType getArgType() const { return argType_; } ArgType getArgType() const { return argType_; }
public: public:
BufferArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(nullptr),
valueType_(valueType),
shape_(shape),
argType_(argType) {}
BufferArg(void* buf, BufferArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
...@@ -170,6 +182,12 @@ protected: ...@@ -170,6 +182,12 @@ protected:
// if a < b then value_.buf_[a] < value_.buf_[b] // if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg { class SequenceIdArg : public BufferArg {
public: public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1;
}
SequenceIdArg(void* buf, SequenceIdArg(void* buf,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
...@@ -190,9 +208,18 @@ private: ...@@ -190,9 +208,18 @@ private:
size_t numSeqs_; size_t numSeqs_;
}; };
// sequence data // sequences data
// For mini-batch calculate,
// one batch can contain more than one sequence of data.
// SequenceArg can be used to represent sequences that contain multiple
// unequal lengths.
class SequenceArg : public BufferArg { class SequenceArg : public BufferArg {
public: public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
SequenceArg(void* buf, SequenceArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
...@@ -210,6 +237,8 @@ public: ...@@ -210,6 +237,8 @@ public:
void* getIdBuf() const { return startPositions_.data(); } void* getIdBuf() const { return startPositions_.data(); }
size_t numSeqs() const { return startPositions_.numSeqs(); } size_t numSeqs() const { return startPositions_.numSeqs(); }
SequenceIdArg& getSequenceId() { return startPositions_; }
const SequenceIdArg& getSequenceId() const { return startPositions_; }
private: private:
SequenceIdArg startPositions_; SequenceIdArg startPositions_;
......
...@@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs, ...@@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs,
} }
} }
void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) {
check(inputs[0]);
}
TEST(Arguments, Matrix) { TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200); MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) { CheckBufferArg check = [=](const BufferArg& arg) {
...@@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) { ...@@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) {
testBufferArgs(argments, checkFunc); testBufferArgs(argments, checkFunc);
} }
TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3);
EXPECT_EQ(arg.shape()[0], 1);
EXPECT_EQ(arg.shape()[1], 2);
EXPECT_EQ(arg.shape()[2], 3);
};
BufferArgs argments;
argments.addArg(arg);
testBufferArgs(argments, check);
}
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册