From 7b67f398c33e03930aea8cfb0d330c2c28757100 Mon Sep 17 00:00:00 2001 From: wangxinxin08 <69842442+wangxinxin08@users.noreply.github.com> Date: Tue, 19 Oct 2021 15:06:48 +0800 Subject: [PATCH] add nearest_interp_v2 trt plugin (#34126) * add nearest_interp_v2 trt plugin --- .../fluid/inference/api/analysis_predictor.cc | 1 + .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../tensorrt/convert/nearest_interp_v2_op.cc | 108 +++++++++++++ .../convert/test_nearest_interp_v2_op.cc | 54 +++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 30 +++- .../tests/infer_ut/test_det_mv3_db.cc | 41 +---- .../unittests/ir/inference/CMakeLists.txt | 1 + .../test_trt_convert_nearest_interp_v2.py | 101 ++++++++++++ .../test_trt_nearest_interp_v2_op.py | 151 ++++++++++++++++++ 9 files changed, 450 insertions(+), 38 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/nearest_interp_v2_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/test_nearest_interp_v2_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp_v2.py create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_v2_op.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 3136e53e74d..dfa27037205 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1403,6 +1403,7 @@ USE_TRT_CONVERTER(roi_align); USE_TRT_CONVERTER(affine_channel); USE_TRT_CONVERTER(multiclass_nms); USE_TRT_CONVERTER(nearest_interp); +USE_TRT_CONVERTER(nearest_interp_v2); USE_TRT_CONVERTER(reshape); USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(gather_nd); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index f2c7a4b62bb..ef12cb6b366 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -18,6 +18,7 @@ nv_library(tensorrt_converter tile_op.cc conv3d_op.cc mish_op.cc + nearest_interp_v2_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/nearest_interp_v2_op.cc b/paddle/fluid/inference/tensorrt/convert/nearest_interp_v2_op.cc new file mode 100644 index 00000000000..f2e0e0c09c5 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/nearest_interp_v2_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class NearestInterpolateV2OpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid nearest_interp_v2 op"; + + framework::OpDesc op_desc(op, nullptr); + + std::string input_name = op_desc.Input("X").front(); + std::string output_name = op_desc.Output("Out").front(); + + auto input = engine_->GetITensor(input_name); + + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); + auto interp_method = + BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method")); + bool align_corners = + BOOST_GET_CONST(bool, op_desc.GetAttr("align_corners")); + + auto input_names = op_desc.Input("X"); + auto scale = BOOST_GET_CONST(std::vector, op_desc.GetAttr("scale")); + auto out_h = BOOST_GET_CONST(int, op_desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, op_desc.GetAttr("out_w")); + + auto layer = TRT_ENGINE_ADD_LAYER(engine_, Resize, *input); + layer->setAlignCorners(align_corners); + + auto in_dim = input->getDimensions(); + + float scale_h = 1.f; + float scale_w = 1.f; + + std::vector scales; + + if (out_h > 0 && out_w > 0) { + // axis are different in static/dynamic mode + bool with_dynamic = engine_->with_dynamic_shape(); + + int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic; + int w_axis = + (data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic; + + scale_h = + static_cast(out_h) / static_cast(in_dim.d[h_axis]); + scale_w = + static_cast(out_w) / static_cast(in_dim.d[w_axis]); + } else { + scale_h = scale[0]; + scale_w = scale[1]; + } + + if (engine_->with_dynamic_shape()) { + scales.push_back(1.f); + } + + if (data_layout == framework::DataLayout::kNCHW) { + scales.push_back(1.f); + scales.push_back(scale_h); + scales.push_back(scale_w); + } else if (data_layout == framework::DataLayout::kNHWC) { + // NHWC + scales.push_back(scale_h); + scales.push_back(scale_w); + scales.push_back(1.f); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Data layout must be NCHW or NHWC.")); + } + layer->setScales(scales.data(), scales.size()); + + RreplenishLayerAndOutput(layer, "nearest_interp_v2", {output_name}, + test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(nearest_interp_v2, NearestInterpolateV2OpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_nearest_interp_v2_op.cc b/paddle/fluid/inference/tensorrt/convert/test_nearest_interp_v2_op.cc new file mode 100644 index 00000000000..f5ab6a99249 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/test_nearest_interp_v2_op.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +TEST(nearest_interp_v2_op, test_swish) { + std::unordered_set parameters; + framework::Scope scope; + TRTConvertValidation validator(10, parameters, scope, 1000); + validator.DeclInputVar("interp-X", nvinfer1::Dims3(3, 32, 32)); + validator.DeclOutputVar("interp-Out", nvinfer1::Dims3(3, 64, 64)); + + // Prepare Op description + framework::OpDesc desc; + desc.SetType("nearest_interp_v2"); + desc.SetInput("X", {"interp-X"}); + desc.SetOutput("Out", {"interp-Out"}); + + std::vector scale({2.f, 2.f}); + + desc.SetAttr("data_layout", "NCHW"); + desc.SetAttr("interp_method", "nearest"); + desc.SetAttr("align_corners", false); + desc.SetAttr("scale", scale); + desc.SetAttr("out_h", 0); + desc.SetAttr("out_w", 0); + + validator.SetOp(*desc.Proto()); + + validator.Execute(1); +} + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +USE_OP(nearest_interp_v2); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 89159c0bb63..e7318d07611 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -141,7 +141,8 @@ struct SimpleOpTypeSetTeller : public Teller { "reduce_mean", "conv3d", "conv3d_transpose", - "mish"}; + "mish", + "nearest_interp_v2"}; }; bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, @@ -599,6 +600,33 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } + if (op_type == "nearest_interp_v2") { + std::vector attrs{"data_layout", "interp_method", + "align_corners", "scale", + "out_h", "out_w"}; + for (auto const attr : attrs) { + if (!desc.HasAttr(attr)) return false; + } + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); + if (data_layout != framework::DataLayout::kNCHW && + data_layout != framework::DataLayout::kNHWC) + return false; + auto interp_method = + BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); + if (interp_method != "nearest") return false; + auto scale = BOOST_GET_CONST(std::vector, desc.GetAttr("scale")); + auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w")); + if (!(out_h > 0 && out_w > 0)) { + if (scale[0] <= 0.f || scale[1] <= 0.f) { + VLOG(3) << "scale factor must be greater than 0 if out_h or out_w is " + "not set."; + return false; + } + } + } + if (op_type == "roi_align") { if (!with_dynamic_shape) return false; diff --git a/paddle/fluid/inference/tests/infer_ut/test_det_mv3_db.cc b/paddle/fluid/inference/tests/infer_ut/test_det_mv3_db.cc index 67c2eeb0be5..cf3398b49ee 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_det_mv3_db.cc +++ b/paddle/fluid/inference/tests/infer_ut/test_det_mv3_db.cc @@ -35,44 +35,11 @@ paddle::test::Record PrepareInput(int batch_size, int image_shape = 640) { void PrepareDynamicShape(paddle_infer::Config* config, int max_batch_size = 4) { // set dynamic shape range std::map> min_input_shape = { - {"x", {1, 3, 50, 50}}, - {"conv2d_92.tmp_0", {1, 120, 20, 20}}, - {"conv2d_91.tmp_0", {1, 24, 10, 10}}, - {"conv2d_59.tmp_0", {1, 96, 20, 20}}, - {"nearest_interp_v2_1.tmp_0", {1, 256, 10, 10}}, - {"nearest_interp_v2_2.tmp_0", {1, 256, 20, 20}}, - {"conv2d_124.tmp_0", {1, 256, 20, 20}}, - {"nearest_interp_v2_3.tmp_0", {1, 64, 20, 20}}, - {"nearest_interp_v2_4.tmp_0", {1, 64, 20, 20}}, - {"nearest_interp_v2_5.tmp_0", {1, 64, 20, 20}}, - {"elementwise_add_7", {1, 56, 2, 2}}, - {"nearest_interp_v2_0.tmp_0", {1, 256, 2, 2}}}; + {"x", {1, 3, 50, 50}}}; std::map> max_input_shape = { - {"x", {max_batch_size, 3, 2000, 2000}}, - {"conv2d_92.tmp_0", {max_batch_size, 120, 400, 400}}, - {"conv2d_91.tmp_0", {max_batch_size, 24, 200, 200}}, - {"conv2d_59.tmp_0", {max_batch_size, 96, 400, 400}}, - {"nearest_interp_v2_1.tmp_0", {max_batch_size, 256, 200, 200}}, - {"nearest_interp_v2_2.tmp_0", {max_batch_size, 256, 400, 400}}, - {"conv2d_124.tmp_0", {max_batch_size, 256, 400, 400}}, - {"nearest_interp_v2_3.tmp_0", {max_batch_size, 64, 400, 400}}, - {"nearest_interp_v2_4.tmp_0", {max_batch_size, 64, 400, 400}}, - {"nearest_interp_v2_5.tmp_0", {max_batch_size, 64, 400, 400}}, - {"elementwise_add_7", {max_batch_size, 56, 400, 400}}, - {"nearest_interp_v2_0.tmp_0", {max_batch_size, 256, 400, 400}}}; + {"x", {max_batch_size, 3, 1600, 1600}}}; std::map> opt_input_shape = { - {"x", {1, 3, 640, 640}}, - {"conv2d_92.tmp_0", {1, 120, 160, 160}}, - {"conv2d_91.tmp_0", {1, 24, 80, 80}}, - {"conv2d_59.tmp_0", {1, 96, 160, 160}}, - {"nearest_interp_v2_1.tmp_0", {1, 256, 80, 80}}, - {"nearest_interp_v2_2.tmp_0", {1, 256, 160, 160}}, - {"conv2d_124.tmp_0", {1, 256, 160, 160}}, - {"nearest_interp_v2_3.tmp_0", {1, 64, 160, 160}}, - {"nearest_interp_v2_4.tmp_0", {1, 64, 160, 160}}, - {"nearest_interp_v2_5.tmp_0", {1, 64, 160, 160}}, - {"elementwise_add_7", {1, 56, 40, 40}}, - {"nearest_interp_v2_0.tmp_0", {1, 256, 40, 40}}}; + {"x", {1, 3, 640, 640}}}; config->SetTRTDynamicShapeInfo(min_input_shape, max_input_shape, opt_input_shape); } @@ -123,7 +90,7 @@ TEST(tensorrt_tester_det_mv3_db, multi_thread2_trt_fp32_dynamic_shape_bz2) { FLAGS_modeldir + "/inference.pdiparams"); config.EnableUseGpu(100, 0); config.EnableTensorRtEngine( - 1 << 20, 2, 3, paddle_infer::PrecisionType::kFloat32, true, false); + 1 << 20, 2, 3, paddle_infer::PrecisionType::kFloat32, false, false); PrepareDynamicShape(&config, 4); // get groudtruth by disbale ir paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 54229533935..b951afdfad5 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -68,4 +68,5 @@ set_tests_properties(test_trt_conv_quant_dequant_pass PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_matmul_quant_dequant PROPERTIES TIMEOUT 100) set_tests_properties(test_trt_conv3d_op PROPERTIES TIMEOUT 60) set_tests_properties(test_trt_conv3d_transpose_op PROPERTIES TIMEOUT 60) +set_tests_properties(test_trt_nearest_interp_v2_op PROPERTIES TIMEOUT 30) endif() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp_v2.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp_v2.py new file mode 100644 index 00000000000..0c7715c9570 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_nearest_interp_v2.py @@ -0,0 +1,101 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertNearestInterpV2Test(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + def generate_input(): + return np.ones([1, 3, 32, 32]).astype(np.float32) + + ops_config = [{ + "op_type": "nearest_interp_v2", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["interp_output_data"] + }, + "op_attrs": { + "data_layout": "NCHW", + "interp_method": "nearest", + "align_corners": False, + "scale": [2., 2.], + "out_h": 0, + "out_w": 0 + } + }] + + ops = self.generate_op_config(ops_config) + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={"input_data": TensorConfig(data_gen=generate_input)}, + outputs=["interp_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs + for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num(attrs, + True), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_v2_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_v2_op.py new file mode 100644 index 00000000000..101ace6cd54 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_v2_op.py @@ -0,0 +1,151 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid.core as core +from paddle import fluid +import paddle.nn.functional as F +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TRTNearestInterpTest(InferencePassTest): + def setUp(self): + self.set_params() + + with fluid.program_guard(self.main_program, self.startup_program): + if self.data_layout == 'NCHW': + shape = [ + -1, self.channels, self.origin_shape[0], + self.origin_shape[1] + ] + else: + shape = [ + -1, self.origin_shape[0], self.origin_shape[1], + self.channels + ] + data = fluid.data(name='data', shape=shape, dtype='float32') + resize_out = self.append_nearest_interp(data) + out = fluid.layers.batch_norm(resize_out, is_test=True) + + if self.data_layout == 'NCHW': + shape = [ + self.bs, self.channels, self.origin_shape[0], + self.origin_shape[1] + ] + else: + shape = [ + self.bs, self.origin_shape[0], self.origin_shape[1], + self.channels + ] + + self.feeds = {'data': np.random.random(shape).astype('float32'), } + self.enable_trt = True + self.trt_parameters = TRTNearestInterpTest.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + def append_nearest_interp(self, data): + if self.scale > 0.: + return F.interpolate( + data, + scale_factor=self.scale, + align_corners=self.align_corners, + mode='nearest', + data_format=self.data_layout) + return F.interpolate( + data, + size=self.resize_shape, + align_corners=self.align_corners, + mode='nearest', + data_format=self.data_layout) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TRTNearestInterpTest1(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = 2. + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest2(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (47, 48) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest3(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest4(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = 2. + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest5(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (47, 48) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +if __name__ == "__main__": + unittest.main() -- GitLab