未验证 提交 5d49e3e9 编写于 作者: W weishengying 提交者: GitHub

Enable Generic-Plugin support FP16 (#48807)

上级 cb7f736f
......@@ -300,27 +300,38 @@ bool GenericPlugin::supportsFormatCombination(
int nb_outputs) TRT_NOEXCEPT {
if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") {
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "scatter_nd_add") {
// input X
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Index
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Updates
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 3)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "pad3d") {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
......@@ -328,7 +339,9 @@ bool GenericPlugin::supportsFormatCombination(
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) &&
(in_out[0].type == in_out[pos].type);
} else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册