未验证 提交 75b16781 编写于 作者: W weishengying 提交者: GitHub

Fix bugs in the General Plugin Mechanism (#47072)

上级 c7d2e82c
......@@ -288,14 +288,31 @@ bool GenericPlugin::supportsFormatCombination(
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") {
if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT;
if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32;
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else if (op_desc_.Type() == "scatter_nd_add") {
if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT;
if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32;
if (pos == 2) return in_out[pos].type == nvinfer1::DataType::kFLOAT;
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 3)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
} else {
return in_out[pos].type == nvinfer1::DataType::kFLOAT;
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册