提交 b3f79966 编写于 作者: M Megvii Engine Team

fix(mgb): fix "TRT_ERROR: INVALID_ARGUMENT: Get binding data type failed."

GitOrigin-RevId: d9601cb15b67f9dd3d73e6ad8d4069119783ae75
上级 99a85c40
......@@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() {
idx++;
}
for (size_t i = 0; i < output().size(); ++i) {
size_t out = 0;
for (; out < output().size() - 1; ++out) {
dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx));
mgb_assert(
dt_trt.valid(),
"output dtype checking failed: invalid dtype returned.");
if (dt_trt.enumv() == DTypeEnum::QuantizedS8) {
mgb_assert(
output(i)->dtype().valid(),
output(out)->dtype().valid(),
"user should specify scale of output tensor of "
"TensorRTRuntimeOpr.");
}
if (!output(i)->dtype().valid())
output(i)->dtype(dt_trt);
if (!output(out)->dtype().valid())
output(out)->dtype(dt_trt);
idx++;
}
//! workspace
if (!output(out)->dtype().valid())
output(out)->dtype(dtype::Byte());
}
SymbolVarArray TensorRTRuntimeOpr::make(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册