未验证 提交 d8b8c2d8 编写于 作者: X xiaoxiaohehe001 提交者: GitHub

[Paddle Inference] Support split sectionslist and axis = 0 input of trt . (#50957)

* split_list
上级 ccfe7681
......@@ -29,15 +29,15 @@ class SplitOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto inputs = op_desc.Inputs();
auto input_dims = input->getDimensions();
size_t output_num = op_desc.Output("Out").size();
int output_num = op_desc.Output("Out").size();
// Get Attrs
int axis = PADDLE_GET_CONST(int, op_desc.GetAttr("axis"));
int num = 0;
std::vector<int> output_lengths =
PADDLE_GET_CONST(std::vector<int>, op_desc.GetAttr("sections"));
int num = 0;
if (op_desc.HasAttr("num")) {
num = PADDLE_GET_CONST(int, op_desc.GetAttr("num"));
}
......@@ -50,19 +50,34 @@ class SplitOpConverter : public OpConverter {
axis += (axis < 0) ? input_dims.nbDims : -1;
}
bool in_axis_dim_dynamic = false;
nvinfer1::ITensor* avg_len_tensor = nullptr;
bool sections_tensor_list = false;
nvinfer1::ITensor* sections_tensor = nullptr;
// need infer output_lengths
if (num > 0 && output_lengths.empty()) {
if (inputs.find("SectionsTensorList") != inputs.end() &&
op_desc.Input("SectionsTensorList").size() >= 1) {
int32_t sections_size = op_desc.Input("SectionsTensorList").size();
std::vector<nvinfer1::ITensor*> sections_tensors;
for (int32_t i = 0; i < sections_size; ++i) {
sections_tensors.push_back(
engine_->GetITensor(op_desc.Input("SectionsTensorList")[i]));
}
sections_tensor = Concat(sections_tensors);
sections_tensor_list = true;
} else if (!output_lengths.empty()) {
sections_tensor = Add1DConstantLayer(output_lengths);
} else if (num > 0 && output_lengths.empty()) {
if (input_dims.d[axis] > 0) {
int64_t in_axis_dim = input_dims.d[axis];
size_t out_axis_dim = in_axis_dim / num;
for (int i = 0; i < num; ++i) {
output_lengths.push_back(out_axis_dim);
}
sections_tensor = Add1DConstantLayer(output_lengths);
} else {
in_axis_dim_dynamic = true;
auto* num_tensor = Add1DConstantLayer(num);
avg_len_tensor =
sections_tensor =
Div(GetEleTensorOfShape(shape_tensor, axis), num_tensor);
}
}
......@@ -79,20 +94,20 @@ class SplitOpConverter : public OpConverter {
std::iota(gather_indices.begin(), gather_indices.end(), 0);
gather_indices[axis] = gather_indices.size();
std::vector<int32_t> zeros(trt_step_dims.nbDims, 0);
auto* zeros_tensor = Add1DConstantLayer(zeros);
std::vector<int32_t> stride(trt_step_dims.nbDims, 1);
auto zeros_tensor = Add1DConstantLayer(zeros);
auto stride_tensor = Add1DConstantLayer(stride);
// input : [N,C,H,W]
int start_point = 0;
for (size_t i = 0; i < output_num; i++) {
nvinfer1::ITensor* this_len_tensor = nullptr;
nvinfer1::ITensor* start_point_tensor = nullptr;
if (!in_axis_dim_dynamic) {
this_len_tensor = Add1DConstantLayer(output_lengths[i]);
start_point_tensor = Add1DConstantLayer(start_point);
start_point += output_lengths[i];
nvinfer1::ITensor* start_point_tensor = zeros_tensor;
nvinfer1::ITensor* this_len_tensor = zeros_tensor;
for (int i = 0; i < output_num; i++) {
if (sections_tensor_list || !in_axis_dim_dynamic) {
start_point_tensor = Sum(start_point_tensor, this_len_tensor);
this_len_tensor = Gather(sections_tensor, std::vector<int32_t>{i});
} else {
this_len_tensor = avg_len_tensor;
this_len_tensor = sections_tensor;
auto* i_tensor = Add1DConstantLayer(static_cast<int>(i));
start_point_tensor = Prod(i_tensor, avg_len_tensor);
start_point_tensor = Prod(i_tensor, sections_tensor);
}
std::vector<nvinfer1::ITensor*> concat_inputs1 = {zeros_tensor,
......@@ -104,11 +119,12 @@ class SplitOpConverter : public OpConverter {
layer = TRT_ENGINE_ADD_LAYER(engine_,
Slice,
*input,
trt_step_dims,
trt_step_dims,
trt_step_dims);
nvinfer1::Dims{},
nvinfer1::Dims{},
nvinfer1::Dims{});
layer->setInput(1, *start_tensor);
layer->setInput(2, *size_tensor);
layer->setInput(3, *stride_tensor);
auto output_name = op_desc.Output("Out")[i];
RreplenishLayerAndOutput(layer, "split", {output_name}, test_mode);
......@@ -124,7 +140,7 @@ class SplitOpConverter : public OpConverter {
for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1;
// input : [C,H,W]
for (size_t i = 0; i < output_num; i++) {
for (int i = 0; i < output_num; i++) {
trt_start_dims.d[axis] = std::accumulate(
output_lengths.begin(), output_lengths.begin() + i, 0);
trt_size_dims.d[axis] = output_lengths[i];
......@@ -153,7 +169,7 @@ class SplitOpConverter : public OpConverter {
layer = engine_->AddPluginV2Ext(&input, 1, plugin);
}
std::vector<std::string> output_names;
for (size_t i = 0; i < output_num; i++) {
for (int i = 0; i < output_num; i++) {
output_names.push_back(op_desc.Output("Out")[i]);
}
RreplenishLayerAndOutput(layer, "split", output_names, test_mode);
......
......@@ -1079,17 +1079,19 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if (split_inputs.find("SectionsTensorList") != split_inputs.end()) {
if (desc.Input("SectionsTensorList").size() >= 1) {
if (!with_dynamic_shape) {
return false;
}
}
}
if (!desc.HasAttr("axis")) {
return false;
}
int axis = PADDLE_GET_CONST(int, desc.GetAttr("axis"));
if (axis == 0) {
if (!with_dynamic_shape && axis == 0) {
VLOG(3) << "Invalid split axis. Split on batch is not supported in "
"TensorRT";
"TensorRT with static shape";
return false;
}
auto* block = desc.Block();
......
......@@ -70,6 +70,14 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
else:
return False
if self.dims == 2:
if self.batch != 3:
return False
if len(attrs[0]['sections']) != 0 and attrs[0]['axis'] == 0:
if self.dims != 2 or self.batch != 3:
return False
return True
def sample_program_configs(self):
......@@ -81,7 +89,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
elif self.dims == 2:
return np.random.random([batch, 24]).astype(np.float32)
elif self.dims == 1:
return np.random.random([24]).astype(np.float32)
return np.random.random([24]).astype(np.int32)
def generate_AxisTensor(attrs: List[Dict[str, Any]]):
return np.ones([1]).astype(np.int32)
......@@ -204,13 +212,9 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
}
self.dynamic_shape.opt_input_shape = {"split_input": [1, 3, 24]}
elif self.dims == 2:
self.dynamic_shape.min_input_shape = {
"split_input": [1, 24 - 1]
}
self.dynamic_shape.max_input_shape = {
"split_input": [9, 24 + 1]
}
self.dynamic_shape.opt_input_shape = {"split_input": [1, 24]}
self.dynamic_shape.min_input_shape = {"split_input": [3, 24]}
self.dynamic_shape.max_input_shape = {"split_input": [3, 24]}
self.dynamic_shape.opt_input_shape = {"split_input": [3, 24]}
elif self.dims == 1:
self.dynamic_shape.min_input_shape = {"split_input": [24 - 1]}
self.dynamic_shape.max_input_shape = {"split_input": [24 + 1]}
......@@ -223,10 +227,16 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
def generate_trt_nodes_num(attrs, dynamic_shape):
if len(program_config.outputs) == 2:
if dynamic_shape:
return 1, 3
else:
if attrs[0]['axis'] != 0:
return 1, 3
else:
return 0, 4
else:
if dynamic_shape:
return 1, 4
else:
if attrs[0]['axis'] != 0:
return 1, 4
......@@ -276,5 +286,135 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
self.run_test()
class TrtConvertSplitTest2(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
def generate_input1(attrs: List[Dict[str, Any]]):
return np.random.random([3, 3, 3, 24]).astype(np.float32)
for sections in [
[-1, -1, -1],
[1, 1, 1],
]:
for num in [0]:
for axis in [0, 1]:
dics = [
{
"sections": sections,
"num": num,
"axis": axis,
}
]
dics_intput = [
{
"X": ["split_input"],
"SectionsTensorList": [
"shapeT1_data",
"shapeT2_data",
"shapeT3_data",
],
},
]
ops_config = [
{
"op_type": "fill_constant",
"op_inputs": {},
"op_outputs": {"Out": ["shapeT1_data"]},
"op_attrs": {
"dtype": 2,
"str_value": "1",
"shape": [1],
},
},
{
"op_type": "fill_constant",
"op_inputs": {},
"op_outputs": {"Out": ["shapeT2_data"]},
"op_attrs": {
"dtype": 2,
"str_value": "1",
"shape": [1],
},
},
{
"op_type": "fill_constant",
"op_inputs": {},
"op_outputs": {"Out": ["shapeT3_data"]},
"op_attrs": {
"dtype": 2,
"str_value": "1",
"shape": [1],
},
},
{
"op_type": "split",
"op_inputs": dics_intput[0],
"op_outputs": {
"Out": [
"output_var0",
"output_var1",
"output_var2",
]
},
"op_attrs": dics[0],
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"split_input": TensorConfig(
data_gen=partial(generate_input1, dics)
)
},
outputs=["output_var0", "output_var1", "output_var2"],
)
yield program_config
def sample_predictor_configs(
self, program_config
) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"split_input": [1, 3, 3, 24]}
self.dynamic_shape.max_input_shape = {"split_input": [9, 3, 3, 24]}
self.dynamic_shape.opt_input_shape = {"split_input": [3, 3, 3, 24]}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape:
return 1, 4
return 0, 5
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
self.trt_param.max_batch_size = 9
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-3
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册