未验证 提交 f403fb69 编写于 作者: F feng_shuai 提交者: GitHub

add trt supoort for slice op (#41467)

* add trt supoort for slice op

* fix:output dims bug

* fix: test

* fix:for c++ coverage

* fix:c++ coverage

* fix: fix test bug

* fix: CI test
上级 137dc3e3
......@@ -44,6 +44,8 @@ class SliceOpConverter : public OpConverter {
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("starts"));
std::vector<int> ends =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("ends"));
std::vector<int> decrease_axises =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("decrease_axis"));
auto input_dims = input->getDimensions();
if (!engine_->with_dynamic_shape()) {
......@@ -107,8 +109,10 @@ class SliceOpConverter : public OpConverter {
} else {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
plugin::SlicePluginDynamic* plugin =
new plugin::SlicePluginDynamic(starts, ends, axes, with_fp16);
int decrease_axis =
decrease_axises.size() == 0 ? -1 : decrease_axises[0];
plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic(
starts, ends, axes, decrease_axis, with_fp16);
layer = engine_->AddDynamicPlugin(&input, 1, plugin);
}
} else {
......
......@@ -930,12 +930,18 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (desc.HasAttr("decrease_axis")) {
std::vector<int> decrease_axis =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("decrease_axis"));
if (with_dynamic_shape) {
if (decrease_axis.size() > 1) {
return false;
}
} else {
if (decrease_axis.size() > 0) {
VLOG(3) << "Invalid slice decrease_axis. decrease_axis.size() > 0"
"is not supported in TensorRT";
return false;
}
}
}
if (!desc.HasAttr("axes") || !desc.HasAttr("starts") ||
!desc.HasAttr("ends")) {
......@@ -1054,16 +1060,14 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
if (desc.Input("Ids").size() != desc.Input("Embs").size()) {
VLOG(3) << "The id and emb size of fused EmbEltwiseLayerNormOp "
"should be same ";
return false;
}
}
if (op_type == "fused_preln_embedding_eltwise_layernorm") {
if (!with_dynamic_shape) {
VLOG(3)
<< "fused_preln_embedding_eltwise_layernorm should run on dynamic "
VLOG(3) << "fused_preln_embedding_eltwise_layernorm should run on "
"dynamic "
"shape mode.";
return false;
}
......@@ -1454,7 +1458,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
const auto y_shape = y_var_desc->GetShape();
if (y_shape.size() != 2) {
VLOG(3)
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes = "
<< " input_y(fc_op)'shapes must be 2, but input_y(fc_op)'shapes =
"
<< y_shape.size();
return false;
}
......@@ -1598,8 +1603,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
#else
if (dtype != framework::proto::VarType::FP32) {
VLOG(3)
<< "reduce op input data type must be float32 using TensorRT < 7.0";
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0";
return false;
}
#endif
......
......@@ -205,8 +205,9 @@ void SlicePlugin::serialize(void *buffer) const TRT_NOEXCEPT {
#if IS_TRT_VERSION_GE(6000)
SlicePluginDynamic::SlicePluginDynamic(std::vector<int> starts,
std::vector<int> ends,
std::vector<int> axes, bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes) {
std::vector<int> axes, int decrease_axis,
bool with_fp16)
: starts_(starts), ends_(ends), axes_(axes), decrease_axis_(decrease_axis) {
with_fp16_ = with_fp16;
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
......@@ -217,6 +218,7 @@ SlicePluginDynamic::SlicePluginDynamic(void const *serialData,
DeserializeValue(&serialData, &serialLength, &starts_);
DeserializeValue(&serialData, &serialLength, &ends_);
DeserializeValue(&serialData, &serialLength, &axes_);
DeserializeValue(&serialData, &serialLength, &decrease_axis_);
DeserializeValue(&serialData, &serialLength, &with_fp16_);
cudaEventCreate(&copy_event_);
cudaStreamCreate(&copy_stream_);
......@@ -233,7 +235,8 @@ int SlicePluginDynamic::initialize() TRT_NOEXCEPT { return 0; }
size_t SlicePluginDynamic::getSerializationSize() const TRT_NOEXCEPT {
size_t size = SerializedSize(starts_) + SerializedSize(ends_) +
SerializedSize(axes_) + SerializedSize(with_fp16_);
SerializedSize(axes_) + SerializedSize(decrease_axis_) +
SerializedSize(with_fp16_);
return size;
}
......@@ -242,6 +245,7 @@ void SlicePluginDynamic::serialize(void *buffer) const TRT_NOEXCEPT {
SerializeValue(&buffer, starts_);
SerializeValue(&buffer, ends_);
SerializeValue(&buffer, axes_);
SerializeValue(&buffer, decrease_axis_);
SerializeValue(&buffer, with_fp16_);
}
......@@ -265,6 +269,17 @@ nvinfer1::DimsExprs SlicePluginDynamic::getOutputDimensions(
ret.d[axes_[i]] = expr_builder.constant(end - start);
#endif
}
if (decrease_axis_ != -1) {
nvinfer1::DimsExprs res;
res.nbDims = ret.nbDims - 1;
int j = 0;
for (size_t i = 0; i < in_dims.nbDims; i++) {
if (decrease_axis_ == i) continue;
res.d[j++] = expr_builder.operation(nvinfer1::DimensionOperation::kMAX,
*expr_builder.constant(0), *ret.d[i]);
}
return res;
}
return ret;
}
......@@ -318,6 +333,10 @@ int SlicePluginDynamic::enqueue(const nvinfer1::PluginTensorDesc *input_desc,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_dims = input_desc[0].dims;
auto out_dims = output_desc[0].dims;
if (decrease_axis_ != -1) {
out_dims = input_dims;
out_dims.d[decrease_axis_] = 1;
}
auto num_dims = input_dims.nbDims;
size_t out_num = ProductDim(out_dims);
......
......@@ -88,10 +88,12 @@ REGISTER_TRT_PLUGIN_V2(SlicePluginCreator);
class SlicePluginDynamic : public DynamicPluginTensorRT {
public:
explicit SlicePluginDynamic(std::vector<int> starts, std::vector<int> ends,
std::vector<int> axes, bool with_fp16);
std::vector<int> axes, int decrease_axis,
bool with_fp16);
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new SlicePluginDynamic(starts_, ends_, axes_, with_fp16_);
return new SlicePluginDynamic(starts_, ends_, axes_, decrease_axis_,
with_fp16_);
}
SlicePluginDynamic(void const* serialData, size_t serialLength);
......@@ -140,6 +142,7 @@ class SlicePluginDynamic : public DynamicPluginTensorRT {
std::vector<int> starts_;
std::vector<int> ends_;
std::vector<int> axes_;
int decrease_axis_;
int* offset_temp_data_{nullptr};
cudaEvent_t copy_event_;
cudaStream_t copy_stream_;
......
......@@ -55,11 +55,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.ones([1, 3, 64, 64]).astype(np.float32)
return np.ones([6, 6, 64, 64]).astype(np.float32)
for axes in [[0, 1], [1, 3], [2, 3]]:
for starts in [[0, 1], [-4, -3]]:
for ends in [[2, 2], [-1, -2], [5, 5]]:
for starts in [[0, 1]]:
for ends in [[2, 2], [5, 5]]:
for decrease_axis in [[], [1], [2], [-1], [-100]]:
for infer_flags in [[-1]]:
dics = [{
......@@ -97,8 +97,8 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [8, 8, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [6, 6, 64, 64]}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
......@@ -107,7 +107,11 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape):
inputs = program_config.inputs
if len(attrs[0]["decrease_axis"]) != 0:
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) == 0:
return 1, 2
if dynamic_shape == True and len(attrs[0]["decrease_axis"]) != 1:
return 0, 3
if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0:
return 0, 3
if dynamic_shape:
for i in range(len(attrs[0]["starts"])):
......@@ -123,7 +127,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
self.trt_param.max_batch_size = 9
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
......@@ -146,7 +150,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest):
# TODO(inference): fix.
# trt6 and trt7.1 has bug.
# trt7.2 deserialize has bug.
# self.run_test()
self.run_test()
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册