未验证 提交 2a5bd7dc 编写于 作者: C chenjian 提交者: GitHub

optimize device synchronization in profiler (#46089)

* avoid to synchronize all devices

* synchronize custom device
上级 4e8ad06a
...@@ -32,11 +32,32 @@ ...@@ -32,11 +32,32 @@
#include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h" #include "paddle/fluid/platform/profiler/mlu/mlu_tracer.h"
#include "paddle/fluid/platform/profiler/trace_event_collector.h" #include "paddle/fluid/platform/profiler/trace_event_collector.h"
#include "paddle/fluid/platform/profiler/utils.h" #include "paddle/fluid/platform/profiler/utils.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle { namespace paddle {
namespace platform { namespace platform {
void SynchronizeAllDevice(); void SynchronizeDevice() {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_GPU_SUCCESS(cudaDeviceSynchronize());
#endif
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipDeviceSynchronize());
#endif
#ifdef PADDLE_WITH_MLU
PADDLE_ENFORCE_MLU_SUCCESS(cnrtSyncDevice());
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : dev_types) {
auto i = phi::DeviceManager::GetDevice(dev_type);
auto place = paddle::platform::CustomPlace(dev_type, i);
phi::DeviceManager::SynchronizeDevice(place);
}
#endif
}
std::atomic<bool> Profiler::alive_{false}; std::atomic<bool> Profiler::alive_{false};
...@@ -99,7 +120,7 @@ void Profiler::Prepare() { ...@@ -99,7 +120,7 @@ void Profiler::Prepare() {
} }
void Profiler::Start() { void Profiler::Start() {
SynchronizeAllDevice(); SynchronizeDevice();
for (auto& tracer : tracers_) { for (auto& tracer : tracers_) {
tracer.Get().StartTracing(); tracer.Get().StartTracing();
} }
...@@ -107,7 +128,7 @@ void Profiler::Start() { ...@@ -107,7 +128,7 @@ void Profiler::Start() {
} }
std::unique_ptr<ProfilerResult> Profiler::Stop() { std::unique_ptr<ProfilerResult> Profiler::Stop() {
SynchronizeAllDevice(); SynchronizeDevice();
TraceEventCollector collector; TraceEventCollector collector;
for (auto& tracer : tracers_) { for (auto& tracer : tracers_) {
tracer.Get().StopTracing(); tracer.Get().StopTracing();
......
...@@ -37,6 +37,8 @@ static constexpr uint32_t kProfileGPUOptionBit = 1; ...@@ -37,6 +37,8 @@ static constexpr uint32_t kProfileGPUOptionBit = 1;
static constexpr uint32_t kProfileMLUOptionBit = 2; static constexpr uint32_t kProfileMLUOptionBit = 2;
static constexpr uint32_t kProfileCustomDeviceOptionBit = 3; static constexpr uint32_t kProfileCustomDeviceOptionBit = 3;
void SynchronizeDevice();
struct ProfilerOptions { struct ProfilerOptions {
uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu, bit 2: mlu uint32_t trace_switch = 0; // bit 0: cpu, bit 1: gpu, bit 2: mlu
uint32_t trace_level = FLAGS_host_trace_level; uint32_t trace_level = FLAGS_host_trace_level;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册