未验证 提交 d90db9bd 编写于 作者: W weishengying 提交者: GitHub

Fix the half precision problem of general plugin (#46580)

上级 f5956bec
...@@ -2326,7 +2326,10 @@ struct GenericPluginTeller : public Teller { ...@@ -2326,7 +2326,10 @@ struct GenericPluginTeller : public Teller {
if (!with_dynamic_shape) { if (!with_dynamic_shape) {
return false; return false;
} }
if (op_type == "yolo_box") {
if (!desc.HasAttr("iou_aware") && !desc.HasAttr("iou_aware_factor"))
return false;
}
if (use_no_calib_int8) { if (use_no_calib_int8) {
return false; return false;
} else { } else {
......
...@@ -287,7 +287,12 @@ bool GenericPlugin::supportsFormatCombination( ...@@ -287,7 +287,12 @@ bool GenericPlugin::supportsFormatCombination(
const nvinfer1::PluginTensorDesc* in_out, const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs, int nb_inputs,
int nb_outputs) TRT_NOEXCEPT { int nb_outputs) TRT_NOEXCEPT {
return true; if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") {
if (pos == 0) return in_out[pos].type == nvinfer1::DataType::kFLOAT;
if (pos == 1) return in_out[pos].type == nvinfer1::DataType::kINT32;
} else {
return in_out[pos].type == nvinfer1::DataType::kFLOAT;
}
} }
nvinfer1::DataType GenericPlugin::getOutputDataType( nvinfer1::DataType GenericPlugin::getOutputDataType(
......
...@@ -144,17 +144,6 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest): ...@@ -144,17 +144,6 @@ class TrtConvertInstanceNormTest(TrtLayerAutoScanTest):
def add_skip_trt_case(self): def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(
self.dynamic_shape.min_input_shape
) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in dynamic fp16 mode.")
def teller2(program_config, predictor_config): def teller2(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True return True
......
...@@ -165,17 +165,6 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest): ...@@ -165,17 +165,6 @@ class TrtConvertYoloBoxTest(TrtLayerAutoScanTest):
def add_skip_trt_case(self): def add_skip_trt_case(self):
def teller1(program_config, predictor_config):
if len(
self.dynamic_shape.min_input_shape
) != 0 and self.trt_param.precision == paddle_infer.PrecisionType.Half:
return True
return False
self.add_skip_case(
teller1, SkipReasons.TRT_NOT_IMPLEMENTED,
"The output has diff between gpu and trt in dynamic fp16 mode.")
def teller2(program_config, predictor_config): def teller2(program_config, predictor_config):
if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt': if len(self.dynamic_shape.min_input_shape) != 0 and os.name == 'nt':
return True return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册