diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index fc436311f0796c2211f447822741f33c4ed4549c..8f2b217a2fde0a05ccb5e09d867d0dc9a892511b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1192,6 +1192,8 @@ USE_TRT_CONVERTER(scale); USE_TRT_CONVERTER(stack); USE_TRT_CONVERTER(clip); USE_TRT_CONVERTER(gather); + +USE_TRT_CONVERTER(nearest_interp); #endif namespace paddle_infer { diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 59205529ef4c029ce7d08e382a02c868d7e94db1..b0d0229ec0531f9b907fc449b62b71752b4b17a5 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -6,6 +6,8 @@ nv_library(tensorrt_converter shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc gather_op.cc + + nearest_interp_op.cc DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry) nv_test(test_op_converter SRCS test_op_converter.cc DEPS diff --git a/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc b/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e91a2ee13f4c2d29274cfb70462e2ff1badce525 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/nearest_interp_op.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class NearestInterpolateOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, bool test_mode) override { + VLOG(3) << "convert a fluid nearest_interp op"; + + framework::OpDesc op_desc(op, nullptr); + + std::string input_name = op_desc.Input("X").front(); + std::string output_name = op_desc.Output("Out").front(); + + auto input = engine_->GetITensor(input_name); + + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); + auto interp_method = + BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method")); + bool align_corners = + BOOST_GET_CONST(bool, op_desc.GetAttr("align_corners")); + + auto input_names = op_desc.Input("X"); + auto scale = BOOST_GET_CONST(float, op_desc.GetAttr("scale")); + auto out_h = BOOST_GET_CONST(int, op_desc.GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, op_desc.GetAttr("out_w")); + + auto layer = TRT_ENGINE_ADD_LAYER(engine_, Resize, *input); + layer->setAlignCorners(align_corners); + + auto in_dim = input->getDimensions(); + + float scale_h = 1.f; + float scale_w = 1.f; + + std::vector scales; + + if (scale > 0.f && (out_h <= 0 && out_w <= 0)) { + scale_h = scale; + scale_w = scale; + } else { + // axis are different in static/dynamic mode + PADDLE_ENFORCE_GT( + out_h, 0, platform::errors::InvalidArgument( + "out_h must be greater than 0 if scale is not set.")); + PADDLE_ENFORCE_GT( + out_w, 0, platform::errors::InvalidArgument( + "out_w must be greater than 0 if scale is not set.")); + + bool with_dynamic = engine_->with_dynamic_shape(); + + int h_axis = (data_layout == framework::DataLayout::kNCHW) + with_dynamic; + int w_axis = + (data_layout == framework::DataLayout::kNCHW) + 1 + with_dynamic; + + scale_h = + static_cast(out_h) / static_cast(in_dim.d[h_axis]); + scale_w = + static_cast(out_w) / static_cast(in_dim.d[w_axis]); + } + + if (engine_->with_dynamic_shape()) { + scales.push_back(1.f); + } + + if (data_layout == framework::DataLayout::kNCHW) { + scales.push_back(1.f); + scales.push_back(scale_h); + scales.push_back(scale_w); + } else if (data_layout == framework::DataLayout::kNHWC) { + // NHWC + scales.push_back(scale_h); + scales.push_back(scale_w); + scales.push_back(1.f); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Data layout must be NCHW or NHWC.")); + } + layer->setScales(scales.data(), scales.size()); + + RreplenishLayerAndOutput(layer, "nearest_interp", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(nearest_interp, NearestInterpolateOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 44939606b49c3578d5bb50c5e3c0f658d09b6eb8..2ec94f5f98c8d8e4ee9f6facd140d8869d5dedda 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/inference/tensorrt/op_teller.h" #include "paddle/fluid/framework/block_desc.h" +#include "paddle/fluid/framework/data_layout.h" namespace paddle { namespace framework { @@ -110,6 +111,8 @@ struct SimpleOpTypeSetTeller : public Teller { "flatten2", "flatten", "gather", + + "nearest_interp", }; }; @@ -187,10 +190,29 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (axis != 1) return false; } } + if (op_type == "gather") { // current not support axis from input, use default 0 if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false; } + + if (op_type == "nearest_interp") { + std::vector attrs{"data_layout", "interp_method", + "align_corners", "scale", + "out_h", "out_w"}; + for (auto const attr : attrs) { + if (!desc.HasAttr(attr)) return false; + } + auto data_layout = framework::StringToDataLayout( + BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); + if (data_layout != framework::DataLayout::kNCHW && + data_layout != framework::DataLayout::kNHWC) + return false; + auto interp_method = + BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); + if (interp_method != "nearest") return false; + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1a58a6c9dda7d16824edc38776fdf243794e4391 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_nearest_interp_op.py @@ -0,0 +1,192 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TRTNearestInterpTest(InferencePassTest): + def setUp(self): + self.set_params() + + with fluid.program_guard(self.main_program, self.startup_program): + if self.data_layout == 'NCHW': + shape = [ + -1, self.channels, self.origin_shape[0], + self.origin_shape[1] + ] + else: + shape = [ + -1, self.origin_shape[0], self.origin_shape[1], + self.channels + ] + data = fluid.data(name='data', shape=shape, dtype='float32') + resize_out = self.append_nearest_interp(data) + out = fluid.layers.batch_norm(resize_out, is_test=True) + + if self.data_layout == 'NCHW': + shape = [ + self.bs, self.channels, self.origin_shape[0], + self.origin_shape[1] + ] + else: + shape = [ + self.bs, self.origin_shape[0], self.origin_shape[1], + self.channels + ] + + self.feeds = {'data': np.random.random(shape).astype('float32'), } + self.enable_trt = True + self.trt_parameters = TRTNearestInterpTest.TensorRTParam( + 1 << 30, self.bs, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.bs = 4 + self.scale = 1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = True + self.data_layout = 'NCHW' + + def append_nearest_interp(self, data): + if self.scale > 0.: + return fluid.layers.resize_nearest( + data, + scale=self.scale, + align_corners=self.align_corners, + data_format=self.data_layout) + return fluid.layers.resize_nearest( + data, + out_shape=self.resize_shape, + align_corners=self.align_corners, + data_format=self.data_layout) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TRTNearestInterpTest1(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = True + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest2(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = 2. + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest3(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest4(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (47, 48) # HW + self.align_corners = False + self.data_layout = 'NCHW' + + +class TRTNearestInterpTest5(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = True + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest6(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = 2. + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest7(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (64, 64) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest8(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (47, 48) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +class TRTNearestInterpTest9(TRTNearestInterpTest): + def set_params(self): + self.bs = 4 + self.scale = -1 + self.channels = 3 + self.origin_shape = (32, 32) # HW + self.resize_shape = (47, 48) # HW + self.align_corners = False + self.data_layout = 'NHWC' + + +if __name__ == "__main__": + unittest.main()