未验证 提交 12406cad 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] add reduce_all and reduce_any (#53088)

上级 3e7be9c9
...@@ -2615,6 +2615,8 @@ USE_TRT_CONVERTER(reduce_max); ...@@ -2615,6 +2615,8 @@ USE_TRT_CONVERTER(reduce_max);
USE_TRT_CONVERTER(reduce_min); USE_TRT_CONVERTER(reduce_min);
USE_TRT_CONVERTER(reduce_sum); USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(reduce_prod); USE_TRT_CONVERTER(reduce_prod);
USE_TRT_CONVERTER(reduce_any);
USE_TRT_CONVERTER(reduce_all);
USE_TRT_CONVERTER(tile); USE_TRT_CONVERTER(tile);
USE_TRT_CONVERTER(conv3d); USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose); USE_TRT_CONVERTER(conv3d_transpose);
......
...@@ -95,6 +95,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>> ...@@ -95,6 +95,8 @@ const std::unordered_map<std::string, std::vector<nvinfer1::ReduceOperation>>
{"reduce_max", {nvinfer1::ReduceOperation::kMAX}}, {"reduce_max", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_min", {nvinfer1::ReduceOperation::kMIN}}, {"reduce_min", {nvinfer1::ReduceOperation::kMIN}},
{"reduce_prod", {nvinfer1::ReduceOperation::kPROD}}, {"reduce_prod", {nvinfer1::ReduceOperation::kPROD}},
{"reduce_any", {nvinfer1::ReduceOperation::kMAX}},
{"reduce_all", {nvinfer1::ReduceOperation::kMIN}},
}; };
class ReduceSumOpConverter : public ReduceOpConverter { class ReduceSumOpConverter : public ReduceOpConverter {
...@@ -122,6 +124,80 @@ class ReduceProdOpConverter : public ReduceOpConverter { ...@@ -122,6 +124,80 @@ class ReduceProdOpConverter : public ReduceOpConverter {
ReduceProdOpConverter() { op_type = "reduce_prod"; } ReduceProdOpConverter() { op_type = "reduce_prod"; }
}; };
class ReduceAnyOpConverter : public ReduceOpConverter {
public:
ReduceAnyOpConverter() { op_type = "reduce_any"; }
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(4) << "convert a paddle " << op_type << " op to tensorrt reduce layer";
framework::OpDesc op_desc(op, nullptr);
auto reduce_type = ops_.find(op_type);
auto* x = engine_->GetITensor(op_desc.Input("X").front());
// Cast the DataType to float
nvinfer1::IReduceLayer* reduce_layer = nullptr;
auto* cast_layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *x);
cast_layer->setOutputType(0, nvinfer1::DataType::kINT32);
cast_layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
nvinfer1::Dims input_shape = x->getDimensions();
int input_dims = input_shape.nbDims;
// Discriminate DataType between int and bool.
bool keep_dim = PADDLE_GET_CONST(bool, op_desc.GetAttr("keep_dim"));
std::vector<int32_t> dim =
PADDLE_GET_CONST(std::vector<int32_t>, op_desc.GetAttr("dim"));
bool reduce_all = PADDLE_GET_CONST(bool, op_desc.GetAttr("reduce_all"));
if (reduce_all) {
uint32_t reduce_dim = 0;
for (int i = 0; i < input_dims; ++i) {
reduce_dim |= 1 << i;
}
reduce_layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*cast_layer->getOutput(0),
reduce_type->second.front(),
reduce_dim,
keep_dim);
} else {
auto CvtToBitMask = [&](const std::vector<int32_t>& dims) -> uint32_t {
uint32_t res = 0;
for (auto x : dims) {
if (x < 0) {
res |= 1 << (x + input_dims);
} else {
if (!engine_->with_dynamic_shape()) x = x - 1;
res |= 1 << x;
}
}
return res;
};
reduce_layer = TRT_ENGINE_ADD_LAYER(engine_,
Reduce,
*cast_layer->getOutput(0),
reduce_type->second.front(),
CvtToBitMask(dim),
keep_dim);
}
auto output_name = op_desc.Output("Out")[0];
auto* layer =
TRT_ENGINE_ADD_LAYER(engine_, Identity, *reduce_layer->getOutput(0));
layer->setOutputType(0, nvinfer1::DataType::kBOOL);
layer->getOutput(0)->setType(nvinfer1::DataType::kBOOL);
// Ensure that the output type and input type are consistent.
layer->getOutput(0)->setType(cast_layer->getInput(0)->getType());
RreplenishLayerAndOutput(layer, op_type, {output_name}, test_mode);
};
};
class ReduceAllOpConverter : public ReduceAnyOpConverter {
public:
ReduceAllOpConverter() { op_type = "reduce_all"; }
};
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
...@@ -131,3 +207,5 @@ REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter); ...@@ -131,3 +207,5 @@ REGISTER_TRT_OP_CONVERTER(reduce_mean, ReduceMeanOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_max, ReduceMaxOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_min, ReduceMinOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter); REGISTER_TRT_OP_CONVERTER(reduce_prod, ReduceProdOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_any, ReduceAnyOpConverter);
REGISTER_TRT_OP_CONVERTER(reduce_all, ReduceAllOpConverter);
...@@ -2193,7 +2193,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2193,7 +2193,8 @@ struct SimpleOpTypeSetTeller : public Teller {
if (op_type == "reduce_sum" || op_type == "reduce_mean" || if (op_type == "reduce_sum" || op_type == "reduce_mean" ||
op_type == "reduce_max" || op_type == "reduce_min" || op_type == "reduce_max" || op_type == "reduce_min" ||
op_type == "reduce_prod") { op_type == "reduce_prod" || op_type == "reduce_any" ||
op_type == "reduce_all") {
if (!desc.HasAttr("dim", /*with_attr_var=*/false)) { if (!desc.HasAttr("dim", /*with_attr_var=*/false)) {
VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is " VLOG(3) << "Skip to convert into TRT while found Attribute('dim') is "
"Variable type in " "Variable type in "
...@@ -2234,14 +2235,28 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2234,14 +2235,28 @@ struct SimpleOpTypeSetTeller : public Teller {
return false; return false;
} }
#if IS_TRT_VERSION_LT(7000)
auto dtype = x_var_desc->GetDataType(); auto dtype = x_var_desc->GetDataType();
if (dtype != framework::proto::VarType::FP32) { if (op_type == "reduce_all" || op_type == "reduce_any") {
VLOG(3) << "reduce op input data type must be float32 using TensorRT " if (dtype != framework::proto::VarType::BOOL) {
"< 7.0"; VLOG(3)
return false; << "reduce_all and reduce_any op input data type must be bool";
} return false;
}
} else {
#if IS_TRT_VERSION_GE(7000)
if (dtype != framework::proto::VarType::INT32 &&
dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be int32 or float32";
return false;
}
#else
if (dtype != framework::proto::VarType::FP32) {
VLOG(3) << "reduce op input data type must be float32 using TensorRT "
"< 7.0";
return false;
}
#endif #endif
}
} }
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
if (op_type == "tile") { if (op_type == "tile") {
...@@ -2804,8 +2819,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2804,8 +2819,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"nearest_interp", "nearest_interp",
"anchor_generator", "anchor_generator",
"reduce_max", "reduce_max",
"reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod",
"reduce_any",
"reduce_all",
"conv3d", "conv3d",
"conv3d_transpose", "conv3d_transpose",
"mish", "mish",
...@@ -2961,8 +2980,12 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -2961,8 +2980,12 @@ struct SimpleOpTypeSetTeller : public Teller {
"nearest_interp", "nearest_interp",
"anchor_generator", "anchor_generator",
"reduce_max", "reduce_max",
"reduce_min",
"reduce_mean", "reduce_mean",
"reduce_sum", "reduce_sum",
"reduce_prod",
"reduce_any",
"reduce_all",
"conv3d", "conv3d",
"conv3d_transpose", "conv3d_transpose",
"mish", "mish",
......
...@@ -51,6 +51,8 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -51,6 +51,8 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
return np.random.random([1, 3, 64, 64]).astype(np.float32) return np.random.random([1, 3, 64, 64]).astype(np.float32)
elif dtype == 2: elif dtype == 2:
return np.random.random([1, 3, 64, 64]).astype(np.int32) return np.random.random([1, 3, 64, 64]).astype(np.int32)
elif dtype == 0:
return np.random.random([1, 3, 64, 64]).astype(np.bool_)
for keep_dim in [True, False]: for keep_dim in [True, False]:
for dim in [ for dim in [
...@@ -65,15 +67,24 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -65,15 +67,24 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
[3, 4, 5], [3, 4, 5],
]: ]:
for reduce_all in [True, False]: for reduce_all in [True, False]:
for out_dtype in [-1, 2, 5]: for out_dtype in [-1, 0, 2, 5]:
for op_type in [ if out_dtype != 0:
"reduce_max", reduce_type_list = [
"reduce_min", "reduce_max",
"reduce_mean", "reduce_min",
"reduce_sum", "reduce_mean",
"reduce_prod", "reduce_sum",
]: "reduce_prod",
dics1 = [ ]
else:
reduce_type_list = [
"reduce_all",
"reduce_any",
]
for op_type in reduce_type_list:
dics = [
{ {
"keep_dim": keep_dim, "keep_dim": keep_dim,
"dim": dim, "dim": dim,
...@@ -83,46 +94,40 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest): ...@@ -83,46 +94,40 @@ class TrtConvertReduceTest(TrtLayerAutoScanTest):
}, },
{}, {},
] ]
dics2 = [
ops_config = [
{ {
"keep_dim": keep_dim, "op_type": op_type,
"dim": dim, "op_inputs": {"X": ["input_data"]},
"reduce_all": reduce_all, "op_outputs": {
"out_dtype": out_dtype, "Out": ["reduce_output_data"]
"in_dtype": out_dtype, },
}, "op_attrs": dics[0],
{}, }
] ]
for dics in [dics1, dics2]: if op_type in ["reduce_any", "reduce_all"]:
ops_config = [ ops_config[0]["outputs_dtype"] = {
{ "reduce_output_data": np.bool_
"op_type": op_type, }
"op_inputs": {"X": ["input_data"]}, ops = self.generate_op_config(ops_config)
"op_outputs": {
"Out": ["reduce_output_data"] program_config = ProgramConfig(
}, ops=ops,
"op_attrs": dics[0], weights={},
} inputs={
] "input_data": TensorConfig(
ops = self.generate_op_config(ops_config) data_gen=partial(
generate_input1, out_dtype, dics
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data": TensorConfig(
data_gen=partial(
generate_input1, out_dtype, dics
)
) )
}, )
outputs=["reduce_output_data"], },
) outputs=["reduce_output_data"],
)
if not self.is_program_valid(program_config): if not self.is_program_valid(program_config):
continue continue
yield program_config yield program_config
def sample_predictor_configs( def sample_predictor_configs(
self, program_config self, program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册