未验证 提交 7ec1e9af 编写于 作者: F feng_shuai 提交者: GitHub

add trt supoort for slice op (#41467) (#41911)

上级 15d30815
...@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter { ...@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
std::vector<int> ends = std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends")); BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
std::vector<int> decrease_axises =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("decrease_axis"));
auto input_dims = input->getDimensions(); auto input_dims = input->getDimensions();
if (!engine_->with_dynamic_shape()) { if (!engine_->with_dynamic_shape()) {
...@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter { ...@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
} else { } else {
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin = int decrease_axis =
new plugin::SlicePluginDynamic(starts, ends, axes, with_fp16); decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin); layer = engine_->AddDynamicPlugin(&input, 1, plugin);
} }
} else { } else {
......
...@@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -930,10 +930,16 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (desc.HasAttr("decrease_axis")) { if (desc.HasAttr("decrease_axis")) {
std::vector<int> decrease_axis = std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (decrease_axis.size() > 0) { if (with_dynamic_shape) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0" if (decrease_axis.size() > 1) {
"is not supported in TensorRT"; return false;
return false; }
} else {
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false;
}
} }
} }
...@@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1054,17 +1060,15 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false; return false;
} }
if (desc.Input("Ids").size() != desc.Input("Embs").size()) { if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
VLOG(3) << "The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same ";
return false; return false;
} }
} }
if (op_type == "fused_preln_embedding_eltwise_layernorm") { if (op_type == "fused_preln_embedding_eltwise_layernorm") {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
VLOG(3) VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on "
<< "fused_preln_embedding_eltwise_layernorm should run on dynamic " "dynamic "
"shape mode."; "shape mode.";
return false; return false;
} }
if (desc.Input("Ids").size() != desc.Input("Embs").size()) { if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
...@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
const auto y_shape = y_var_desc->GetShape(); const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) { if (y_shape.size() != 2) {
VLOG(3) VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = " << " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes =
"
<< y_shape.size(); << y_shape.size();
return false; return false;
} }
...@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
#else #else
if (dtype != framework::proto::VarType::FP32) { if (dtype != framework::proto::VarType::FP32) {
VLOG(3) VLOG(3) << "reduce op input data type must be float32 using TensorRT "
<< "reduce op input data type must be float32 using TensorRT < 7.0"; "< 7.0";
return false; return false;
} }
#endif #endif
......
...@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT { ...@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000) #if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts, SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
std::vector<int> ends, std::vector<int> ends,
std::vector<int> axes, bool with_fp16) std::vector<int> axes, int decrease_axis,
: starts_(starts), ends_(ends), axes_(axes) { bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16; with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_); cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_); cudaStreamCreate(&copy_stream_);
...@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData, ...@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &starts_); DeserializeValue(&serialData, &serialLength, &starts_);
DeserializeValue(&serialData, &serialLength, &ends_); DeserializeValue(&serialData, &serialLength, &ends_);
DeserializeValue(&serialData, &serialLength, &axes_); DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &decrease_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_); DeserializeValue(&serialData, &serialLength, &with_fp16_);
cudaEventCreate(&copy_event_); cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_); cudaStreamCreate(&copy_stream_);
...@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; } ...@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT { size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) + size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(with_fp16_); SerializedSize(axes_) + SerializedSize(decrease_axis_) +
SerializedSize(with_fp16_);
return size; return size;
} }
...@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT { ...@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, starts_); SerializeValue(&buffer, starts_);
SerializeValue(&buffer, ends_); SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_); SerializeValue(&buffer, axes_);
SerializeValue(&buffer, decrease_axis_);
SerializeValue(&buffer, with_fp16_); SerializeValue(&buffer, with_fp16_);
} }
...@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions( ...@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
ret.d[axes_[i]] = expr_builder.constant(end - start); ret.d[axes_[i]] = expr_builder.constant(end - start);
#endif #endif
} }
if (decrease_axis_ != -1) {
nvinfer1::DimsExprs res;
res.nbDims = ret.nbDims - 1;
int j = 0;
for (size_t i = 0; i < in_dims.nbDims; i++) {
if (decrease_axis_ == i) continue;
res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX,
*expr_builder.constant(0), *ret.d[i]);
}
return res;
}
return ret; return ret;
} }
...@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc, ...@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaStream_t stream) TRT_NOEXCEPT { cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims; auto input_dims = input_desc[0].dims;
auto out_dims = output_desc[0].dims; auto out_dims = output_desc[0].dims;
if (decrease_axis_ != -1) {
out_dims = input_dims;
out_dims.d[decrease_axis_] = 1;
}
auto num_dims = input_dims.nbDims; auto num_dims = input_dims.nbDims;
size_t out_num = ProductDim(out_dims); size_t out_num = ProductDim(out_dims);
......
...@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator); ...@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
class SlicePluginDynamic : public DynamicPluginTensorRT { class SlicePluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends, explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16); std::vector<int> axes, int decrease_axis,
bool with_fp16);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override { nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new SlicePluginDynamic(starts_, ends_, axes_, with_fp16_); return new SlicePluginDynamic(starts_, ends_, axes_, decrease_axis_,
with_fp16_);
} }
SlicePluginDynamic(void const* serialData, size_t serialLength); SlicePluginDynamic(void const* serialData, size_t serialLength);
...@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT { ...@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std::vector<int> starts_; std::vector<int> starts_;
std::vector<int> ends_; std::vector<int> ends_;
std::vector<int> axes_; std::vector<int> axes_;
int decrease_axis_;
int* offset_temp_data_{nullptr}; int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_; cudaEvent_t copy_event_;
cudaStream_t copy_stream_; cudaStream_t copy_stream_;
......
...@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]): def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 3, 64, 64]).astype(np.float32) return np.ones([6, 6, 64, 64]).astype(np.float32)
for axes in [[0, 1], [1, 3], [2, 3]]: for axes in [[0, 1], [1, 3], [2, 3]]:
for starts in [[0, 1], [-4, -3]]: for starts in [[0, 1]]:
for ends in [[2, 2], [-1, -2], [5, 5]]: for ends in [[2, 2], [5, 5]]:
for decrease_axis in [[], [1], [2], [-1], [-100]]: for decrease_axis in [[], [1], [2], [-1], [-100]]:
for infer_flags in [[-1]]: for infer_flags in [[-1]]:
dics = [{ dics = [{
...@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self, program_config) -> (paddle_infer.Config, List[int], float): self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs): def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]} self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]}
def clear_dynamic_shape(): def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {} self.dynamic_shape.min_input_shape = {}
...@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape): def generate_trt_nodes_num(attrs, dynamic_shape):
inputs = program_config.inputs inputs = program_config.inputs
if len(attrs[0]["decrease_axis"]) != 0: if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0:
return 1, 2
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1:
return 0, 3
if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0:
return 0, 3 return 0, 3
if dynamic_shape: if dynamic_shape:
for i in range(len(attrs[0]["starts"])): for i in range(len(attrs[0]["starts"])):
...@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
program_config.ops[i].attrs program_config.ops[i].attrs
for i in range(len(program_config.ops)) for i in range(len(program_config.ops))
] ]
self.trt_param.max_batch_size = 9
# for static_shape # for static_shape
clear_dynamic_shape() clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
...@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): ...@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
# TODO(inference): fix. # TODO(inference): fix.
# trt6 and trt7.1 has bug. # trt6 and trt7.1 has bug.
# trt7.2 deserialize has bug. # trt7.2 deserialize has bug.
# self.run_test() self.run_test()
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册