未验证 提交 5d49e3e9 编写于 作者: W weishengying 提交者: GitHub

Enable Generic-Plugin support FP16 (#48807)

上级 cb7f736f
...@@ -300,27 +300,38 @@ bool GenericPlugin::supportsFormatCombination( ...@@ -300,27 +300,38 @@ bool GenericPlugin::supportsFormatCombination(
int nb_outputs) TRT_NOEXCEPT { int nb_outputs) TRT_NOEXCEPT {
if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") { if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") {
if (pos == 0) if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 1) if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) && return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 2) if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return in_out[0].type == in_out[pos].type &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "scatter_nd_add") { } else if (op_desc_.Type() == "scatter_nd_add") {
// input X
if (pos == 0) if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Index
if (pos == 1) if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) && return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Updates
if (pos == 2) if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 3) if (pos == 3)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return in_out[0].type == in_out[pos].type &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "pad3d") { } else if (op_desc_.Type() == "pad3d") {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT || return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() && (isFp16Supported() &&
...@@ -328,7 +339,9 @@ bool GenericPlugin::supportsFormatCombination( ...@@ -328,7 +339,9 @@ bool GenericPlugin::supportsFormatCombination(
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) && (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) &&
(in_out[0].type == in_out[pos].type); (in_out[0].type == in_out[pos].type);
} else { } else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) && return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR); (in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册