提交 b3f79966 编写于 作者: M Megvii Engine Team

fix(mgb): fix "TRT_ERROR: INVALID_ARGUMENT: Get binding data type failed."

GitOrigin-RevId: d9601cb15b67f9dd3d73e6ad8d4069119783ae75
上级 99a85c40
...@@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() { ...@@ -194,21 +194,25 @@ void TensorRTRuntimeOpr::init_output_dtype() {
idx++; idx++;
} }
for (size_t i = 0; i < output().size(); ++i) { size_t out = 0;
for (; out < output().size() - 1; ++out) {
dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx)); dt_trt = get_dtype_from_trt(m_engine->getBindingDataType(idx));
mgb_assert( mgb_assert(
dt_trt.valid(), dt_trt.valid(),
"output dtype checking failed: invalid dtype returned."); "output dtype checking failed: invalid dtype returned.");
if (dt_trt.enumv() == DTypeEnum::QuantizedS8) { if (dt_trt.enumv() == DTypeEnum::QuantizedS8) {
mgb_assert( mgb_assert(
output(i)->dtype().valid(), output(out)->dtype().valid(),
"user should specify scale of output tensor of " "user should specify scale of output tensor of "
"TensorRTRuntimeOpr."); "TensorRTRuntimeOpr.");
} }
if (!output(i)->dtype().valid()) if (!output(out)->dtype().valid())
output(i)->dtype(dt_trt); output(out)->dtype(dt_trt);
idx++; idx++;
} }
//! workspace
if (!output(out)->dtype().valid())
output(out)->dtype(dtype::Byte());
} }
SymbolVarArray TensorRTRuntimeOpr::make( SymbolVarArray TensorRTRuntimeOpr::make(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册