未验证 提交 bc150edc 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] Fix_nearest: align_corners != true (#37368)

* fix_nearest

* fix_nearest

* fix_nearest

* fix_nearest
上级 ccad31f5
......@@ -38,8 +38,10 @@ class NearestInterpolateOpConverter : public OpConverter {
auto input = engine_->GetITensor(input_name);
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
auto data_layout = !op_desc.HasAttr("data_layout")
? framework::DataLayout::kNCHW
: framework::StringToDataLayout(BOOST_GET_CONST(
std::string, op_desc.GetAttr("data_layout")));
auto interp_method =
BOOST_GET_CONST(std::string, op_desc.GetAttr("interp_method"));
bool align_corners =
......
......@@ -667,43 +667,44 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "nearest_interp") {
std::vector<std::string> attrs{"data_layout", "interp_method",
"align_corners", "scale",
"out_h", "out_w"};
std::vector<std::string> attrs{"interp_method", "align_corners", "scale",
"out_h", "out_w"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
if (data_layout != framework::DataLayout::kNCHW &&
data_layout != framework::DataLayout::kNHWC)
return false;
if (desc.HasAttr("data_layout")) {
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
if (data_layout != framework::DataLayout::kNCHW &&
data_layout != framework::DataLayout::kNHWC)
return false;
}
auto interp_method =
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false;
if (!desc.HasAttr("scale") || !desc.HasAttr("out_h") ||
!desc.HasAttr("out_w")) {
return false;
} else {
auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale"));
auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h"));
auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w"));
if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) {
if (out_h <= 0) {
VLOG(3) << "out_h must be greater than 0 if scale is not set.";
return false;
}
if (out_w <= 0) {
VLOG(3) << "out_w must be greater than 0 if scale is not set.";
return false;
}
auto scale = BOOST_GET_CONST(float, desc.GetAttr("scale"));
auto out_h = BOOST_GET_CONST(int, desc.GetAttr("out_h"));
auto out_w = BOOST_GET_CONST(int, desc.GetAttr("out_w"));
auto align_corners = BOOST_GET_CONST(bool, desc.GetAttr("align_corners"));
if (!(scale > 0.f && (out_h <= 0 && out_w <= 0))) {
if (out_h <= 0) {
VLOG(3) << "out_h must be greater than 0 if scale is not set.";
return false;
}
if ((scale <= 0.f) && with_dynamic_shape) {
VLOG(3) << "dynamic shape not support scale not set.";
if (out_w <= 0) {
VLOG(3) << "out_w must be greater than 0 if scale is not set.";
return false;
}
}
if ((scale <= 0.f) && with_dynamic_shape) {
VLOG(3) << "dynamic shape not support scale not set.";
return false;
}
// When align_corners = true, the paddle's and trt_layer's results has
// diff
if (align_corners && scale != 1) {
return false;
}
}
if (op_type == "nearest_interp_v2") {
......
......@@ -124,6 +124,8 @@ class TrtConvertNearestInterpTest(TrtLayerAutoScanTest):
if program_config.ops[0].attrs[
'scale'] <= 0 and self.dynamic_shape.min_input_shape:
return True
if program_config.ops[0].attrs['align_corners'] == True:
return True
return False
self.add_skip_case(
......
......@@ -61,9 +61,9 @@ class TRTNearestInterpTest(InferencePassTest):
def set_params(self):
self.bs = 4
self.scale = 1
self.scale = 0
self.channels = 3
self.origin_shape = (32, 32) # HW
self.origin_shape = (4, 4) # HW
self.resize_shape = (64, 64) # HW
self.align_corners = True
self.data_layout = 'NCHW'
......@@ -114,7 +114,7 @@ class TRTNearestInterpTest2(TRTNearestInterpTest):
class TRTNearestInterpTest3(TRTNearestInterpTest):
def set_params(self):
self.bs = 4
self.scale = -1
self.scale = 0
self.channels = 3
self.origin_shape = (32, 32) # HW
self.resize_shape = (64, 64) # HW
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册