未验证 提交 01ccfbcd 编写于 作者: W Wilber 提交者: GitHub

update trt error message when input height or width is -1 (#31019)

上级 cf8b8f9c
......@@ -81,10 +81,35 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<T>& shape, std::string input,
"TensorRT's tensor input requires at most 4 "
"dimensions, but input %s has %d dims.",
input, shape.size()));
auto ShapeStr = [](const std::vector<T>& shape) {
std::ostringstream os;
os << "[";
for (size_t i = 0; i < shape.size(); ++i) {
if (i == shape.size() - 1) {
os << shape[i];
} else {
os << shape[i] << ",";
}
}
os << "]";
return os.str();
};
if (!with_dynamic_shape) {
if (shape.size() == 4UL) {
if (shape[2] == -1 || shape[3] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input, ShapeStr(shape)));
}
return nvinfer1::DimsCHW(shape[1], shape[2], shape[3]);
} else if (shape.size() == 3UL) {
if (shape[1] == -1 || shape[2] == -1) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input [%s] shape of trt subgraph is %s, please enable "
"trt dynamic_shape mode by SetTRTDynamicShapeInfo.",
input, ShapeStr(shape)));
}
return nvinfer1::Dims2(shape[1], shape[2]);
}
return nvinfer1::DimsCHW(shape[1], 1, 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册