未验证 提交 26cc5c54 编写于 作者: H heliqi 提交者: GitHub

fix onnxruntime bug (#42095) (#42104)

修复ORT在batch变动时,输出shape不对问题
上级 41003161
...@@ -693,10 +693,9 @@ void Tensor::ORTCopyToCpu(T *data) const { ...@@ -693,10 +693,9 @@ void Tensor::ORTCopyToCpu(T *data) const {
if (place_ == PlaceType::kCPU) { if (place_ == PlaceType::kCPU) {
std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size); std::memcpy(static_cast<void *>(data), value.GetTensorData<void *>(), size);
} else { } else {
paddle::memory::Copy(paddle::platform::CPUPlace(), PADDLE_THROW(paddle::platform::errors::Unavailable(
static_cast<void *>(data), "CopyToCpu error.The current ONNXRuntime backend doesn't support "
paddle::platform::CUDAPlace(device_), "GPU."));
value.GetTensorData<void>(), size, nullptr);
} }
} }
......
...@@ -279,6 +279,12 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs, ...@@ -279,6 +279,12 @@ bool ONNXRuntimePredictor::Run(const std::vector<PaddleTensor> &inputs,
bool ONNXRuntimePredictor::ZeroCopyRun() { bool ONNXRuntimePredictor::ZeroCopyRun() {
try { try {
const char *device_name = place_ == PlaceType::kCPU ? "Cpu" : "Cuda";
for (auto output : output_desc_) {
Ort::MemoryInfo out_memory_info(device_name, OrtDeviceAllocator,
place_.GetDeviceId(), OrtMemTypeDefault);
binding_->BindOutput(output.name.c_str(), out_memory_info);
}
session_.Run({}, *(binding_.get())); session_.Run({}, *(binding_.get()));
} catch (const std::exception &e) { } catch (const std::exception &e) {
LOG(ERROR) << e.what(); LOG(ERROR) << e.what();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册