未验证 提交 13992de7 编写于 作者: S Sanbu 提交者: GitHub

Add reduce_min prod trt converter (#49615)

上级 acab7daf
......@@ -2350,7 +2350,9 @@ USE_TRT_CONVERTER(reshape2);
USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reduce_mean);
USE_TRT_CONVERTER(reduce_max);
USE_TRT_CONVERTER(reduce_min);
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(reduce_prod);
USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
......
......@@ -43,7 +43,6 @@ class ReduceOpConverter : public OpConverter {
VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer";
framework::OpDesc op_desc(op, nullptr);
auto reduce_type = ops_.find(op_type);
auto* x = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims input_shape = x->getDimensions();
int input_dims = input_shape.nbDims;
......@@ -104,6 +103,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>>
{"reduce_mean", {nvinfer1::ReduceOperation::kAVG}},
{"reduce_sum", {nvinfer1::ReduceOperation::kSUM}},
{"reduce_max", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_min", {nvinfer1::ReduceOperation::kMIN}},
{"reduce_prod", {nvinfer1::ReduceOperation::kPROD}},
};
class ReduceSumOpConverter : public ReduceOpConverter {
......@@ -120,6 +121,17 @@ class ReduceMaxOpConverter : public ReduceOpConverter {
public:
ReduceMaxOpConverter() { op_type = "reduce_max"; }
};
class ReduceMinOpConverter : public ReduceOpConverter {
public:
ReduceMinOpConverter() { op_type = "reduce_min"; }
};
class ReduceProdOpConverter : public ReduceOpConverter {
public:
ReduceProdOpConverter() { op_type = "reduce_prod"; }
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -127,3 +139,5 @@ class ReduceMaxOpConverter : public ReduceOpConverter {
REGISTER_TRT_OP_CONVERTER(reduce_sum, ReduceSumOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter);
......@@ -2088,7 +2088,8 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if (op_type == "reduce_sum" || op_type == "reduce_mean" ||
op_type == "reduce_max") {
op_type == "reduce_max" || op_type == "reduce_min" ||
op_type == "reduce_prod") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in "
......
......@@ -68,10 +68,12 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
for out_dtype in [-1, 2, 5]:
for op_type in [
"reduce_max",
"reduce_min",
"reduce_mean",
"reduce_sum",
"reduce_prod",
]:
dics = [
dics1 = [
{
"keep_dim": keep_dim,
"dim": dim,
......@@ -81,7 +83,14 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
},
{},
]
dics2 = [
{
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
]
for dics in [dics1, dics2]:
ops_config = [
{
"op_type": op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册