提交 f1b1c753 编写于 作者: R ReeseWang

add dim=3 support for scale op

上级 bd0f1abd
......@@ -58,6 +58,23 @@ class ScaleOpConverter : public OpConverter {
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
nvinfer1::ILayer* layer = nullptr;
auto idim = input->getDimensions();
PADDLE_ENFORCE_GE(idim.nbDims, 3,
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
if (idim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims
expand_layer->setReshapeDimensions(target_shape);
input = expand_layer->getOutput(0);
}
if (bias_after_scale) {
layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *input, nvinfer1::ScaleMode::kUNIFORM,
......@@ -73,6 +90,16 @@ class ScaleOpConverter : public OpConverter {
power_weights.get(), scale_weights.get(), power_weights.get());
}
PADDLE_ENFORCE_EQ(layer != nullptr, true);
if (idim.nbDims == 3) {
// TensorRT scale layer is not supporting input dims < 4 when using
// explicit batch
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims
squeeze_layer->setReshapeDimensions(target_shape);
}
RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode);
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册