未验证 提交 8bc1c82d 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Inference] add squeeze2/unsqueeze2 trt layer (#42782)

* add squeeze2

* add squeeze

* add squeeze2,unsqueeze2

* merge develop

* fix format

* add conditions for squeeze2 and unsqueeze in op_teller

* merge develop

* add squeeze unsqueeze

* add squeeze unsqueeze

* add squeeze unsqueeze

* remove unsqueeze2_eltwise_fuse_pass

* add squeeze/unsqueeze
上级 5cf3f898
......@@ -1962,6 +1962,8 @@ USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2)
USE_TRT_CONVERTER(squeeze2)
USE_TRT_CONVERTER(unsqueeze2)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -61,7 +61,9 @@ list(
transformer_input_convert_op.cc
remove_padding_op.cc
recover_padding_op.cc
top_k_op.cc)
top_k_op.cc
squeeze2_op.cc
unsqueeze2_op.cc)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class Squeeze2OpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid squeeze2 op to tensorrt shuffle layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
auto output_name = op_desc.Output("Out")[0];
// Get Attrs
std::vector<int> axes =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
PADDLE_ENFORCE_GT(
axes.size(), 0,
platform::errors::InvalidArgument(
"Attr(axes).size should be > 0 in squeeze2 op in TensorRT,"
"but received axes.size() = %d.",
axes.size()));
std::vector<bool> should_squeeze(input_dims.nbDims, false);
for (size_t i = 0; i < axes.size(); i++) {
if (engine_->with_dynamic_shape()) {
axes[i] += (axes[i] < 0) ? input_dims.nbDims : 0;
} else {
axes[i] += (axes[i] < 0) ? input_dims.nbDims : -1;
}
should_squeeze[axes[i]] = true;
}
nvinfer1::Dims trt_out_dims;
trt_out_dims.nbDims = 0;
std::vector<int32_t> gather_indices;
for (size_t i = 0; i < should_squeeze.size(); i++) {
if (should_squeeze[i]) continue;
gather_indices.push_back(i);
// for static shape
trt_out_dims.d[trt_out_dims.nbDims] = input_dims.d[i];
trt_out_dims.nbDims++;
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (engine_->with_dynamic_shape()) {
auto* shape_tensor = Shape(input);
auto* real_shape_tensor = Gather(shape_tensor, gather_indices);
layer->setInput(1, *real_shape_tensor);
} else {
layer->setReshapeDimensions(trt_out_dims);
}
RreplenishLayerAndOutput(layer, "squeeze2", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(squeeze2, Squeeze2OpConverter);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class Unsqueeze2OpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(4) << "convert a fluid unsqueeze2 op to tensorrt shuffle layer";
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto input_dims = input->getDimensions();
auto output_name = op_desc.Output("Out")[0];
// Get Attrs
std::vector<int> axes =
BOOST_GET_CONST(std::vector<int>, op_desc.GetAttr("axes"));
PADDLE_ENFORCE_GT(
axes.size(), 0,
platform::errors::InvalidArgument(
"Attr(axes).size should be > 0 in unsqueeze2 op in TensorRT,"
"but received axes.size() = %d.",
axes.size()));
std::vector<bool> should_unsqueeze(input_dims.nbDims + axes.size(), false);
int cur_out_rank = input_dims.nbDims;
for (size_t i = 0; i < axes.size(); i++) {
cur_out_rank++;
if (engine_->with_dynamic_shape()) {
axes[i] += (axes[i] < 0) ? cur_out_rank : 0;
} else {
axes[i] += (axes[i] < 0) ? cur_out_rank : -1;
}
// axes[i] is relative to cur_out_rank
// we make [axes[i], cur_out_rank - 2] shift right
// and make (axes[i]) to true!
for (int j = cur_out_rank - 1; j > axes[i]; j--) {
should_unsqueeze[j] = should_unsqueeze[j - 1];
}
if (axes[i] >= cur_out_rank)
should_unsqueeze[cur_out_rank - 1] = true;
else
should_unsqueeze[axes[i]] = true;
}
nvinfer1::Dims trt_out_dims;
trt_out_dims.nbDims = should_unsqueeze.size();
std::vector<int32_t> gather_indices;
int in_rank_i = 0;
for (size_t i = 0; i < should_unsqueeze.size(); i++) {
if (should_unsqueeze[i]) {
trt_out_dims.d[i] = 1;
gather_indices.push_back(input_dims.nbDims);
continue;
}
trt_out_dims.d[i] = input_dims.d[in_rank_i];
gather_indices.push_back(in_rank_i);
in_rank_i++;
}
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (engine_->with_dynamic_shape()) {
auto* shape_tensor = Shape(input);
std::vector<int32_t> all_one(axes.size(), 1);
auto* all_one_tensor = Add1DConstantLayer(all_one);
std::vector<nvinfer1::ITensor*> concat_inputs = {shape_tensor,
all_one_tensor};
auto* real_shape_tensor = Gather(Concat(concat_inputs), gather_indices);
layer->setInput(1, *real_shape_tensor);
} else {
layer->setReshapeDimensions(trt_out_dims);
}
RreplenishLayerAndOutput(layer, "unsqueeze2", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(unsqueeze2, Unsqueeze2OpConverter);
......@@ -153,7 +153,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
"remove_padding",
"squeeze2",
"unsqueeze2"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -242,7 +244,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"multiclass_nms3",
"transformer_input_convert",
"recover_padding",
"remove_padding"};
"remove_padding",
"squeeze2",
"unsqueeze2"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
......@@ -887,6 +891,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if (op_type == "squeeze2") {
std::vector<int> axes;
if (desc.HasAttr("axes")) {
axes = BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
}
if (axes.size() == 0) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
}
if (!with_dynamic_shape) {
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
"supported in static shape";
return false;
}
}
}
if (op_type == "unsqueeze2") {
std::vector<int> axes;
if (desc.HasAttr("axes")) {
axes = BOOST_GET_CONST(std::vector<int>, desc.GetAttr("axes"));
}
if (axes.size() == 0) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
}
if (!with_dynamic_shape) {
if (std::find(axes.begin(), axes.end(), 0) != axes.end()) {
VLOG(3) << "Invalid squeeze axes. Axes having batch axis is not "
"supported in static shape";
return false;
}
}
}
if (op_type == "batch_norm") {
const std::vector<std::string> bn_inputs = {"X", "Bias", "Mean", "Scale",
"Variance"};
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertSplitTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if len(inputs['in_data'].shape) <= max(attrs[0]['axes']):
return False
return True
def sample_program_configs(self):
for dims in [2, 3, 4]:
for batch in [3, 4]:
for axes in [[2], [2, 3], [-1]]:
self.batch = batch
self.dims = dims
self.axes = axes
dics = [{"axes": axes}]
ops_config = [{
"op_type": "squeeze2",
"op_inputs": {
"X": ["in_data"]
},
"op_outputs": {
"Out": ["out_data"],
"XShape": ["XShape_data"]
},
"op_attrs": dics[0]
}]
# new_axes is the update of axes
new_axes = list(axes)
for i in range(len(new_axes)):
if (new_axes[i] < 0):
new_axes[i] += dims
if (max(new_axes) >= dims):
continue
# generate input data
self.input_shape = [1] * dims
for i in range(dims):
self.input_shape[i] = np.random.randint(1, 20)
def generate_input1(attrs: List[Dict[str, Any]], batch):
self.input_shape[0] = batch
for i in new_axes:
self.input_shape[i] = 1
return np.random.random(self.input_shape).astype(
np.float32)
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"in_data":
TensorConfig(
data_gen=partial(generate_input1, dics, batch))
},
outputs=["out_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
max_shape = list(self.input_shape)
min_shape = list(self.input_shape)
opt_shape = list(self.input_shape)
for i in range(len(self.input_shape)):
max_shape[i] = max_shape[i] + 1
self.dynamic_shape.min_input_shape = {"in_data": min_shape}
self.dynamic_shape.max_input_shape = {"in_data": max_shape}
self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
self.trt_param.max_batch_size = 9
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertSplitTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
for dims in [2, 3, 4]:
for batch in [3, 4]:
for axes in [[-2, 3], [1], [2], [2, 3]]:
self.batch = batch
self.dims = dims
self.axes = axes
dics = [{"axes": axes}]
ops_config = [{
"op_type": "unsqueeze2",
"op_inputs": {
"X": ["in_data"]
},
"op_outputs": {
"Out": ["out_data"],
"XShape": ["XShape_data"]
},
"op_attrs": dics[0]
}]
# generate input data
self.input_shape = [1] * dims
for i in range(dims):
self.input_shape[i] = np.random.randint(1, 20)
def generate_input1(attrs: List[Dict[str, Any]], batch):
self.input_shape[0] = batch
return np.random.random(self.input_shape).astype(
np.float32)
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"in_data":
TensorConfig(
data_gen=partial(generate_input1, dics, batch))
},
outputs=["out_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
max_shape = list(self.input_shape)
min_shape = list(self.input_shape)
opt_shape = list(self.input_shape)
for i in range(len(self.input_shape)):
max_shape[i] = max_shape[i] + 1
self.dynamic_shape.min_input_shape = {"in_data": min_shape}
self.dynamic_shape.max_input_shape = {"in_data": max_shape}
self.dynamic_shape.opt_input_shape = {"in_data": opt_shape}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
self.trt_param.max_batch_size = 9
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册