From 35af5818afe85daf16c02fc4e79749236cce0fb8 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Thu, 27 Apr 2023 11:01:52 +0800 Subject: [PATCH] refine SynchronizeAllDevice (#53370) --- paddle/fluid/platform/profiler_helper.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/platform/profiler_helper.h b/paddle/fluid/platform/profiler_helper.h index 0f058ad0f8d..1d34d5fd27b 100644 --- a/paddle/fluid/platform/profiler_helper.h +++ b/paddle/fluid/platform/profiler_helper.h @@ -95,28 +95,34 @@ std::vector> GetMemEvents() { void SynchronizeAllDevice() { #ifdef PADDLE_WITH_CUDA + int pre_device_id = GetCurrentDeviceId(); int count = GetGPUDeviceCount(); for (int i = 0; i < count; i++) { SetDeviceId(i); PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize()); } + SetDeviceId(pre_device_id); #endif #ifdef PADDLE_WITH_HIP + int pre_device_id = GetCurrentDeviceId(); int count = GetGPUDeviceCount(); for (int i = 0; i < count; i++) { SetDeviceId(i); PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize()); } + SetDeviceId(pre_device_id); #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); for (const auto &dev_type : dev_types) { + int pre_device_id = phi::DeviceManager::GetDevice(dev_type); auto dev_cnt = phi::DeviceManager::GetDeviceCount(dev_type); for (size_t i = 0; i < dev_cnt; i++) { auto place = paddle::platform::CustomPlace(dev_type, i); phi::DeviceManager::SetDevice(place); phi::DeviceManager::SynchronizeDevice(place); } + phi::DeviceManager::SetDevice(dev_type, pre_device_id); } #endif } -- GitLab