From 64a08f840fece64397cc9978db1866158eb11b3e Mon Sep 17 00:00:00 2001 From: nhzlx Date: Wed, 1 Aug 2018 05:14:55 +0000 Subject: [PATCH] increase the test batch --- .../inference/tensorrt/convert/test_elementwise_op.cc | 8 ++++---- paddle/fluid/inference/tensorrt/convert/ut_helper.h | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc index c0254c6162..7537d02a35 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_elementwise_op.cc @@ -23,7 +23,7 @@ namespace tensorrt { TEST(elementwise_op, add_weight_test) { std::unordered_set parameters({"elementwise_add-Y"}); framework::Scope scope; - TRTConvertValidation validator(1, parameters, scope, 1 << 15); + TRTConvertValidation validator(10, parameters, scope, 1 << 15); validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclParamVar("elementwise_add-Y", nvinfer1::Dims3(10, 1, 1)); // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); @@ -41,13 +41,13 @@ TEST(elementwise_op, add_weight_test) { validator.SetOp(*desc.Proto()); - validator.Execute(1); + validator.Execute(8); } TEST(elementwise_op, add_tensor_test) { std::unordered_set parameters; framework::Scope scope; - TRTConvertValidation validator(2, parameters, scope, 1 << 15); + TRTConvertValidation validator(8, parameters, scope, 1 << 15); validator.DeclInputVar("elementwise_add-X", nvinfer1::DimsCHW(10, 3, 3)); validator.DeclInputVar("elementwise_add-Y", nvinfer1::Dims3(10, 3, 3)); // validator.DeclParamVar("mul-Y", nvinfer1::Dims2(8, 2)); @@ -64,7 +64,7 @@ TEST(elementwise_op, add_tensor_test) { validator.SetOp(*desc.Proto()); - validator.Execute(1); + validator.Execute(8); } } // namespace tensorrt diff --git a/paddle/fluid/inference/tensorrt/convert/ut_helper.h b/paddle/fluid/inference/tensorrt/convert/ut_helper.h index 39529cc2c7..63c2f978f2 100644 --- a/paddle/fluid/inference/tensorrt/convert/ut_helper.h +++ b/paddle/fluid/inference/tensorrt/convert/ut_helper.h @@ -149,7 +149,7 @@ class TRTConvertValidation { cudaStreamSynchronize(*engine_->stream()); ASSERT_FALSE(op_desc_->OutputArgumentNames().empty()); - const size_t output_space_size = 2000; + const size_t output_space_size = 3000; for (const auto& output : op_desc_->OutputArgumentNames()) { std::vector fluid_out; std::vector trt_out(output_space_size); -- GitLab