From 55e9087f5d0d9cc85f01a6b17ddfe724968f2be6 Mon Sep 17 00:00:00 2001 From: wenbin Date: Mon, 17 Jan 2022 20:14:23 +0800 Subject: [PATCH] disable unsupported trt dimension (#38962) * develop test * throw * ne * wrong cnt --- paddle/fluid/inference/tensorrt/engine.h | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h index 663534feda1..849ec07d07e 100644 --- a/paddle/fluid/inference/tensorrt/engine.h +++ b/paddle/fluid/inference/tensorrt/engine.h @@ -128,7 +128,19 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector& shape, std::string input, dims.d[0] = shape[1]; return dims; } - return nvinfer1::Dims3(shape[1], 1, 1); + // static shape doesn't support 1D op so far. + PADDLE_ENFORCE_NE(shape.size(), 1UL, + platform::errors::InvalidArgument( + "The input [%s] shape of trt subgraph is %s." + "it's not supported by trt so far", + input, ShapeStr(shape))); + + nvinfer1::Dims dims; + dims.nbDims = shape.size() - 1; + for (size_t i = 1; i < shape.size(); i++) { + dims.d[i - 1] = shape[i]; + } + return dims; } else { if (shape.size() == 4UL) { return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]); -- GitLab