提交 039c0bf2 编写于 作者: H hedaoyuan

Add some constructors for generating object that only contains shape (do not contains data).

上级 2a20fdc1
......@@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 };
class BufferArg;
class SequenceArg;
class SparseMatrixArg;
typedef std::shared_ptr<BufferArg> BufferArgPtr;
/**
* \brief BufferArg used as the argument type of Function.
......@@ -50,6 +49,11 @@ typedef std::shared_ptr<BufferArg> BufferArgPtr;
* 3. SequenceArg for a Buffer of sequence data.
* 4. SparseMatrixArg for a Buffer of sparse matrix.
*
* Buffer shape
* For most buffers, the first dimension `shape()[0]` represents
* the size of the mini-batch.
*
* Buffer argType
* There is an ArgType property for the BufferArg used as Function Output.
* Whether the result of the Function calculation is assigned to the
* output Buffer or added to the output Buffer is determined by the
......@@ -71,6 +75,14 @@ public:
ArgType getArgType() const { return argType_; }
public:
BufferArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: buf_(nullptr),
valueType_(valueType),
shape_(shape),
argType_(argType) {}
BufferArg(void* buf,
ValueType valueType,
const TensorShape& shape,
......@@ -170,6 +182,12 @@ protected:
// if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg {
public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) {
CHECK_EQ(shape_.ndims(), (size_t)1);
numSeqs_ = shape_[0] - 1;
}
SequenceIdArg(void* buf,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
......@@ -190,9 +208,18 @@ private:
size_t numSeqs_;
};
// sequence data
// sequences data
// For mini-batch calculate,
// one batch can contain more than one sequence of data.
// SequenceArg can be used to represent sequences that contain multiple
// unequal lengths.
class SequenceArg : public BufferArg {
public:
SequenceArg(ValueType valueType,
const TensorShape& shape,
ArgType argType = UNSPECIFIED)
: BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {}
SequenceArg(void* buf,
ValueType valueType,
const TensorShape& shape,
......@@ -210,6 +237,8 @@ public:
void* getIdBuf() const { return startPositions_.data(); }
size_t numSeqs() const { return startPositions_.numSeqs(); }
SequenceIdArg& getSequenceId() { return startPositions_; }
const SequenceIdArg& getSequenceId() const { return startPositions_; }
private:
SequenceIdArg startPositions_;
......
......@@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs,
}
}
void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) {
check(inputs[0]);
}
TEST(Arguments, Matrix) {
MatrixPtr matrix = Matrix::create(100, 200);
CheckBufferArg check = [=](const BufferArg& arg) {
......@@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) {
testBufferArgs(argments, checkFunc);
}
TEST(Arguments, BufferArg) {
BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3});
CheckBufferArg check = [=](const BufferArg& arg) {
EXPECT_EQ(arg.shape().ndims(), 3);
EXPECT_EQ(arg.shape()[0], 1);
EXPECT_EQ(arg.shape()[1], 2);
EXPECT_EQ(arg.shape()[2], 3);
};
BufferArgs argments;
argments.addArg(arg);
testBufferArgs(argments, check);
}
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册