diff --git a/paddle/function/BufferArg.h b/paddle/function/BufferArg.h index 12352ba29e33920ba65bd66088b6f7cc53517b52..28542a86574812048ff5f72a2d4b2c188f7e30bf 100644 --- a/paddle/function/BufferArg.h +++ b/paddle/function/BufferArg.h @@ -39,7 +39,6 @@ enum SparseDataFormat { SPARSE_CSR_FORMAT = 0, SPARSE_CSC_FORMAT = 1 }; class BufferArg; class SequenceArg; class SparseMatrixArg; -typedef std::shared_ptr BufferArgPtr; /** * \brief BufferArg used as the argument type of Function. @@ -50,6 +49,11 @@ typedef std::shared_ptr BufferArgPtr; * 3. SequenceArg for a Buffer of sequence data. * 4. SparseMatrixArg for a Buffer of sparse matrix. * + * Buffer shape + * For most buffers, the first dimension `shape()[0]` represents + * the size of the mini-batch. + * + * Buffer argType * There is an ArgType property for the BufferArg used as Function Output. * Whether the result of the Function calculation is assigned to the * output Buffer or added to the output Buffer is determined by the @@ -71,6 +75,14 @@ public: ArgType getArgType() const { return argType_; } public: + BufferArg(ValueType valueType, + const TensorShape& shape, + ArgType argType = UNSPECIFIED) + : buf_(nullptr), + valueType_(valueType), + shape_(shape), + argType_(argType) {} + BufferArg(void* buf, ValueType valueType, const TensorShape& shape, @@ -170,6 +182,12 @@ protected: // if a < b then value_.buf_[a] < value_.buf_[b] class SequenceIdArg : public BufferArg { public: + SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED) + : BufferArg(VALUE_TYPE_INT32, shape, argType) { + CHECK_EQ(shape_.ndims(), (size_t)1); + numSeqs_ = shape_[0] - 1; + } + SequenceIdArg(void* buf, const TensorShape& shape, ArgType argType = UNSPECIFIED) @@ -190,9 +208,18 @@ private: size_t numSeqs_; }; -// sequence data +// sequences data +// For mini-batch calculate, +// one batch can contain more than one sequence of data. +// SequenceArg can be used to represent sequences that contain multiple +// unequal lengths. class SequenceArg : public BufferArg { public: + SequenceArg(ValueType valueType, + const TensorShape& shape, + ArgType argType = UNSPECIFIED) + : BufferArg(valueType, shape, argType), startPositions_(TensorShape()) {} + SequenceArg(void* buf, ValueType valueType, const TensorShape& shape, @@ -210,6 +237,8 @@ public: void* getIdBuf() const { return startPositions_.data(); } size_t numSeqs() const { return startPositions_.numSeqs(); } + SequenceIdArg& getSequenceId() { return startPositions_; } + const SequenceIdArg& getSequenceId() const { return startPositions_; } private: SequenceIdArg startPositions_; diff --git a/paddle/function/FunctionTest.cpp b/paddle/function/FunctionTest.cpp index eb05ca9a2190d56b925fc063778459315d312d4e..03c609b524277763b73c2e100d9d9c5081c7d3f6 100644 --- a/paddle/function/FunctionTest.cpp +++ b/paddle/function/FunctionTest.cpp @@ -84,6 +84,10 @@ void testBufferArgs(const BufferArgs& inputs, } } +void testBufferArgs(const BufferArgs& inputs, const CheckBufferArg& check) { + check(inputs[0]); +} + TEST(Arguments, Matrix) { MatrixPtr matrix = Matrix::create(100, 200); CheckBufferArg check = [=](const BufferArg& arg) { @@ -144,4 +148,18 @@ TEST(Arguments, CpuSparseMatrix) { testBufferArgs(argments, checkFunc); } +TEST(Arguments, BufferArg) { + BufferArg arg(nullptr, VALUE_TYPE_FLOAT, {1, 2, 3}); + CheckBufferArg check = [=](const BufferArg& arg) { + EXPECT_EQ(arg.shape().ndims(), 3); + EXPECT_EQ(arg.shape()[0], 1); + EXPECT_EQ(arg.shape()[1], 2); + EXPECT_EQ(arg.shape()[2], 3); + }; + + BufferArgs argments; + argments.addArg(arg); + testBufferArgs(argments, check); +} + } // namespace paddle