diff --git a/paddle/function/FunctionTest.cpp b/paddle/function/FunctionTest.cpp index 6e44c2f5dbf3472d9651b4c5cfa1f3bc9a80ffc1..eb05ca9a2190d56b925fc063778459315d312d4e 100644 --- a/paddle/function/FunctionTest.cpp +++ b/paddle/function/FunctionTest.cpp @@ -57,55 +57,91 @@ TEST(Function, BufferArgs) { Function(gpuArgments); } -TEST(BufferArgs, asArgument) { +/** + * Some tests case are used to check the consistency between the BufferArg type + * argument received by Function and the original type argument. + * + * Use Case: + * TEST() { + * Matrix matrix(...); + * CheckBufferArg lambda = [=](const BufferArg& arg) { + * // check matrix and arg are equivalent + * EXPECT_EQ(matrix, arg); + * } + * + * BufferArgs argments{matrix...}; + * std::vector checkFunc{lambda...}; + * testBufferArgs(argments, checkFunc); + * } + */ +typedef std::function CheckBufferArg; + +void testBufferArgs(const BufferArgs& inputs, + const std::vector& check) { + EXPECT_EQ(inputs.size(), check.size()); + for (size_t i = 0; i < inputs.size(); i++) { + check[i](inputs[i]); + } +} + +TEST(Arguments, Matrix) { MatrixPtr matrix = Matrix::create(100, 200); - VectorPtr vector = Vector::create(100, false); - CpuSparseMatrix sparse(200, 300, 50); + CheckBufferArg check = [=](const BufferArg& arg) { + EXPECT_EQ(arg.shape().ndims(), 2); + EXPECT_EQ(arg.shape()[0], 100); + EXPECT_EQ(arg.shape()[1], 200); + EXPECT_EQ(arg.data(), matrix->getData()); + + EXPECT_EQ(arg.matrix().getHeight(), matrix->getHeight()); + EXPECT_EQ(arg.matrix().getWidth(), matrix->getWidth()); + EXPECT_EQ(arg.matrix().getData(), matrix->getData()); + }; - // prepare arguments BufferArgs argments; argments.addArg(*matrix); - argments.addArg(*vector); - argments.addArg(sparse); + std::vector checkFunc; + checkFunc.push_back(check); + testBufferArgs(argments, checkFunc); +} + +TEST(Arguments, Vector) { + VectorPtr vector = Vector::create(100, false); + CheckBufferArg check = [=](const BufferArg& arg) { + EXPECT_EQ(arg.shape().ndims(), 1); + EXPECT_EQ(arg.shape()[0], 100); + EXPECT_EQ(arg.data(), vector->getData()); - // function - auto function = [=](const BufferArgs& inputs) { - EXPECT_EQ(inputs.size(), 3); - - // check inputs[0] - EXPECT_EQ(inputs[0].shape().ndims(), 2); - EXPECT_EQ(inputs[0].shape()[0], 100); - EXPECT_EQ(inputs[0].shape()[1], 200); - EXPECT_EQ(inputs[0].data(), matrix->getData()); - - EXPECT_EQ(inputs[0].matrix().getHeight(), - matrix->getHeight()); - EXPECT_EQ(inputs[0].matrix().getWidth(), - matrix->getWidth()); - EXPECT_EQ(inputs[0].matrix().getData(), matrix->getData()); - - // check inputs[1] - EXPECT_EQ(inputs[1].shape().ndims(), 1); - EXPECT_EQ(inputs[1].shape()[0], 100); - EXPECT_EQ(inputs[1].data(), vector->getData()); - CpuVector inVector = inputs[1].vector(); + CpuVector inVector = arg.vector(); EXPECT_EQ(inVector.getSize(), vector->getSize()); EXPECT_EQ(inVector.getData(), vector->getData()); + }; - // check inputs[2] - EXPECT_EQ(inputs[2].shape().ndims(), 2); - EXPECT_EQ(inputs[2].shape()[0], 200); - EXPECT_EQ(inputs[2].shape()[1], 300); - EXPECT_EQ(inputs[2].data(), sparse.getData()); - // CHECK_EQ(inputs[2].sparse().nnz(), 50); - // CHECK_EQ(inputs[2].sparse().dataFormat(), SPARSE_CSR_FORMAT); - // CHECK_EQ(inputs[2].sparse().dataType(), SPARSE_FLOAT_VALUE); - EXPECT_EQ(inputs[2].sparse().getRowBuf(), sparse.getRows()); - EXPECT_EQ(inputs[2].sparse().getColBuf(), sparse.getCols()); + BufferArgs argments; + argments.addArg(*vector); + std::vector checkFunc; + checkFunc.push_back(check); + testBufferArgs(argments, checkFunc); +} + +TEST(Arguments, CpuSparseMatrix) { + CpuSparseMatrix sparse(200, 300, 50); + CheckBufferArg check = [=](const BufferArg& arg) { + EXPECT_EQ(arg.shape().ndims(), 2); + EXPECT_EQ(arg.shape()[0], 200); + EXPECT_EQ(arg.shape()[1], 300); + EXPECT_EQ(arg.data(), sparse.getData()); + // CHECK_EQ(arg.sparse().nnz(), 50); + // CHECK_EQ(arg.sparse().dataFormat(), SPARSE_CSR_FORMAT); + // CHECK_EQ(arg.sparse().dataType(), SPARSE_FLOAT_VALUE); + EXPECT_EQ(arg.sparse().getRowBuf(), sparse.getRows()); + EXPECT_EQ(arg.sparse().getColBuf(), sparse.getCols()); }; - // call function - function(argments); + BufferArgs argments; + argments.addArg(sparse); + std::vector checkFunc; + checkFunc.push_back(check); + testBufferArgs(argments, checkFunc); } } // namespace paddle