未验证 提交 12406cad 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] add reduce_all and reduce_any (#53088)

上级 3e7be9c9
...@@ -2615,6 +2615,8 @@ USE_TRT_CONVERTER(reduce_max); ...@@ -2615,6 +2615,8 @@ USE_TRT_CONVERTER(reduce_max);
USE_TRT_CONVERTER(reduce_min); USE_TRT_CONVERTER(reduce_min);
USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(reduce_prod); USE_TRT_CONVERTER(reduce_prod);
USE_TRT_CONVERTER(reduce_any);
USE_TRT_CONVERTER(reduce_all);
USE_TRT_CONVERTER(tile); USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d); USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(conv3d_transpose);
......
...@@ -95,6 +95,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>> ...@@ -95,6 +95,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>>
{"reduce_max", {nvinfer1::ReduceOperation::kMAX}}, {"reduce_max", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_min", {nvinfer1::ReduceOperation::kMIN}}, {"reduce_min", {nvinfer1::ReduceOperation::kMIN}},
{"reduce_prod", {nvinfer1::ReduceOperation::kPROD}}, {"reduce_prod", {nvinfer1::ReduceOperation::kPROD}},
{"reduce_any", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_all", {nvinfer1::ReduceOperation::kMIN}},
}; };
class ReduceSumOpConverter : public ReduceOpConverter { class ReduceSumOpConverter : public ReduceOpConverter {
...@@ -122,6 +124,80 @@ class ReduceProdOpConverter : public ReduceOpConverter { ...@@ -122,6 +124,80 @@ class ReduceProdOpConverter : public ReduceOpConverter {
ReduceProdOpConverter() { op_type = "reduce_prod"; } ReduceProdOpConverter() { op_type = "reduce_prod"; }
}; };
class ReduceAnyOpConverter : public ReduceOpConverter {
public:
ReduceAnyOpConverter() { op_type = "reduce_any"; }
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer";
framework::OpDesc op_desc(op, nullptr);
auto reduce_type = ops_.find(op_type);
auto* x = engine_->GetITensor(op_desc.Input("X").front());
// Cast the DataType to float
nvinfer1::IReduceLayer* reduce_layer = nullptr;
auto* cast_layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *x);
cast_layer->setOutputType(0, nvinfer1::DataType::kINT32);
cast_layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
nvinfer1::Dims input_shape = x->getDimensions();
int input_dims = input_shape.nbDims;
// Discriminate DataType between int and bool.
bool keep_dim = PADDLE_GET_CONST(bool, op_desc.GetAttr("keep_dim"));
std::vector<int32_t> dim =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("dim"));
bool reduce_all = PADDLE_GET_CONST(bool, op_desc.GetAttr("reduce_all"));
if (reduce_all) {
uint32_t reduce_dim = 0;
for (int i = 0; i < input_dims; ++i) {
reduce_dim |= 1 << i;
}
reduce_layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*cast_layer->getOutput(0),
reduce_type->second.front(),
reduce_dim,
keep_dim);
} else {
auto CvtToBitMask = [&](const std::vector<int32_t>& dims) -> uint32_t {
uint32_t res = 0;
for (auto x : dims) {
if (x < 0) {
res |= 1 << (x + input_dims);
} else {
if (!engine_->with_dynamic_shape()) x = x - 1;
res |= 1 << x;
}
}
return res;
};
reduce_layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*cast_layer->getOutput(0),
reduce_type->second.front(),
CvtToBitMask(dim),
keep_dim);
}
auto output_name = op_desc.Output("Out")[0];
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Identity, *reduce_layer->getOutput(0));
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
// Ensure that the output type and input type are consistent.
layer->getOutput(0)->setType(cast_layer->getInput(0)->getType());
RreplenishLayerAndOutput(layer, op_type, {output_name}, test_mode);
};
};
class ReduceAllOpConverter : public ReduceAnyOpConverter {
public:
ReduceAllOpConverter() { op_type = "reduce_all"; }
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -131,3 +207,5 @@ REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter); ...@@ -131,3 +207,5 @@ REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_any, ReduceAnyOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_all, ReduceAllOpConverter);
...@@ -2193,7 +2193,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2193,7 +2193,8 @@ struct SimpleOpTypeSetTeller : public Teller {
if (op_type == "reduce_sum" || op_type == "reduce_mean" || if (op_type == "reduce_sum" || op_type == "reduce_mean" ||
op_type == "reduce_max" || op_type == "reduce_min" || op_type == "reduce_max" || op_type == "reduce_min" ||
op_type == "reduce_prod") { op_type == "reduce_prod" || op_type == "reduce_any" ||
op_type == "reduce_all") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) { if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is " VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in " "Variable type in "
...@@ -2234,8 +2235,21 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2234,8 +2235,21 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
#if IS_TRT_VERSION_LT(7000)
auto dtype = x_var_desc->GetDataType(); auto dtype = x_var_desc->GetDataType();
if (op_type == "reduce_all" || op_type == "reduce_any") {
if (dtype != framework::proto::VarType::BOOL) {
VLOG(3)
<< "reduce_all and reduce_any op input data type must be bool";
return false;
}
} else {
#if IS_TRT_VERSION_GE(7000)
if (dtype != framework::proto::VarType::INT32 &&
dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be int32 or float32";
return false;
}
#else
if (dtype != framework::proto::VarType::FP32) { if (dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be float32 using TensorRT " VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0"; "< 7.0";
...@@ -2243,6 +2257,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2243,6 +2257,7 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
#endif #endif
} }
}
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
if (op_type == "tile") { if (op_type == "tile") {
// Paddle-TRT does not support the input tensors. // Paddle-TRT does not support the input tensors.
...@@ -2804,8 +2819,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2804,8 +2819,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"nearest_interp", "nearest_interp",
"anchor_generator", "anchor_generator",
"reduce_max", "reduce_max",
"reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod",
"reduce_any",
"reduce_all",
"conv3d", "conv3d",
"conv3d_transpose", "conv3d_transpose",
"mish", "mish",
...@@ -2961,8 +2980,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2961,8 +2980,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"nearest_interp", "nearest_interp",
"anchor_generator", "anchor_generator",
"reduce_max", "reduce_max",
"reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod",
"reduce_any",
"reduce_all",
"conv3d", "conv3d",
"conv3d_transpose", "conv3d_transpose",
"mish", "mish",
......
...@@ -51,6 +51,8 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -51,6 +51,8 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
return np.random.random([1, 3, 64, 64]).astype(np.float32) return np.random.random([1, 3, 64, 64]).astype(np.float32)
elif dtype == 2: elif dtype == 2:
return np.random.random([1, 3, 64, 64]).astype(np.int32) return np.random.random([1, 3, 64, 64]).astype(np.int32)
elif dtype == 0:
return np.random.random([1, 3, 64, 64]).astype(np.bool_)
for keep_dim in [True, False]: for keep_dim in [True, False]:
for dim in [ for dim in [
...@@ -65,25 +67,24 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -65,25 +67,24 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
[3, 4, 5], [3, 4, 5],
]: ]:
for reduce_all in [True, False]: for reduce_all in [True, False]:
for out_dtype in [-1, 2, 5]: for out_dtype in [-1, 0, 2, 5]:
for op_type in [ if out_dtype != 0:
reduce_type_list = [
"reduce_max", "reduce_max",
"reduce_min", "reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod", "reduce_prod",
]:
dics1 = [
{
"keep_dim": keep_dim,
"dim": dim,
"reduce_all": reduce_all,
"out_dtype": out_dtype,
"in_dtype": out_dtype,
},
{},
] ]
dics2 = [ else:
reduce_type_list = [
"reduce_all",
"reduce_any",
]
for op_type in reduce_type_list:
dics = [
{ {
"keep_dim": keep_dim, "keep_dim": keep_dim,
"dim": dim, "dim": dim,
...@@ -93,7 +94,7 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -93,7 +94,7 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
}, },
{}, {},
] ]
for dics in [dics1, dics2]:
ops_config = [ ops_config = [
{ {
"op_type": op_type, "op_type": op_type,
...@@ -104,6 +105,10 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -104,6 +105,10 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
"op_attrs": dics[0], "op_attrs": dics[0],
} }
] ]
if op_type in ["reduce_any", "reduce_all"]:
ops_config[0]["outputs_dtype"] = {
"reduce_output_data": np.bool_
}
ops = self.generate_op_config(ops_config) ops = self.generate_op_config(ops_config)
program_config = ProgramConfig( program_config = ProgramConfig(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册