未验证 提交 c061c082 编写于 作者: Z Zhang Jun 提交者: GitHub

[inference][trt] bilinear support OutSize input (#47495)

* add bilinear OutSize
上级 05a4be36
......@@ -52,6 +52,7 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
auto resize_inputs = op_desc.Inputs();
auto input_names = op_desc.Input("X");
auto out_h = PADDLE_GET_CONST(int, op_desc.GetAttr("out_h"));
auto out_w = PADDLE_GET_CONST(int, op_desc.GetAttr("out_w"));
......@@ -94,6 +95,15 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
out_w = static_cast<int>(in_dim.d[w_axis] * scale_w);
}
// Priority: Input(OutSize) > attr(out_h/out_w) > attr(scale)
nvinfer1::ITensor* outsize_tensor = nullptr;
if (engine_->with_dynamic_shape() &&
resize_inputs.find("OutSize") != resize_inputs.end()) {
if (op_desc.Input("OutSize").size() >= 1) {
outsize_tensor = engine_->GetITensor(op_desc.Input("OutSize")[0]);
}
}
if (out_h > 0 && out_w > 0) {
scale_h =
static_cast<float>(out_h) / static_cast<float>(in_dim.d[h_axis]);
......@@ -102,11 +112,9 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
}
std::vector<float> scales;
if (engine_->with_dynamic_shape()) {
scales.push_back(1.f);
}
if (data_layout == phi::DataLayout::kNCHW) {
scales.push_back(1.f);
scales.push_back(scale_h);
......@@ -115,12 +123,29 @@ class BilinearInterpolateV2OpConverter : public OpConverter {
scales.push_back(scale_h);
scales.push_back(scale_w);
scales.push_back(1.f);
}
if (engine_->with_dynamic_shape()) {
if (outsize_tensor != nullptr) {
std::vector<nvinfer1::ITensor*> outsize_itensors;
auto* input_shape = Shape(input);
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 0));
if (data_layout == phi::DataLayout::kNCHW) {
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 1));
outsize_itensors.push_back(outsize_tensor);
} else if (data_layout == phi::DataLayout::kNHWC) {
outsize_itensors.push_back(outsize_tensor);
outsize_itensors.push_back(GetEleTensorOfShape(input_shape, 3));
}
layer->setInput(1, *Concat(outsize_itensors));
} else {
layer->setScales(scales.data(), scales.size());
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Data layout must be NCHW or NHWC."));
layer->setScales(scales.data(), scales.size());
}
layer->setScales(scales.data(), scales.size());
RreplenishLayerAndOutput(
layer, "bilinear_interp_v2", {output_name}, test_mode);
}
......
......@@ -803,8 +803,8 @@ struct SimpleOpTypeSetTeller : public Teller {
}
if (resize_inputs.find("OutSize") != resize_inputs.end()) {
if (desc.Input("OutSize").size() >= 1) {
VLOG(3) << "The Paddle-TRT doesn't support the OutSize for op_type "
if (!with_dynamic_shape) {
VLOG(3) << "Static shape don't support the OutSize for op_type "
<< op_type;
return false;
}
......
......@@ -41,8 +41,8 @@ class TrtConvertBilinearInterpV2Test(TrtLayerAutoScanTest):
)
for data_layout in ["NCHW", "NHWC"]:
for scale_y in [2.0, -1.0, 0.0]:
for scale_x in [2.0, -1.0, 0.0]:
for scale_y in [2.0, 1.0]:
for scale_x in [2.0, 1.0]:
scale = [scale_y, scale_x]
for out_h in [32, 64, 128, 192]:
for out_w in [32, 64]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册