diff --git a/paddle/function/BufferArgTest.cpp b/paddle/function/BufferArgTest.cpp index b345597435c9911ce95b596f5f7f2add47f4cd03..1744f377808f137dcda4a28acce336dc22be3d01 100644 --- a/paddle/function/BufferArgTest.cpp +++ b/paddle/function/BufferArgTest.cpp @@ -14,9 +14,7 @@ limitations under the License. */ #include "BufferArg.h" #include -#include "Function.h" #include "paddle/math/MemoryHandle.h" -#include "paddle/math/SparseMatrix.h" namespace paddle { @@ -37,55 +35,4 @@ TEST(BufferTest, SequenceIdArg) { EXPECT_EQ(buffer.numSeqs(), 9); } -TEST(BufferTest, asArgument) { - MatrixPtr matrix = Matrix::create(100, 200); - VectorPtr vector = Vector::create(100, false); - CpuSparseMatrix sparse(200, 300, 50); - - // prepare arguments - BufferArgs argments; - argments.addArg(*matrix); - argments.addArg(*vector); - argments.addArg(sparse); - - // function - auto function = [=](const BufferArgs& inputs) { - EXPECT_EQ(inputs.size(), 3); - - // check inputs[0] - EXPECT_EQ(inputs[0].shape().ndims(), 2); - EXPECT_EQ(inputs[0].shape()[0], 100); - EXPECT_EQ(inputs[0].shape()[1], 200); - EXPECT_EQ(inputs[0].data(), matrix->getData()); - - EXPECT_EQ(inputs[0].matrix().getHeight(), - matrix->getHeight()); - EXPECT_EQ(inputs[0].matrix().getWidth(), - matrix->getWidth()); - EXPECT_EQ(inputs[0].matrix().getData(), matrix->getData()); - - // check inputs[1] - EXPECT_EQ(inputs[1].shape().ndims(), 1); - EXPECT_EQ(inputs[1].shape()[0], 100); - EXPECT_EQ(inputs[1].data(), vector->getData()); - CpuVector inVector = inputs[1].vector(); - EXPECT_EQ(inVector.getSize(), vector->getSize()); - EXPECT_EQ(inVector.getData(), vector->getData()); - - // check inputs[2] - EXPECT_EQ(inputs[2].shape().ndims(), 2); - EXPECT_EQ(inputs[2].shape()[0], 200); - EXPECT_EQ(inputs[2].shape()[1], 300); - EXPECT_EQ(inputs[2].data(), sparse.getData()); - // CHECK_EQ(inputs[2].sparse().nnz(), 50); - // CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT); - // CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE); - EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows()); - EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols()); - }; - - // call function - function(argments); -} - } // namespace paddle diff --git a/paddle/function/FunctionTest.cpp b/paddle/function/FunctionTest.cpp index 7ce908320a6f6f764e8fdacc96432aca78d7b2df..6e44c2f5dbf3472d9651b4c5cfa1f3bc9a80ffc1 100644 --- a/paddle/function/FunctionTest.cpp +++ b/paddle/function/FunctionTest.cpp @@ -14,6 +14,7 @@ limitations under the License. */ #include "Function.h" #include +#include "paddle/math/SparseMatrix.h" namespace paddle { @@ -56,4 +57,55 @@ TEST(Function, BufferArgs) { Function(gpuArgments); } +TEST(BufferArgs, asArgument) { + MatrixPtr matrix = Matrix::create(100, 200); + VectorPtr vector = Vector::create(100, false); + CpuSparseMatrix sparse(200, 300, 50); + + // prepare arguments + BufferArgs argments; + argments.addArg(*matrix); + argments.addArg(*vector); + argments.addArg(sparse); + + // function + auto function = [=](const BufferArgs& inputs) { + EXPECT_EQ(inputs.size(), 3); + + // check inputs[0] + EXPECT_EQ(inputs[0].shape().ndims(), 2); + EXPECT_EQ(inputs[0].shape()[0], 100); + EXPECT_EQ(inputs[0].shape()[1], 200); + EXPECT_EQ(inputs[0].data(), matrix->getData()); + + EXPECT_EQ(inputs[0].matrix().getHeight(), + matrix->getHeight()); + EXPECT_EQ(inputs[0].matrix().getWidth(), + matrix->getWidth()); + EXPECT_EQ(inputs[0].matrix().getData(), matrix->getData()); + + // check inputs[1] + EXPECT_EQ(inputs[1].shape().ndims(), 1); + EXPECT_EQ(inputs[1].shape()[0], 100); + EXPECT_EQ(inputs[1].data(), vector->getData()); + CpuVector inVector = inputs[1].vector(); + EXPECT_EQ(inVector.getSize(), vector->getSize()); + EXPECT_EQ(inVector.getData(), vector->getData()); + + // check inputs[2] + EXPECT_EQ(inputs[2].shape().ndims(), 2); + EXPECT_EQ(inputs[2].shape()[0], 200); + EXPECT_EQ(inputs[2].shape()[1], 300); + EXPECT_EQ(inputs[2].data(), sparse.getData()); + // CHECK_EQ(inputs[2].sparse().nnz(), 50); + // CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT); + // CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE); + EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows()); + EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols()); + }; + + // call function + function(argments); +} + } // namespace paddle