未验证 提交 c061c082 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] bilinear support OutSize input (#47495)

* add bilinear OutSize
上级 05a4be36
...@@ -52,6 +52,7 @@ class BilinearInterpolateV2OpConverter : public OpConverter { ...@@ -52,6 +52,7 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
auto resize_inputs = op_desc.Inputs(); auto resize_inputs = op_desc.Inputs();
auto input_names = op_desc.Input("X"); auto input_names = op_desc.Input("X");
auto out_h = PADDLE_GET_CONST(int, op_desc.GetAttr("out_h")); auto out_h = PADDLE_GET_CONST(int, op_desc.GetAttr("out_h"));
auto out_w = PADDLE_GET_CONST(int, op_desc.GetAttr("out_w")); auto out_w = PADDLE_GET_CONST(int, op_desc.GetAttr("out_w"));
...@@ -94,6 +95,15 @@ class BilinearInterpolateV2OpConverter : public OpConverter { ...@@ -94,6 +95,15 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
out_w = static_cast<int>(in_dim.d[w_axis] * scale_w); out_w = static_cast<int>(in_dim.d[w_axis] * scale_w);
} }
// Priority: Input(OutSize) > attr(out_h/out_w) > attr(scale)
nvinfer1::ITensor* outsize_tensor = nullptr;
if (engine_->with_dynamic_shape() &&
resize_inputs.find("OutSize") != resize_inputs.end()) {
if (op_desc.Input("OutSize").size() >= 1) {
outsize_tensor = engine_->GetITensor(op_desc.Input("OutSize")[0]);
}
}
if (out_h > 0 && out_w > 0) { if (out_h > 0 && out_w > 0) {
scale_h = scale_h =
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]); static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
...@@ -102,11 +112,9 @@ class BilinearInterpolateV2OpConverter : public OpConverter { ...@@ -102,11 +112,9 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
} }
std::vector<float> scales; std::vector<float> scales;
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
scales.push_back(1.f); scales.push_back(1.f);
} }
if (data_layout == phi::DataLayout::kNCHW) { if (data_layout == phi::DataLayout::kNCHW) {
scales.push_back(1.f); scales.push_back(1.f);
scales.push_back(scale_h); scales.push_back(scale_h);
...@@ -115,12 +123,29 @@ class BilinearInterpolateV2OpConverter : public OpConverter { ...@@ -115,12 +123,29 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
scales.push_back(scale_h); scales.push_back(scale_h);
scales.push_back(scale_w); scales.push_back(scale_w);
scales.push_back(1.f); scales.push_back(1.f);
}
if (engine_->with_dynamic_shape()) {
if (outsize_tensor != nullptr) {
std::vector<nvinfer1::ITensor*> outsize_itensors;
auto* input_shape = Shape(input);
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 0));
if (data_layout == phi::DataLayout::kNCHW) {
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 1));
outsize_itensors.push_back(outsize_tensor);
} else if (data_layout == phi::DataLayout::kNHWC) {
outsize_itensors.push_back(outsize_tensor);
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 3));
}
layer->setInput(1, *Concat(outsize_itensors));
} else {
layer->setScales(scales.data(), scales.size());
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( layer->setScales(scales.data(), scales.size());
"Data layout must be NCHW or NHWC."));
} }
layer->setScales(scales.data(), scales.size());
RreplenishLayerAndOutput( RreplenishLayerAndOutput(
layer, "bilinear_interp_v2", {output_name}, test_mode); layer, "bilinear_interp_v2", {output_name}, test_mode);
} }
......
...@@ -803,8 +803,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -803,8 +803,8 @@ struct SimpleOpTypeSetTeller : public Teller {
} }
if (resize_inputs.find("OutSize") != resize_inputs.end()) { if (resize_inputs.find("OutSize") != resize_inputs.end()) {
if (desc.Input("OutSize").size() >= 1) { if (!with_dynamic_shape) {
VLOG(3) << "The Paddle-TRT doesn't support the OutSize for op_type " VLOG(3) << "Static shape don't support the OutSize for op_type "
<< op_type; << op_type;
return false; return false;
} }
......
...@@ -41,8 +41,8 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest): ...@@ -41,8 +41,8 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest):
) )
for data_layout in ["NCHW", "NHWC"]: for data_layout in ["NCHW", "NHWC"]:
for scale_y in [2.0, -1.0, 0.0]: for scale_y in [2.0, 1.0]:
for scale_x in [2.0, -1.0, 0.0]: for scale_x in [2.0, 1.0]:
scale = [scale_y, scale_x] scale = [scale_y, scale_x]
for out_h in [32, 64, 128, 192]: for out_h in [32, 64, 128, 192]:
for out_w in [32, 64]: for out_w in [32, 64]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册