未验证 提交 a5875319 编写于 作者: W Wangzheee 提交者: GitHub

fix_slice_convert_varlen (#46874)

上级 4596b9a2
...@@ -11,6 +11,7 @@ limitations under the License. */ ...@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h" #include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle { namespace paddle {
namespace inference { namespace inference {
...@@ -73,7 +74,39 @@ class SliceOpConverter : public OpConverter { ...@@ -73,7 +74,39 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000) if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
if (engine_->with_interleaved()) {
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Permutation transpose_embed{2, 1, 0, 3};
shuffler_slice->setSecondTranspose(transpose_embed);
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale);
shuffler_slice->setName(
("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")")
.c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
} else {
plugin_inputs.emplace_back(input);
}
std::string pos_name;
if (engine_->Has("ernie_pos_name")) {
pos_name = engine_->Get<std::string>("ernie_pos_name");
} else {
// hard code for compatibility
pos_name = engine_->network()->getInput(2)->getName();
}
plugin_inputs.emplace_back(
engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
new plugin::SpecialSlicePluginDynamic();
layer = engine_->AddDynamicPlugin(
plugin_inputs.data(), plugin_inputs.size(), plugin);
} else {
auto nchw_input_dims = input->getDimensions(); auto nchw_input_dims = input->getDimensions();
nvinfer1::Dims trt_start_dims; nvinfer1::Dims trt_start_dims;
trt_start_dims.nbDims = nchw_input_dims.nbDims; trt_start_dims.nbDims = nchw_input_dims.nbDims;
...@@ -120,8 +153,12 @@ class SliceOpConverter : public OpConverter { ...@@ -120,8 +153,12 @@ class SliceOpConverter : public OpConverter {
auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor); auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor);
#endif #endif
layer = TRT_ENGINE_ADD_LAYER( layer = TRT_ENGINE_ADD_LAYER(engine_,
engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); Slice,
*input,
trt_start_dims,
trt_size_dims,
trt_step_dims);
layer->setInput(1, *start_tensor); layer->setInput(1, *start_tensor);
layer->setInput(2, *size_tensor); layer->setInput(2, *size_tensor);
...@@ -139,14 +176,14 @@ class SliceOpConverter : public OpConverter { ...@@ -139,14 +176,14 @@ class SliceOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
layer->setInput(1, *real_size_tensor); layer->setInput(1, *real_size_tensor);
} }
#else
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0]; int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16); starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin); layer = engine_->AddDynamicPlugin(&input, 1, plugin);
#endif }
} else { } else {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
auto chw_input_dims = input->getDimensions(); auto chw_input_dims = input->getDimensions();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册