From f403fb69cef0e1f6400404e8f79d9770ab776d94 Mon Sep 17 00:00:00 2001 From: feng_shuai Date: Tue, 12 Apr 2022 11:41:22 +0800 Subject: [PATCH] add trt supoort for slice op (#41467) * add trt supoort for slice op * fix:output dims bug * fix: test * fix:for c++ coverage * fix:c++ coverage * fix: fix test bug * fix: CI test --- .../inference/tensorrt/convert/slice_op.cc | 8 +++-- paddle/fluid/inference/tensorrt/op_teller.cc | 29 +++++++++++-------- .../tensorrt/plugin/slice_op_plugin.cu | 25 ++++++++++++++-- .../tensorrt/plugin/slice_op_plugin.h | 7 +++-- .../ir/inference/test_trt_convert_slice.py | 20 ++++++++----- 5 files changed, 62 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index fde80ab42c..dea9a1ec3d 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter { BOOST_GET_CONST(std::vector, op_desc.GetAttr("starts")); std::vector ends = BOOST_GET_CONST(std::vector, op_desc.GetAttr("ends")); + std::vector decrease_axises = + BOOST_GET_CONST(std::vector, op_desc.GetAttr("decrease_axis")); auto input_dims = input->getDimensions(); if (!engine_->with_dynamic_shape()) { @@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter { } else { bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - plugin::SlicePluginDynamic* plugin = - new plugin::SlicePluginDynamic(starts, ends, axes, with_fp16); + int decrease_axis = + decrease_axises.size() == 0 ? -1 : decrease_axises[0]; + plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( + starts, ends, axes, decrease_axis, with_fp16); layer = engine_->AddDynamicPlugin(&input, 1, plugin); } } else { diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 85c5dc7107..6ccaf80c9f 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, if (desc.HasAttr("decrease_axis")) { std::vector decrease_axis = BOOST_GET_CONST(std::vector, desc.GetAttr("decrease_axis")); - if (decrease_axis.size() > 0) { - VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0" - "is not supported in TensorRT"; - return false; + if (with_dynamic_shape) { + if (decrease_axis.size() > 1) { + return false; + } + } else { + if (decrease_axis.size() > 0) { + VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0" + "is not supported in TensorRT"; + return false; + } } } @@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, return false; } if (desc.Input("Ids").size() != desc.Input("Embs").size()) { - VLOG(3) << "The id and emb size of fused EmbEltwiseLayerNormOp " - "should be same "; return false; } } if (op_type == "fused_preln_embedding_eltwise_layernorm") { if (!with_dynamic_shape) { - VLOG(3) - << "fused_preln_embedding_eltwise_layernorm should run on dynamic " - "shape mode."; + VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on " + "dynamic " + "shape mode."; return false; } if (desc.Input("Ids").size() != desc.Input("Embs").size()) { @@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, const auto y_shape = y_var_desc->GetShape(); if (y_shape.size() != 2) { VLOG(3) - << " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = " + << " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = + " << y_shape.size(); return false; } @@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } #else if (dtype != framework::proto::VarType::FP32) { - VLOG(3) - << "reduce op input data type must be float32 using TensorRT < 7.0"; + VLOG(3) << "reduce op input data type must be float32 using TensorRT " + "< 7.0"; return false; } #endif diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu index 2b6541c551..4e6b82d2dc 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.cu @@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { #if IS_TRT_VERSION_GE(6000) SlicePluginDynamic::SlicePluginDynamic(std::vector starts, std::vector ends, - std::vector axes, bool with_fp16) - : starts_(starts), ends_(ends), axes_(axes) { + std::vector axes, int decrease_axis, + bool with_fp16) + : starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) { with_fp16_ = with_fp16; cudaEventCreate(©_event_); cudaStreamCreate(©_stream_); @@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData, DeserializeValue(&serialData, &serialLength, &starts_); DeserializeValue(&serialData, &serialLength, &ends_); DeserializeValue(&serialData, &serialLength, &axes_); + DeserializeValue(&serialData, &serialLength, &decrease_axis_); DeserializeValue(&serialData, &serialLength, &with_fp16_); cudaEventCreate(©_event_); cudaStreamCreate(©_stream_); @@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { size_t size = SerializedSize(starts_) + SerializedSize(ends_) + - SerializedSize(axes_) + SerializedSize(with_fp16_); + SerializedSize(axes_) + SerializedSize(decrease_axis_) + + SerializedSize(with_fp16_); return size; } @@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { SerializeValue(&buffer, starts_); SerializeValue(&buffer, ends_); SerializeValue(&buffer, axes_); + SerializeValue(&buffer, decrease_axis_); SerializeValue(&buffer, with_fp16_); } @@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( ret.d[axes_[i]] = expr_builder.constant(end - start); #endif } + if (decrease_axis_ != -1) { + nvinfer1::DimsExprs res; + res.nbDims = ret.nbDims - 1; + int j = 0; + for (size_t i = 0; i < in_dims.nbDims; i++) { + if (decrease_axis_ == i) continue; + res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX, + *expr_builder.constant(0), *ret.d[i]); + } + return res; + } return ret; } @@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, cudaStream_t stream) TRT_NOEXCEPT { auto input_dims = input_desc[0].dims; auto out_dims = output_desc[0].dims; + if (decrease_axis_ != -1) { + out_dims = input_dims; + out_dims.d[decrease_axis_] = 1; + } auto num_dims = input_dims.nbDims; size_t out_num = ProductDim(out_dims); diff --git a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h index 29f8f7c099..4c07f0be36 100644 --- a/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h @@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator); class SlicePluginDynamic : public DynamicPluginTensorRT { public: explicit SlicePluginDynamic(std::vector starts, std::vector ends, - std::vector axes, bool with_fp16); + std::vector axes, int decrease_axis, + bool with_fp16); nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { - return new SlicePluginDynamic(starts_, ends_, axes_, with_fp16_); + return new SlicePluginDynamic(starts_, ends_, axes_, decrease_axis_, + with_fp16_); } SlicePluginDynamic(void const* serialData, size_t serialLength); @@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { std::vector starts_; std::vector ends_; std::vector axes_; + int decrease_axis_; int* offset_temp_data_{nullptr}; cudaEvent_t copy_event_; cudaStream_t copy_stream_; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py index 17a2c9cd74..86c52dad23 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py @@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): def sample_program_configs(self): def generate_input1(attrs: List[Dict[str, Any]]): - return np.ones([1, 3, 64, 64]).astype(np.float32) + return np.ones([6, 6, 64, 64]).astype(np.float32) for axes in [[0, 1], [1, 3], [2, 3]]: - for starts in [[0, 1], [-4, -3]]: - for ends in [[2, 2], [-1, -2], [5, 5]]: + for starts in [[0, 1]]: + for ends in [[2, 2], [5, 5]]: for decrease_axis in [[], [1], [2], [-1], [-100]]: for infer_flags in [[-1]]: dics = [{ @@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): self, program_config) -> (paddle_infer.Config, List[int], float): def generate_dynamic_shape(attrs): self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} - self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} - self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]} def clear_dynamic_shape(): self.dynamic_shape.min_input_shape = {} @@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): def generate_trt_nodes_num(attrs, dynamic_shape): inputs = program_config.inputs - if len(attrs[0]["decrease_axis"]) != 0: + if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0: + return 1, 2 + if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1: + return 0, 3 + if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0: return 0, 3 if dynamic_shape: for i in range(len(attrs[0]["starts"])): @@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): program_config.ops[i].attrs for i in range(len(program_config.ops)) ] - + self.trt_param.max_batch_size = 9 # for static_shape clear_dynamic_shape() self.trt_param.precision = paddle_infer.PrecisionType.Float32 @@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): # TODO(inference): fix. # trt6 and trt7.1 has bug. # trt7.2 deserialize has bug. - # self.run_test() + self.run_test() pass -- GitLab