未验证 提交 26cc5c54 编写于 作者: H heliqi 提交者: GitHub

fix onnxruntime bug (#42095) (#42104)

修复ORT在batch变动时,输出shape不对问题
上级 41003161
......@@ -693,10 +693,9 @@ void Tensor::ORTCopyToCpu(T *data) const {
if (place_ == PlaceType::kCPU) {
std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
} else {
paddle::memory::Copy(paddle::platform::CPUPlace(),
static_cast<void *>(data),
paddle::platform::CUDAPlace(device_),
value.GetTensorData<void>(), size, nullptr);
PADDLE_THROW(paddle::platform::errors::Unavailable(
"CopyToCpu error.The current ONNXRuntime backend doesn't support "
"GPU."));
}
}
......
......@@ -279,6 +279,12 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() {
try {
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
for (auto output : output_desc_) {
Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info);
}
session_.Run({}, *(binding_.get()));
} catch (const std::exception &e) {
LOG(ERROR) << e.what();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册