未验证 提交 2f3ad5ab 编写于 作者: F feng_shuai 提交者: GitHub

optimize: vit static shape (#47280)

上级 84273aaa
......@@ -460,9 +460,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
plugin_inputs.emplace_back(mask_tensor);
// input_2 for plugin
std::vector<int> pos_id = {0};
int max_batch = 500;
int max_batch = 512;
int length = (input_dims.d[1] == -1) ? 1 : input_dims.d[1];
for (int i = 1; i < max_batch; i++) {
pos_id.push_back(i);
pos_id.push_back(i * length);
}
nvinfer1::ITensor* fake_pos_id_tensor = Add1DConstantLayer(pos_id);
nvinfer1::ITensor* length_tensor =
......@@ -497,18 +498,26 @@ class MultiheadMatMulOpConverter : public OpConverter {
stride.d[0] = 1;
size.d[0] = 1;
nvinfer1::ITensor* pos_id_tensor = (input_dims.d[1] == -1)
? pos_id_layer->getOutput(0)
: fake_pos_id_tensor;
auto* slice_pos_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *pos_id_layer->getOutput(0), start, size, stride);
engine_, Slice, *pos_id_tensor, start, size, stride);
slice_pos_layer->setInput(2, *size_layer->getOutput(0));
plugin_inputs.emplace_back(slice_pos_layer->getOutput(0));
// input_3 for plugin
std::vector<int> data(500, 1);
int max_length = (input_dims.d[1] == -1) ? 512 : input_dims.d[1];
std::vector<int> data(max_length, 1);
nvinfer1::ITensor* fake_max_seqlen_tensor = Add1DConstantLayer(data);
auto* slice_max_layer = TRT_ENGINE_ADD_LAYER(
engine_, Slice, *fake_max_seqlen_tensor, start, size, stride);
slice_max_layer->setInput(2, *length_tensor);
plugin_inputs.emplace_back(slice_max_layer->getOutput(0));
nvinfer1::ITensor* max_seqlen_tensor =
(input_dims.d[1] == -1) ? slice_max_layer->getOutput(0)
: fake_max_seqlen_tensor;
plugin_inputs.emplace_back(max_seqlen_tensor);
// plugin_layer
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin);
......
......@@ -808,7 +808,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
for batch in [2, 4]:
self.batch = batch
for length in [64, 384]:
for length in [197]:
self.length = length
ops_config = [
{
......@@ -1006,6 +1006,17 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"input_data1": [1, 197, 768],
}
def generate_static_shape(attrs):
self.dynamic_shape.min_input_shape = {
"input_data1": [1, 197, 768],
}
self.dynamic_shape.max_input_shape = {
"input_data1": [16, 197, 768],
}
self.dynamic_shape.opt_input_shape = {
"input_data1": [1, 197, 768],
}
def clear_dynamic_shape():
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.min_input_shape = {}
......@@ -1035,6 +1046,21 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
1e-5,
)
# for static_shape
clear_dynamic_shape()
generate_static_shape(attrs)
self.trt_param.workspace_size = 2013265920
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-3,
1e-3,
)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(), (
1e-5,
1e-5,
)
def add_skip_trt_case(self):
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册