From a9f3719bdf2f60b0ae67bfa573c77a1f4b09b785 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Wed, 3 Aug 2022 11:27:29 +0800 Subject: [PATCH] remove stack plugin (#44756) --- .../inference/tensorrt/convert/stack_op.cc | 48 +++++++++++-------- .../ir/inference/test_trt_convert_stack.py | 24 +++++----- 2 files changed, 40 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/stack_op.cc b/paddle/fluid/inference/tensorrt/convert/stack_op.cc index e4d3003b534..c60d2578ec0 100644 --- a/paddle/fluid/inference/tensorrt/convert/stack_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/stack_op.cc @@ -41,11 +41,10 @@ class StackOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); auto input = op_desc.Input("X"); int input_num = input.size(); - nvinfer1::ITensor** inputs = - (nvinfer1::ITensor**)malloc(input_num * sizeof(nvinfer1::ITensor*)); + std::vector inputs; for (int i = 0; i < input_num; ++i) { - inputs[i] = engine_->GetITensor(input[i]); + inputs.push_back(engine_->GetITensor(input[i])); if (op_desc.HasAttr("out_threshold")) { float out_scale = PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); @@ -54,28 +53,37 @@ class StackOpConverter : public OpConverter { } int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis")); + int output_rank = inputs[0]->getDimensions().nbDims + 1; if (axis < 0) { - axis = axis + inputs[0]->getDimensions().nbDims + 1; + axis = axis + output_rank; } + // Now, axis is relative to output_rank. + + auto* shape_tensor = Shape(inputs[0]); + std::vector shape_tensor_vec; + for (int i = 0; i < output_rank; i++) { + if (i < axis) { + shape_tensor_vec.push_back(GetEleTensorOfShape(shape_tensor, i)); + } else if (i > axis) { + shape_tensor_vec.push_back(GetEleTensorOfShape(shape_tensor, i - 1)); + } else { + shape_tensor_vec.push_back(Add1DConstantLayer(1)); + } + } + auto* after_shape_tensor = Concat(shape_tensor_vec); + + for (int i = 0; i < input_num; ++i) { + auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *inputs[i]); + reshape_layer->setInput(1, *after_shape_tensor); + inputs[i] = reshape_layer->getOutput(0); + } + + auto* layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, inputs.data(), inputs.size()); + layer->setAxis(axis); - nvinfer1::ILayer* layer = nullptr; -#if IS_TRT_VERSION_GE(6000) - bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); - plugin::StackPluginDynamic* plugin = - new plugin::StackPluginDynamic(axis, input_num, with_fp16); - layer = engine_->AddDynamicPlugin(inputs, input_num, plugin); - PADDLE_ENFORCE_NOT_NULL( - layer, - platform::errors::InvalidArgument( - "trt stack layer in converter could not be created.")); -#else - PADDLE_THROW(platform::errors::Fatal( - "You are running the TRT Dynamic Shape mode, need to confirm that " - "your TRT version is no less than 6.0")); -#endif auto output_name = op_desc.Output("Y").front(); RreplenishLayerAndOutput(layer, "stack", {output_name}, test_mode); - free(inputs); } }; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_stack.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_stack.py index f9641bad34c..cfae56fc2b6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_stack.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_stack.py @@ -41,33 +41,33 @@ class TrtConvertStackTest(TrtLayerAutoScanTest): def generate_input1(attrs: List[Dict[str, Any]], batch): if self.dims == 4: - return np.ones([batch, 3, 24, 24]).astype(np.float32) + return np.random.random([batch, 3, 24, 24]).astype(np.float32) elif self.dims == 3: - return np.ones([batch, 3, 24]).astype(np.float32) + return np.random.random([batch, 3, 24]).astype(np.float32) elif self.dims == 2: - return np.ones([batch, 24]).astype(np.float32) + return np.random.random([batch, 24]).astype(np.float32) elif self.dims == 1: - return np.ones([24]).astype(np.float32) + return np.random.random([24]).astype(np.float32) def generate_input2(attrs: List[Dict[str, Any]], batch): if self.dims == 4: - return np.ones([batch, 3, 24, 24]).astype(np.float32) + return np.random.random([batch, 3, 24, 24]).astype(np.float32) elif self.dims == 3: - return np.ones([batch, 3, 24]).astype(np.float32) + return np.random.random([batch, 3, 24]).astype(np.float32) elif self.dims == 2: - return np.ones([batch, 24]).astype(np.float32) + return np.random.random([batch, 24]).astype(np.float32) elif self.dims == 1: - return np.ones([24]).astype(np.float32) + return np.random.random([24]).astype(np.float32) def generate_input3(attrs: List[Dict[str, Any]], batch): if self.dims == 4: - return np.ones([batch, 3, 24, 24]).astype(np.float32) + return np.random.random([batch, 3, 24, 24]).astype(np.float32) elif self.dims == 3: - return np.ones([batch, 3, 24]).astype(np.float32) + return np.random.random([batch, 3, 24]).astype(np.float32) elif self.dims == 2: - return np.ones([batch, 24]).astype(np.float32) + return np.random.random([batch, 24]).astype(np.float32) elif self.dims == 1: - return np.ones([24]).astype(np.float32) + return np.random.random([24]).astype(np.float32) for dims in [1, 2, 3, 4]: for batch in [1, 4]: -- GitLab