提交 b59c6845 编写于 作者: Z zlsh80826

fix scale op output layer

上级 6531712a
......@@ -59,14 +59,15 @@ class ScaleOpConverter : public OpConverter {
0};
nvinfer1::ILayer* layer = nullptr;
auto idim = input->getDimensions();
PADDLE_ENFORCE_GE(idim.nbDims, 3,
auto input_dim = input->getDimensions();
PADDLE_ENFORCE_GE(input_dim.nbDims, 3,
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (idim.nbDims == 3) {
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
......@@ -93,13 +94,14 @@ class ScaleOpConverter : public OpConverter {
PADDLE_ENFORCE_EQ(layer != nullptr, true,
platform::errors::Fatal("Create scale layer failed."));
if (idim.nbDims == 3) {
if (input_dim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册