未验证 提交 bc150edc 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] Fix_nearest: align_corners != true (#37368)

* fix_nearest

* fix_nearest

* fix_nearest

* fix_nearest
上级 ccad31f5
...@@ -38,8 +38,10 @@ class NearestInterpolateOpConverter : public OpConverter { ...@@ -38,8 +38,10 @@ class NearestInterpolateOpConverter : public OpConverter {
auto input = engine_->GetITensor(input_name); auto input = engine_->GetITensor(input_name);
auto data_layout = framework::StringToDataLayout( auto data_layout = !op_desc.HasAttr("data_layout")
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout"))); ? framework::DataLayout::kNCHW
: framework::StringToDataLayout(BOOST_GET_CONST(
std::string, op_desc.GetAttr("data_layout")));
auto interp_method = auto interp_method =
BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method")); BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method"));
bool align_corners = bool align_corners =
......
...@@ -667,28 +667,25 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -667,28 +667,25 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
if (op_type == "nearest_interp") { if (op_type == "nearest_interp") {
std::vector<std::string> attrs{"data_layout", "interp_method", std::vector<std::string> attrs{"interp_method", "align_corners", "scale",
"align_corners", "scale",
"out_h", "out_w"}; "out_h", "out_w"};
for (auto const attr : attrs) { for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false; if (!desc.HasAttr(attr)) return false;
} }
if (desc.HasAttr("data_layout")) {
auto data_layout = framework::StringToDataLayout( auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout"))); BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
if (data_layout != framework::DataLayout::kNCHW && if (data_layout != framework::DataLayout::kNCHW &&
data_layout != framework::DataLayout::kNHWC) data_layout != framework::DataLayout::kNHWC)
return false; return false;
}
auto interp_method = auto interp_method =
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method")); BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false; if (interp_method != "nearest") return false;
if (!desc.HasAttr("scale") || !desc.HasAttr("out_h") ||
!desc.HasAttr("out_w")) {
return false;
} else {
auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale")); auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale"));
auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h")); auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h"));
auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w")); auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w"));
auto align_corners = BOOST_GET_CONST(bool, desc.GetAttr("align_corners"));
if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) { if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) {
if (out_h <= 0) { if (out_h <= 0) {
VLOG(3) << "out_h must be greater than 0 if scale is not set."; VLOG(3) << "out_h must be greater than 0 if scale is not set.";
...@@ -703,6 +700,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -703,6 +700,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
VLOG(3) << "dynamic shape not support scale not set."; VLOG(3) << "dynamic shape not support scale not set.";
return false; return false;
} }
// When align_corners = true, the paddle's and trt_layer's results has
// diff
if (align_corners && scale != 1) {
return false;
} }
} }
......
...@@ -124,6 +124,8 @@ class TrtConvertNearestInterpTest(TrtLayerAutoScanTest): ...@@ -124,6 +124,8 @@ class TrtConvertNearestInterpTest(TrtLayerAutoScanTest):
if program_config.ops[0].attrs[ if program_config.ops[0].attrs[
'scale'] <= 0 and self.dynamic_shape.min_input_shape: 'scale'] <= 0 and self.dynamic_shape.min_input_shape:
return True return True
if program_config.ops[0].attrs['align_corners'] == True:
return True
return False return False
self.add_skip_case( self.add_skip_case(
......
...@@ -61,9 +61,9 @@ class TRTNearestInterpTest(InferencePassTest): ...@@ -61,9 +61,9 @@ class TRTNearestInterpTest(InferencePassTest):
def set_params(self): def set_params(self):
self.bs = 4 self.bs = 4
self.scale = 1 self.scale = 0
self.channels = 3 self.channels = 3
self.origin_shape = (32, 32) # HW self.origin_shape = (4, 4) # HW
self.resize_shape = (64, 64) # HW self.resize_shape = (64, 64) # HW
self.align_corners = True self.align_corners = True
self.data_layout = 'NCHW' self.data_layout = 'NCHW'
...@@ -114,7 +114,7 @@ class TRTNearestInterpTest2(TRTNearestInterpTest): ...@@ -114,7 +114,7 @@ class TRTNearestInterpTest2(TRTNearestInterpTest):
class TRTNearestInterpTest3(TRTNearestInterpTest): class TRTNearestInterpTest3(TRTNearestInterpTest):
def set_params(self): def set_params(self):
self.bs = 4 self.bs = 4
self.scale = -1 self.scale = 0
self.channels = 3 self.channels = 3
self.origin_shape = (32, 32) # HW self.origin_shape = (32, 32) # HW
self.resize_shape = (64, 64) # HW self.resize_shape = (64, 64) # HW
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册