未验证 提交 a5875319 编写于 作者: W Wangzheee 提交者: GitHub

fix_slice_convert_varlen (#46874)

上级 4596b9a2
......@@ -11,6 +11,7 @@ limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/slice_op_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.h"
namespace paddle {
namespace inference {
......@@ -73,80 +74,116 @@ class SliceOpConverter : public OpConverter {
nvinfer1::ILayer* layer = nullptr;
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
auto nchw_input_dims = input->getDimensions();
nvinfer1::Dims trt_start_dims;
trt_start_dims.nbDims = nchw_input_dims.nbDims;
memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims);
nvinfer1::Dims trt_size_dims = trt_start_dims;
nvinfer1::Dims trt_end_dims = trt_start_dims;
nvinfer1::Dims trt_step_dims = trt_start_dims;
for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1;
// input : [N,C,H,W]
bool has_neg_indices = false;
for (size_t i = 0; i < axes.size(); i++) {
int trt_axis = axes[i];
trt_start_dims.d[trt_axis] = starts[i];
trt_end_dims.d[trt_axis] = ends[i];
if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true;
}
auto* shape_tensor = Shape(input);
auto* start_tensor = Add1DConstantLayer(trt_start_dims);
if (has_neg_indices) {
start_tensor = FixNegIndices(shape_tensor, start_tensor);
}
if (engine_->use_oss() && engine_->with_ernie() &&
input_dims.nbDims == 4) {
std::vector<nvinfer1::ITensor*> plugin_inputs;
if (engine_->with_interleaved()) {
auto* shuffler_slice = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Permutation transpose_embed{2, 1, 0, 3};
shuffler_slice->setSecondTranspose(transpose_embed);
engine_->SetTensorDynamicRange(shuffler_slice->getOutput(0),
out_scale);
shuffler_slice->setName(
("SpecialSlice_interleaved: transpose: (Output: " + output_name +
")")
.c_str());
plugin_inputs.emplace_back(shuffler_slice->getOutput(0));
} else {
plugin_inputs.emplace_back(input);
}
std::string pos_name;
if (engine_->Has("ernie_pos_name")) {
pos_name = engine_->Get<std::string>("ernie_pos_name");
} else {
// hard code for compatibility
pos_name = engine_->network()->getInput(2)->getName();
}
plugin_inputs.emplace_back(
engine_->GetITensor(pos_name)); // cu_seqlens, eval_placeholder_2
// bool ban_fp16 = engine_->disable_trt_plugin_fp16();
plugin::SpecialSlicePluginDynamic* plugin =
new plugin::SpecialSlicePluginDynamic();
layer = engine_->AddDynamicPlugin(
plugin_inputs.data(), plugin_inputs.size(), plugin);
} else {
auto nchw_input_dims = input->getDimensions();
nvinfer1::Dims trt_start_dims;
trt_start_dims.nbDims = nchw_input_dims.nbDims;
memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims);
nvinfer1::Dims trt_size_dims = trt_start_dims;
nvinfer1::Dims trt_end_dims = trt_start_dims;
nvinfer1::Dims trt_step_dims = trt_start_dims;
for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1;
// input : [N,C,H,W]
bool has_neg_indices = false;
for (size_t i = 0; i < axes.size(); i++) {
int trt_axis = axes[i];
trt_start_dims.d[trt_axis] = starts[i];
trt_end_dims.d[trt_axis] = ends[i];
if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true;
}
auto* shape_tensor = Shape(input);
auto* start_tensor = Add1DConstantLayer(trt_start_dims);
if (has_neg_indices) {
start_tensor = FixNegIndices(shape_tensor, start_tensor);
}
std::vector<nvinfer1::ITensor*> end_vec_tensor;
for (int i = 0; i < trt_end_dims.nbDims; i++) {
end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i));
}
std::vector<nvinfer1::ITensor*> end_vec_tensor;
for (int i = 0; i < trt_end_dims.nbDims; i++) {
end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i));
}
for (size_t i = 0; i < axes.size(); i++) {
int trt_axis = axes[i];
if (ends[i] >= 0) {
end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]);
} else {
end_vec_tensor[trt_axis] =
Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i]));
for (size_t i = 0; i < axes.size(); i++) {
int trt_axis = axes[i];
if (ends[i] >= 0) {
end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]);
} else {
end_vec_tensor[trt_axis] =
Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i]));
}
}
}
// CI failed in trt 6015 but success in 7134, may be a trt bug
#if IS_TRT_VERSION_GE(7134)
auto* size_tensor =
Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor);
auto* size_tensor =
Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor);
#else
auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor);
auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor);
#endif
layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims);
layer->setInput(1, *start_tensor);
layer->setInput(2, *size_tensor);
if (decrease_axises.size() > 0) {
std::vector<int32_t> gather_indices;
for (int i = 0; i < trt_size_dims.nbDims; i++) {
if (decrease_axises.end() !=
std::find(decrease_axises.begin(), decrease_axises.end(), i))
continue;
gather_indices.push_back(i);
layer = TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*input,
trt_start_dims,
trt_size_dims,
trt_step_dims);
layer->setInput(1, *start_tensor);
layer->setInput(2, *size_tensor);
if (decrease_axises.size() > 0) {
std::vector<int32_t> gather_indices;
for (int i = 0; i < trt_size_dims.nbDims; i++) {
if (decrease_axises.end() !=
std::find(decrease_axises.begin(), decrease_axises.end(), i))
continue;
gather_indices.push_back(i);
}
if (gather_indices.empty())
gather_indices.push_back(decrease_axises[0]);
auto real_size_tensor = Gather(size_tensor, gather_indices);
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
layer->setInput(1, *real_size_tensor);
}
if (gather_indices.empty())
gather_indices.push_back(decrease_axises[0]);
auto real_size_tensor = Gather(size_tensor, gather_indices);
layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0));
layer->setInput(1, *real_size_tensor);
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
}
#else
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
#endif
} else {
#if IS_TRT_VERSION_GE(6000)
auto chw_input_dims = input->getDimensions();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册