提交 b59c6845 编写于 作者: Z zlsh80826

fix scale op output layer

上级 6531712a
...@@ -59,14 +59,15 @@ class ScaleOpConverter : public OpConverter { ...@@ -59,14 +59,15 @@ class ScaleOpConverter : public OpConverter {
0}; 0};
nvinfer1::ILayer* layer = nullptr; nvinfer1::ILayer* layer = nullptr;
auto idim = input->getDimensions(); auto input_dim = input->getDimensions();
PADDLE_ENFORCE_GE(idim.nbDims, 3, PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal( platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3")); "Paddle-TRT scale mode only support dimension >= 3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (idim.nbDims == 3) {
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using // TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch // explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
...@@ -93,13 +94,14 @@ class ScaleOpConverter : public OpConverter { ...@@ -93,13 +94,14 @@ class ScaleOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(layer != nullptr, true, PADDLE_ENFORCE_EQ(layer != nullptr, true,
platform::errors::Fatal("Create scale layer failed.")); platform::errors::Fatal("Create scale layer failed."));
if (idim.nbDims == 3) { if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using // TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch // explicit batch
squeeze_layer = squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape); squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
} }
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册