未验证 提交 dd304f31 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] add reduce max for trt (#48684)

* add reduce max for trt
上级 0c7f3575
......@@ -2318,9 +2318,10 @@ USE_TRT_CONVERTER(nearest_interp_v2);
USE_TRT_CONVERTER(bilinear_interp_v2);
USE_TRT_CONVERTER(reshape);
USE_TRT_CONVERTER(reshape2);
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reduce_mean);
USE_TRT_CONVERTER(reduce_max);
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
......
......@@ -42,12 +42,7 @@ class ReduceOpConverter : public OpConverter {
bool test_mode) override {
VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer";
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ReduceOperation reduce_type = nvinfer1::ReduceOperation::kSUM;
if (op_type == "reduce_sum") {
reduce_type = nvinfer1::ReduceOperation::kSUM;
} else if (op_type == "reduce_mean") {
reduce_type = nvinfer1::ReduceOperation::kAVG;
}
auto reduce_type = ops_.find(op_type);
auto* x = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims input_shape = x->getDimensions();
......@@ -64,8 +59,12 @@ class ReduceOpConverter : public OpConverter {
for (int i = 0; i < input_dims; ++i) {
reduce_dim |= 1 << i;
}
layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *x, reduce_type, reduce_dim, keep_dim);
layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*x,
reduce_type->second.front(),
reduce_dim,
keep_dim);
} else {
auto CvtToBitMask = [&](const std::vector<int32_t>& dims) -> uint32_t {
uint32_t res = 0;
......@@ -79,8 +78,12 @@ class ReduceOpConverter : public OpConverter {
}
return res;
};
layer = TRT_ENGINE_ADD_LAYER(
engine_, Reduce, *x, reduce_type, CvtToBitMask(dim), keep_dim);
layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*x,
reduce_type->second.front(),
CvtToBitMask(dim),
keep_dim);
}
auto output_name = op_desc.Output("Out")[0];
......@@ -91,6 +94,16 @@ class ReduceOpConverter : public OpConverter {
protected:
std::string op_type;
static const std::unordered_map<std::string,
std::vector<nvinfer1::ReduceOperation>>
ops_;
};
const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>>
ReduceOpConverter::ops_ = {
{"reduce_mean", {nvinfer1::ReduceOperation::kAVG}},
{"reduce_sum", {nvinfer1::ReduceOperation::kSUM}},
{"reduce_max", {nvinfer1::ReduceOperation::kMAX}},
};
class ReduceSumOpConverter : public ReduceOpConverter {
......@@ -103,9 +116,14 @@ class ReduceMeanOpConverter : public ReduceOpConverter {
ReduceMeanOpConverter() { op_type = "reduce_mean"; }
};
class ReduceMaxOpConverter : public ReduceOpConverter {
public:
ReduceMaxOpConverter() { op_type = "reduce_max"; }
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(reduce_sum, ReduceSumOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter);
......@@ -2038,7 +2038,8 @@ struct SimpleOpTypeSetTeller : public Teller {
const auto x_shape = x_var_desc->GetShape();
}
if (op_type == "reduce_sum" || op_type == "reduce_mean") {
if (op_type == "reduce_sum" || op_type == "reduce_mean" ||
op_type == "reduce_max") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in "
......@@ -2470,8 +2471,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_max",
"reduce_mean",
"reduce_sum",
"conv3d",
"conv3d_transpose",
"mish",
......@@ -2610,8 +2612,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"affine_channel",
"nearest_interp",
"anchor_generator",
"reduce_sum",
"reduce_max",
"reduce_mean",
"reduce_sum",
"conv3d",
"conv3d_transpose",
"mish",
......
......@@ -23,7 +23,7 @@ from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertReduceMeanTest(TrtLayerAutoScanTest):
class TrtConvertReduceTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
attrs = [
......@@ -66,44 +66,51 @@ class TrtConvertReduceMeanTest(TrtLayerAutoScanTest):
]:
for reduce_all in [True, False]:
for out_dtype in [-1, 2, 5]:
dics = [
{
"keep_dim": keep_dim,
"dim": dim,
"reduce_all": reduce_all,
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
]
ops_config = [
{
"op_type": "reduce_mean",
"op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["reduce_output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input1, out_dtype, dics
for op_type in [
"reduce_max",
"reduce_mean",
"reduce_sum",
]:
dics = [
{
"keep_dim": keep_dim,
"dim": dim,
"reduce_all": reduce_all,
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
]
ops_config = [
{
"op_type": op_type,
"op_inputs": {"X": ["input_data"]},
"op_outputs": {
"Out": ["reduce_output_data"]
},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input1, out_dtype, dics
)
)
)
},
outputs=["reduce_output_data"],
)
},
outputs=["reduce_output_data"],
)
if not self.is_program_valid(program_config):
continue
if not self.is_program_valid(program_config):
continue
yield program_config
yield program_config
def sample_predictor_configs(
self, program_config
......@@ -139,22 +146,22 @@ class TrtConvertReduceMeanTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), 1e-5
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (5e-4, 5e-4)
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), 1e-5
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (5e-4, 5e-4)
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from typing import Any, Dict, List
import numpy as np
from program_config import ProgramConfig, TensorConfig
from trt_layer_auto_scan_test import TrtLayerAutoScanTest
import paddle.inference as paddle_infer
class TrtConvertReduceSumTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
inputs = program_config.inputs
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# dim should be in (-rank, rank), and not NONE
rank = len(inputs['input_data'].shape)
for x in attrs[0]["dim"]:
if x >= rank or x <= -rank:
return False
if len(attrs[0]["dim"]) == 0:
return False
return True
def sample_program_configs(self):
def generate_input1(dtype, attrs: List[Dict[str, Any]]):
if dtype == -1 or dtype == 5:
return np.random.random([1, 3, 32, 32]).astype(np.float32)
elif dtype == 2:
return np.random.random([1, 3, 32, 32]).astype(np.int32)
for keep_dim in [True, False]:
for dim in [
[],
[1],
[0],
[0, 1],
[1, 2, 3],
[-2, 0, 3],
[-3],
[-4, 1],
[3, 4, 5],
]:
for reduce_all in [True, False]:
for out_dtype in [-1, 2, 5]:
dics = [
{
"keep_dim": keep_dim,
"dim": dim,
"reduce_all": reduce_all,
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
]
ops_config = [
{
"op_type": "reduce_sum",
"op_inputs": {"X": ["input_data"]},
"op_outputs": {"Out": ["reduce_output_data"]},
"op_attrs": dics[0],
}
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input1, out_dtype, dics
)
)
},
outputs=["reduce_output_data"],
)
if not self.is_program_valid(program_config):
continue
yield program_config
def sample_predictor_configs(self, program_config):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 32, 32]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 32, 32]}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
if dynamic_shape:
if (not attrs[0]['keep_dim']) and attrs[0]['reduce_all']:
return 0, 3
else:
return 1, 2
else:
if 0 in attrs[0]['dim'] or attrs[0]['reduce_all']:
return 0, 3
else:
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False
), (1e-3, 1e-3)
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True
), (1e-3, 1e-3)
def add_skip_trt_case(self):
pass
def test(self):
self.add_skip_trt_case()
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册