From c294cca4591f81ba29d543f5e75450d0649f73bc Mon Sep 17 00:00:00 2001 From: wenbin Date: Tue, 25 May 2021 13:59:42 +0800 Subject: [PATCH] =?UTF-8?q?=E7=A6=81=E6=AD=A2=E5=9C=A8=E4=BD=8E=E7=89=88?= =?UTF-8?q?=E6=9C=ACTRT=E4=B8=AD=E4=BD=BF=E7=94=A8strides>1=E7=9A=84conv?= =?UTF-8?q?=20(#32997)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * revert elementwise and disable trt conv if strides > 1 * strides check * remove useless var * comments --- .../tensorrt/convert/activation_op.cc | 5 ---- .../tensorrt/convert/affine_channel_op.cc | 10 ------- .../tensorrt/convert/elementwise_op.cc | 4 --- paddle/fluid/inference/tensorrt/op_teller.cc | 27 +++++++++++++++++++ 4 files changed, 27 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/activation_op.cc b/paddle/fluid/inference/tensorrt/convert/activation_op.cc index 9244b9af0b..e6a0ecf4ae 100644 --- a/paddle/fluid/inference/tensorrt/convert/activation_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/activation_op.cc @@ -52,11 +52,6 @@ class ActivationOpConverter : public OpConverter { engine_->GetITensor(op_desc.Input("X")[0]); auto op_pair = ops.find(op_type_); - if (op_pair == ops.end()) { - PADDLE_THROW(platform::errors::Fatal( - "Wrong activation op type, the trt do not support the %s act type.", - op_type_)); - } nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( engine_, Activation, *const_cast(input_tensor), diff --git a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc index 813342c084..eba67c3c09 100644 --- a/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc @@ -55,16 +55,6 @@ class AffineChannelOpConverter : public OpConverter { auto* bias_t = bias_v->GetMutable(); float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false); - auto data_layout = framework::StringToDataLayout( - BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); - - PADDLE_ENFORCE_EQ( - data_layout, framework::DataLayout::kNCHW, - platform::errors::InvalidArgument( - "TensorRT affine channel converter can only convert NCHW format. " - "Other format should be run in fluid mode. Report a bug on github " - "issue if you see this line.")); - // tensorrt scalend layer only support spatial dims >= 2, // so nhwc is not availabe (spatial dims == 0) const int channel_axis = engine_->with_dynamic_shape(); diff --git a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc index 47f5cc97d3..df24008544 100644 --- a/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/elementwise_op.cc @@ -25,10 +25,6 @@ static bool CheckDims(const nvinfer1::Dims& dims_x, return false; } for (int i = 0; i < dims_x.nbDims; i++) { - // conservative judgment - if (dims_x.d[i] == -1 || dims_y.d[i] == -1) { - return false; - } if (dims_x.d[i] != dims_y.d[i]) { return false; } diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 54fc9492b7..9df3ec0445 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -143,6 +143,19 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); if (paddings.size() > 2) return false; +// strides > 1 is only supported by trt7.0 above +#if !IS_TRT_VERSION_GE(7000) + if (desc.HasAttr("strides")) { + const std::vector strides = + BOOST_GET_CONST(std::vector, desc.GetAttr("strides")); + // there is no issue if strides.size() less than 2 + if (strides.size() > 1) { + for (size_t i = 0; i < strides.size(); i++) { + if (strides[i] > 1) return false; + } + } + } +#endif } if (op_type == "pool2d") { @@ -225,6 +238,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, << desc.Output("Output").size() << " output."; return false; } + +// strides > 1 is only supported by trt7.0 above +#if !IS_TRT_VERSION_GE(7000) + if (desc.HasAttr("strides")) { + const std::vector strides = + BOOST_GET_CONST(std::vector, desc.GetAttr("strides")); + // there is no issue if strides.size() less than 2 + if (strides.size() > 1) { + for (size_t i = 0; i < strides.size(); i++) { + if (strides[i] > 1) return false; + } + } + } +#endif } if (op_type == "matmul") { -- GitLab