未验证 提交 13992de7 编写于 作者: S Sanbu 提交者: GitHub

Add reduce_min prod trt converter (#49615)

上级 acab7daf
...@@ -2350,7 +2350,9 @@ USE_TRT_CONVERTER(reshape2); ...@@ -2350,7 +2350,9 @@ USE_TRT_CONVERTER(reshape2);
USE_TRT_CONVERTER(gather_nd); USE_TRT_CONVERTER(gather_nd);
USE_TRT_CONVERTER(reduce_mean); USE_TRT_CONVERTER(reduce_mean);
USE_TRT_CONVERTER(reduce_max); USE_TRT_CONVERTER(reduce_max);
USE_TRT_CONVERTER(reduce_min);
USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(reduce_prod);
USE_TRT_CONVERTER(tile); USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d); USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(conv3d_transpose);
......
...@@ -43,7 +43,6 @@ class ReduceOpConverter : public OpConverter { ...@@ -43,7 +43,6 @@ class ReduceOpConverter : public OpConverter {
VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer"; VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto reduce_type = ops_.find(op_type); auto reduce_type = ops_.find(op_type);
auto* x = engine_->GetITensor(op_desc.Input("X").front()); auto* x = engine_->GetITensor(op_desc.Input("X").front());
nvinfer1::Dims input_shape = x->getDimensions(); nvinfer1::Dims input_shape = x->getDimensions();
int input_dims = input_shape.nbDims; int input_dims = input_shape.nbDims;
...@@ -104,6 +103,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>> ...@@ -104,6 +103,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>>
{"reduce_mean", {nvinfer1::ReduceOperation::kAVG}}, {"reduce_mean", {nvinfer1::ReduceOperation::kAVG}},
{"reduce_sum", {nvinfer1::ReduceOperation::kSUM}}, {"reduce_sum", {nvinfer1::ReduceOperation::kSUM}},
{"reduce_max", {nvinfer1::ReduceOperation::kMAX}}, {"reduce_max", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_min", {nvinfer1::ReduceOperation::kMIN}},
{"reduce_prod", {nvinfer1::ReduceOperation::kPROD}},
}; };
class ReduceSumOpConverter : public ReduceOpConverter { class ReduceSumOpConverter : public ReduceOpConverter {
...@@ -120,6 +121,17 @@ class ReduceMaxOpConverter : public ReduceOpConverter { ...@@ -120,6 +121,17 @@ class ReduceMaxOpConverter : public ReduceOpConverter {
public: public:
ReduceMaxOpConverter() { op_type = "reduce_max"; } ReduceMaxOpConverter() { op_type = "reduce_max"; }
}; };
class ReduceMinOpConverter : public ReduceOpConverter {
public:
ReduceMinOpConverter() { op_type = "reduce_min"; }
};
class ReduceProdOpConverter : public ReduceOpConverter {
public:
ReduceProdOpConverter() { op_type = "reduce_prod"; }
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -127,3 +139,5 @@ class ReduceMaxOpConverter : public ReduceOpConverter { ...@@ -127,3 +139,5 @@ class ReduceMaxOpConverter : public ReduceOpConverter {
REGISTER_TRT_OP_CONVERTER(reduce_sum, ReduceSumOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_sum, ReduceSumOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter);
...@@ -2088,7 +2088,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2088,7 +2088,8 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
if (op_type == "reduce_sum" || op_type == "reduce_mean" || if (op_type == "reduce_sum" || op_type == "reduce_mean" ||
op_type == "reduce_max") { op_type == "reduce_max" || op_type == "reduce_min" ||
op_type == "reduce_prod") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) { if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is " VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in " "Variable type in "
......
...@@ -68,10 +68,12 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -68,10 +68,12 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
for out_dtype in [-1, 2, 5]: for out_dtype in [-1, 2, 5]:
for op_type in [ for op_type in [
"reduce_max", "reduce_max",
"reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod",
]: ]:
dics = [ dics1 = [
{ {
"keep_dim": keep_dim, "keep_dim": keep_dim,
"dim": dim, "dim": dim,
...@@ -81,7 +83,14 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -81,7 +83,14 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
}, },
{}, {},
] ]
dics2 = [
{
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
]
for dics in [dics1, dics2]:
ops_config = [ ops_config = [
{ {
"op_type": op_type, "op_type": op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册