From f1b1c7535424e07e5532b469ecbc248f0730321e Mon Sep 17 00:00:00 2001 From: ReeseWang Date: Sun, 19 Jul 2020 22:54:30 +0800 Subject: [PATCH] add dim=3 support for scale op --- .../inference/tensorrt/convert/scale_op.cc | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index 19e1895635a..abf41cd7259 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -58,6 +58,23 @@ class ScaleOpConverter : public OpConverter { TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, 0}; nvinfer1::ILayer* layer = nullptr; + + auto idim = input->getDimensions(); + PADDLE_ENFORCE_GE(idim.nbDims, 3, + platform::errors::Fatal( + "Paddle-TRT scale mode only support dimension >= 3")); + + nvinfer1::IShuffleLayer* expand_layer = nullptr; + nvinfer1::IShuffleLayer* squeeze_layer = nullptr; + if (idim.nbDims == 3) { + // TensorRT scale layer is not supporting input dims < 4 when using + // explicit batch + expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims + expand_layer->setReshapeDimensions(target_shape); + input = expand_layer->getOutput(0); + } + if (bias_after_scale) { layer = TRT_ENGINE_ADD_LAYER( engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM, @@ -73,6 +90,16 @@ class ScaleOpConverter : public OpConverter { power_weights.get(), scale_weights.get(), power_weights.get()); } + PADDLE_ENFORCE_EQ(layer != nullptr, true); + + if (idim.nbDims == 3) { + // TensorRT scale layer is not supporting input dims < 4 when using + // explicit batch + squeeze_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); + nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims + squeeze_layer->setReshapeDimensions(target_shape); + } RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); } }; -- GitLab