From b59c68455fcc4706d0b94cee84e0ff0d97f8fff2 Mon Sep 17 00:00:00 2001 From: zlsh80826 Date: Sat, 8 Aug 2020 13:06:37 +0800 Subject: [PATCH] fix scale op output layer --- paddle/fluid/inference/tensorrt/convert/scale_op.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index 841451a2b31..f9a1fe41ddc 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -59,14 +59,15 @@ class ScaleOpConverter : public OpConverter { 0}; nvinfer1::ILayer* layer = nullptr; - auto idim = input->getDimensions(); - PADDLE_ENFORCE_GE(idim.nbDims, 3, + auto input_dim = input->getDimensions(); + PADDLE_ENFORCE_GE(input_dim.nbDims, 3, platform::errors::Fatal( "Paddle-TRT scale mode only support dimension >= 3")); nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr; - if (idim.nbDims == 3) { + + if (input_dim.nbDims == 3) { // TensorRT scale layer is not supporting input dims < 4 when using // explicit batch expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); @@ -93,13 +94,14 @@ class ScaleOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(layer != nullptr, true, platform::errors::Fatal("Create scale layer failed.")); - if (idim.nbDims == 3) { + if (input_dim.nbDims == 3) { // TensorRT scale layer is not supporting input dims < 4 when using // explicit batch squeeze_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims squeeze_layer->setReshapeDimensions(target_shape); + layer = static_cast(squeeze_layer); } RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); } -- GitLab