未验证 提交 55e9087f 编写于 作者: W wenbin 提交者: GitHub

disable unsupported trt dimension (#38962)

* develop test

* throw

* ne

* wrong cnt
上级 944ea436
......@@ -128,7 +128,19 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
dims.d[0] = shape[1];
return dims;
}
return nvinfer1::Dims3(shape[1], 1, 1);
// static shape doesn't support 1D op so far.
PADDLE_ENFORCE_NE(shape.size(), 1UL,
platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s."
"it's not supported by trt so far",
input, ShapeStr(shape)));
nvinfer1::Dims dims;
dims.nbDims = shape.size() - 1;
for (size_t i = 1; i < shape.size(); i++) {
dims.d[i - 1] = shape[i];
}
return dims;
} else {
if (shape.size() == 4UL) {
return nvinfer1::Dims4(shape[0], shape[1], shape[2], shape[3]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册