From 11f5d0c81606617506bf17ff50ef599559537a81 Mon Sep 17 00:00:00 2001 From: nhzlx Date: Fri, 17 Aug 2018 10:19:41 +0000 Subject: [PATCH] Merge pull request #12761 from NHZlX:global_pooling_trt --- .../tensorrt/convert/test_pool2d_op.cc | 43 +++---------------- 1 file changed, 7 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc index cc13af0cb..aedd6b62d 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc @@ -20,7 +20,7 @@ namespace paddle { namespace inference { namespace tensorrt { -TEST(Pool2dOpConverter, main) { +void test_pool2d(bool global_pooling) { framework::Scope scope; std::unordered_set parameters; TRTConvertValidation validator(5, parameters, scope, 1 << 15); @@ -28,7 +28,10 @@ TEST(Pool2dOpConverter, main) { // The ITensor's Dims should not contain the batch size. // So, the ITensor's Dims of input and output should be C * H * W. validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); + if (global_pooling) + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); + else + validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2)); // Prepare Op description framework::OpDesc desc; @@ -40,7 +43,6 @@ TEST(Pool2dOpConverter, main) { std::vector strides({2, 2}); std::vector paddings({0, 0}); std::string pooling_t = "max"; - bool global_pooling = false; desc.SetAttr("pooling_type", pooling_t); desc.SetAttr("ksize", ksize); @@ -55,40 +57,9 @@ TEST(Pool2dOpConverter, main) { validator.Execute(3); } -TEST(Pool2dOpConverter, test_global_pooling) { - framework::Scope scope; - std::unordered_set parameters; - TRTConvertValidation validator(5, parameters, scope, 1 << 15); - - // The ITensor's Dims should not contain the batch size. - // So, the ITensor's Dims of input and output should be C * H * W. - validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4)); - validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1)); - - // Prepare Op description - framework::OpDesc desc; - desc.SetType("pool2d"); - desc.SetInput("X", {"pool2d-X"}); - desc.SetOutput("Out", {"pool2d-Out"}); - - std::vector ksize({2, 2}); - std::vector strides({2, 2}); - std::vector paddings({0, 0}); - std::string pooling_t = "max"; - bool global_pooling = true; +TEST(Pool2dOpConverter, normal) { test_pool2d(false); } - desc.SetAttr("pooling_type", pooling_t); - desc.SetAttr("ksize", ksize); - desc.SetAttr("strides", strides); - desc.SetAttr("paddings", paddings); - desc.SetAttr("global_pooling", global_pooling); - - LOG(INFO) << "set OP"; - validator.SetOp(*desc.Proto()); - LOG(INFO) << "execute"; - - validator.Execute(3); -} +TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); } } // namespace tensorrt } // namespace inference -- GitLab