未验证 提交 2f3ad5ab 编写于 作者: F feng_shuai 提交者: GitHub

optimize: vit static shape (#47280)

上级 84273aaa
...@@ -460,9 +460,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -460,9 +460,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.emplace_back(mask_tensor); plugin_inputs.emplace_back(mask_tensor);
// input_2 for plugin // input_2 for plugin
std::vector<int> pos_id = {0}; std::vector<int> pos_id = {0};
int max_batch = 500; int max_batch = 512;
int length = (input_dims.d[1] == -1) ? 1 : input_dims.d[1];
for (int i = 1; i < max_batch; i++) { for (int i = 1; i < max_batch; i++) {
pos_id.push_back(i); pos_id.push_back(i * length);
} }
nvinfer1::ITensor* fake_pos_id_tensor = Add1DConstantLayer(pos_id); nvinfer1::ITensor* fake_pos_id_tensor = Add1DConstantLayer(pos_id);
nvinfer1::ITensor* length_tensor = nvinfer1::ITensor* length_tensor =
...@@ -497,18 +498,26 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -497,18 +498,26 @@ class MultiheadMatMulOpConverter : public OpConverter {
stride.d[0] = 1; stride.d[0] = 1;
size.d[0] = 1; size.d[0] = 1;
nvinfer1::ITensor* pos_id_tensor = (input_dims.d[1] == -1)
? pos_id_layer->getOutput(0)
: fake_pos_id_tensor;
auto* slice_pos_layer = TRT_ENGINE_ADD_LAYER( auto* slice_pos_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *pos_id_layer->getOutput(0), start, size, stride); engine_, Slice, *pos_id_tensor, start, size, stride);
slice_pos_layer->setInput(2, *size_layer->getOutput(0)); slice_pos_layer->setInput(2, *size_layer->getOutput(0));
plugin_inputs.emplace_back(slice_pos_layer->getOutput(0)); plugin_inputs.emplace_back(slice_pos_layer->getOutput(0));
// input_3 for plugin // input_3 for plugin
std::vector<int> data(500, 1); int max_length = (input_dims.d[1] == -1) ? 512 : input_dims.d[1];
std::vector<int> data(max_length, 1);
nvinfer1::ITensor* fake_max_seqlen_tensor = Add1DConstantLayer(data); nvinfer1::ITensor* fake_max_seqlen_tensor = Add1DConstantLayer(data);
auto* slice_max_layer = TRT_ENGINE_ADD_LAYER( auto* slice_max_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *fake_max_seqlen_tensor, start, size, stride); engine_, Slice, *fake_max_seqlen_tensor, start, size, stride);
slice_max_layer->setInput(2, *length_tensor); slice_max_layer->setInput(2, *length_tensor);
plugin_inputs.emplace_back(slice_max_layer->getOutput(0)); nvinfer1::ITensor* max_seqlen_tensor =
(input_dims.d[1] == -1) ? slice_max_layer->getOutput(0)
: fake_max_seqlen_tensor;
plugin_inputs.emplace_back(max_seqlen_tensor);
// plugin_layer // plugin_layer
auto plugin_layer = engine_->network()->addPluginV2( auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin); plugin_inputs.data(), plugin_inputs.size(), *plugin);
......
...@@ -808,7 +808,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -808,7 +808,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
for batch in [2, 4]: for batch in [2, 4]:
self.batch = batch self.batch = batch
for length in [64, 384]: for length in [197]:
self.length = length self.length = length
ops_config = [ ops_config = [
{ {
...@@ -1006,6 +1006,17 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1006,6 +1006,17 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"input_data1": [1, 197, 768], "input_data1": [1, 197, 768],
} }
def generate_static_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 197, 768],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [16, 197, 768],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 197, 768],
}
def clear_dynamic_shape(): def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {} self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {} self.dynamic_shape.min_input_shape = {}
...@@ -1035,6 +1046,21 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1035,6 +1046,21 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
1e-5, 1e-5,
) )
# for static_shape
clear_dynamic_shape()
generate_static_shape(attrs)
self.trt_param.workspace_size = 2013265920
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-3,
1e-3,
)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-5,
1e-5,
)
def add_skip_trt_case(self): def add_skip_trt_case(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册