From 7235fd662b5af2f5999beb266025320e1ebd30ec Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 1 Mar 2019 05:41:39 -0600 Subject: [PATCH] Add Event for TensorCopy (#15953) Add Event for TensorCopy --- paddle/fluid/framework/CMakeLists.txt | 4 +- paddle/fluid/framework/tensor_util.cc | 7 +++ paddle/fluid/memory/CMakeLists.txt | 2 +- paddle/fluid/memory/memcpy.cc | 20 ++++++ .../fluid/operators/reader/buffered_reader.cc | 23 ++++--- paddle/fluid/platform/device_tracer.cc | 63 ++++++++++++++++--- paddle/fluid/platform/device_tracer.h | 13 +++- tools/timeline.py | 2 +- 8 files changed, 111 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 7ddf1ab44f..b9491c953f 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -38,10 +38,10 @@ if(WITH_GPU) nv_library(tensor SRCS tensor.cc .tensor_util.cu DEPS place memory data_type device_context) add_dependencies(tensor tensor_util) else() - nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context ) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS place memory data_type device_context profiler) endif(WIN32) else() - cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context ) + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS place memory data_type device_context profiler) endif() cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) diff --git a/paddle/fluid/framework/tensor_util.cc b/paddle/fluid/framework/tensor_util.cc index 85d15c5d3f..a7f09df491 100644 --- a/paddle/fluid/framework/tensor_util.cc +++ b/paddle/fluid/framework/tensor_util.cc @@ -14,8 +14,11 @@ #include "paddle/fluid/framework/tensor_util.h" #include #include +#include +#include #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { @@ -135,16 +138,19 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, #ifdef PADDLE_WITH_CUDA else if (platform::is_gpu_place(src_place) && // NOLINT platform::is_cpu_place(dst_place)) { + platform::RecordEvent record_event("TensorCopy:GPU->CPU"); auto src_gpu_place = boost::get(src_place); auto dst_cpu_place = boost::get(dst_place); memory::Copy(dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); } else if (platform::is_cpu_place(src_place) && platform::is_gpu_place(dst_place)) { + platform::RecordEvent record_event("TensorCopy:CPU->GPU"); auto src_cpu_place = boost::get(src_place); auto dst_gpu_place = boost::get(dst_place); memory::Copy(dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, nullptr); } else if (platform::is_gpu_place(src_place) && platform::is_gpu_place(dst_place)) { + platform::RecordEvent record_event("TensorCopy:GPU->GPU"); if (src_ptr == dst_ptr && platform::is_same_place(src_place, dst_place)) { VLOG(3) << "Skip copy the same data from " << src_place << " to " << dst_place; @@ -155,6 +161,7 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place, memory::Copy(dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, nullptr); } else if (platform::is_cuda_pinned_place(src_place) && platform::is_gpu_place(dst_place)) { + platform::RecordEvent record_event("TensorCopy:CUDAPinned->GPU"); auto src_pinned_place = boost::get(src_place); auto dst_gpu_place = boost::get(dst_place); memory::Copy(dst_gpu_place, dst_ptr, src_pinned_place, src_ptr, size, diff --git a/paddle/fluid/memory/CMakeLists.txt b/paddle/fluid/memory/CMakeLists.txt index e726807764..7eb663ea28 100644 --- a/paddle/fluid/memory/CMakeLists.txt +++ b/paddle/fluid/memory/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(detail) add_subdirectory(allocation) -cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade) +cc_library(malloc SRCS malloc.cc DEPS place enforce allocator_facade profiler) cc_library(memcpy SRCS memcpy.cc DEPS place) cc_library(memory diff --git a/paddle/fluid/memory/memcpy.cc b/paddle/fluid/memory/memcpy.cc index 2a6f70a01e..1408163e4b 100644 --- a/paddle/fluid/memory/memcpy.cc +++ b/paddle/fluid/memory/memcpy.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memcpy.h" #include // for memcpy +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace memory { @@ -29,14 +30,23 @@ void Copy(platform::CPUPlace, void* dst, #ifdef PADDLE_WITH_CUDA static constexpr size_t kMaxGpuAsyncCopyBytes = 64 * 1024; // 64K +// NOTE(zcd): Do not use GpuMemcpySync as much as possible. +// because GpuMemcpySync issues the copying command to the default stream, +// which will make two commands from different streams cannot run concurrently. +// Reference: +// https://devblogs.nvidia.com/gpu-pro-tip-cuda-7-streams-simplify-concurrency/ + template <> void Copy( platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place, const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(src_place.device); + if (stream) { + platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CPU"); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } else { + platform::RecordEvent record_event("GpuMemcpySync:GPU->CPU"); platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); // FIXME(zjl): do we really need it? if (num <= kMaxGpuAsyncCopyBytes) { @@ -51,8 +61,10 @@ void Copy( const void* src, size_t num, cudaStream_t stream) { platform::SetDeviceId(dst_place.device); if (stream) { + platform::RecordEvent record_event("GpuMemcpyAsync:CPU->GPU"); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } else { + platform::RecordEvent record_event("GpuMemcpySync:CPU->GPU"); platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); // FIXME(zjl): do we really need it? if (num <= kMaxGpuAsyncCopyBytes) { @@ -68,15 +80,19 @@ void Copy( if (dst_place == src_place) { platform::SetDeviceId(src_place.device); if (stream) { + platform::RecordEvent record_event("GpuMemcpyAsync(same_gpu):GPU->GPU"); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream); } else { + platform::RecordEvent record_event("GpuMemcpySync(same_gpu):GPU->GPU"); platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice); } } else { if (stream) { + platform::RecordEvent record_event("GpuMemcpyPeerAsync:GPU->GPU"); platform::GpuMemcpyPeerAsync(dst, dst_place.device, src, src_place.device, num, stream); } else { + platform::RecordEvent record_event("GpuMemcpyPeerSync:GPU->GPU"); platform::GpuMemcpyPeerSync(dst, dst_place.device, src, src_place.device, num); } @@ -111,8 +127,10 @@ void Copy( cudaStream_t stream) { platform::SetDeviceId(src_place.device); if (stream) { + platform::RecordEvent record_event("GpuMemcpyAsync:GPU->CUDAPinned"); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream); } else { + platform::RecordEvent record_event("GpuMemcpySync:GPU->CUDAPinned"); platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToHost); } } @@ -124,8 +142,10 @@ void Copy( cudaStream_t stream) { platform::SetDeviceId(dst_place.device); if (stream) { + platform::RecordEvent record_event("GpuMemcpyAsync:CUDAPinned->GPU"); platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream); } else { + platform::RecordEvent record_event("GpuMemcpySync:CUDAPinned->GPU"); platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice); } } diff --git a/paddle/fluid/operators/reader/buffered_reader.cc b/paddle/fluid/operators/reader/buffered_reader.cc index defc29b91f..84322f00da 100644 --- a/paddle/fluid/operators/reader/buffered_reader.cc +++ b/paddle/fluid/operators/reader/buffered_reader.cc @@ -13,9 +13,11 @@ // limitations under the License. #include "paddle/fluid/operators/reader/buffered_reader.h" +#include #include #include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { namespace reader { @@ -49,9 +51,10 @@ BufferedReader::BufferedReader( .Get(place_))) ->stream(); events.resize(buffer_size); - for (auto &event : events) + PADDLE_ENFORCE(cudaStreamCreate(&stream)); + for (auto &event : events) { PADDLE_ENFORCE(cudaEventCreateWithFlags(&event, cudaEventDisableTiming)); - PADDLE_ENFORCE(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + } } #endif cpu_buffer_.resize(buffer_size); @@ -83,12 +86,15 @@ void BufferedReader::ReadAsync(size_t i) { #ifdef PADDLE_WITH_CUDA // NOTE(liangdun): using async copy instead of TensorCopySync - // TensorCopySync would block other stream + // TensorCopySync would block other stream, because TensorCopySync + // issues the copying command to the default stream, it will make two + // commands from different streams cannot run concurrently. if (platform::is_gpu_place(place_)) { platform::SetDeviceId(boost::get(place_).device); PADDLE_ENFORCE(cudaStreamWaitEvent(stream, events[i], 0)); TensorVec &gpu = gpu_buffer_[i]; gpu.resize(cpu.size()); + platform::RecordEvent record_event("BufferedReader:MemoryCopy"); for (size_t i = 0; i < cpu.size(); ++i) { gpu[i].Resize(cpu[i].dims()); gpu[i].set_layout(cpu[i].layout()); @@ -97,20 +103,19 @@ void BufferedReader::ReadAsync(size_t i) { auto gpu_ptr = gpu[i].mutable_data(place_, cpu[i].type()); auto size = cpu[i].numel() * paddle::framework::SizeOfType(cpu[i].type()); - if (platform::is_cuda_pinned_place(cpu_place)) + if (platform::is_cuda_pinned_place(cpu_place)) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, size, stream); - else if ((platform::is_gpu_place(cpu_place))) + } else if ((platform::is_gpu_place(cpu_place))) { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, size, stream); - else - // if cpu place is not pinned, async copy is slower than sync copy, - // so we use sync copy instead. + } else { memory::Copy(boost::get(place_), gpu_ptr, boost::get(cpu_place), cpu_ptr, size, - 0); + stream); + } gpu[i].set_lod(cpu[i].lod()); } PADDLE_ENFORCE(cudaStreamSynchronize(stream)); diff --git a/paddle/fluid/platform/device_tracer.cc b/paddle/fluid/platform/device_tracer.cc index 0179daa557..b084f1a649 100644 --- a/paddle/fluid/platform/device_tracer.cc +++ b/paddle/fluid/platform/device_tracer.cc @@ -30,7 +30,6 @@ limitations under the License. */ #include "glog/logging.h" #include "google/protobuf/text_format.h" #include "paddle/fluid/framework/block_desc.h" -#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/string/printf.h" namespace paddle { @@ -222,19 +221,24 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer, } case CUPTI_ACTIVITY_KIND_DRIVER: { auto *api = reinterpret_cast(record); - if (api->start != 0 && api->end != 0) - // -1 device id represents CUDA api call - tracer->AddCPURecords( + if (api->start != 0 && api->end != 0) { + // -1 device id represents ActiveKind api call + tracer->AddActiveKindRecords( DriverKind(api->cbid), api->start, api->end, -1, - GetThreadIdFromSystemThreadId(api->threadId)); + GetThreadIdFromSystemThreadId(api->threadId), + api->correlationId); + } break; } case CUPTI_ACTIVITY_KIND_RUNTIME: { auto *api = reinterpret_cast(record); - if (api->start != 0 && api->end != 0) - tracer->AddCPURecords( + if (api->start != 0 && api->end != 0) { + // -1 device id represents ActiveKind api call + tracer->AddActiveKindRecords( RuntimeKind(api->cbid), api->start, api->end, -1, - GetThreadIdFromSystemThreadId(api->threadId)); + GetThreadIdFromSystemThreadId(api->threadId), + api->correlationId); + } break; } default: { break; } @@ -313,6 +317,25 @@ class DeviceTracerImpl : public DeviceTracer { stream_id, correlation_id, bytes}); } + void AddActiveKindRecords(const std::string &anno, uint64_t start_ns, + uint64_t end_ns, int64_t device_id, + int64_t thread_id, uint32_t correlation_id) { + if (anno.empty()) { + VLOG(1) << "Empty timeline annotation."; + return; + } + thread_local std::forward_list + *local_active_kind_records = nullptr; + if (local_active_kind_records == nullptr) { + std::lock_guard l(trace_mu_); + active_kind_records_.emplace_front(); + local_active_kind_records = &active_kind_records_.front(); + } + // lock is not needed, only one thread call this function. + local_active_kind_records->push_front(ActiveKindRecord{ + anno, start_ns, end_ns, device_id, thread_id, correlation_id}); + } + void AddKernelRecords(std::string name, uint64_t start, uint64_t end, int64_t device_id, int64_t stream_id, uint32_t correlation_id) { @@ -355,6 +378,7 @@ class DeviceTracerImpl : public DeviceTracer { } const std::vector cbids { CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020, + CUPTI_RUNTIME_TRACE_CBID_cudaSetupArgument_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaMemcpyAsync_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaMemset_v3020, CUPTI_RUNTIME_TRACE_CBID_cudaMemsetAsync_v3020, @@ -385,6 +409,7 @@ class DeviceTracerImpl : public DeviceTracer { correlations_.clear(); for (auto &tmp : correlations_pairs) tmp.clear(); for (auto &tmp : cpu_records_) tmp.clear(); + for (auto &tmp : active_kind_records_) tmp.clear(); } void GenEventKernelCudaElapsedTime() { @@ -437,7 +462,7 @@ class DeviceTracerImpl : public DeviceTracer { event->set_device_id(r.device_id); } VLOG(1) << "KernelRecord event miss: " << miss << " find: " << find; - for (auto &tmp : cpu_records_) + for (auto &tmp : cpu_records_) { for (const CPURecord &r : tmp) { auto *event = profile_pb.add_events(); event->set_type(proto::Event::CPU); @@ -447,6 +472,24 @@ class DeviceTracerImpl : public DeviceTracer { event->set_sub_device_id(r.thread_id); event->set_device_id(r.device_id); } + } + for (auto &tmp : active_kind_records_) { + for (const ActiveKindRecord &r : tmp) { + auto *event = profile_pb.add_events(); + event->set_type(proto::Event::CPU); + auto c = correlations_.find(r.correlation_id); + if (c != correlations_.end() && c->second != nullptr) { + event->set_name(c->second->name()); + event->set_detail_info(r.name); + } else { + event->set_name(r.name); + } + event->set_start_ns(r.start_ns); + event->set_end_ns(r.end_ns); + event->set_sub_device_id(r.thread_id); + event->set_device_id(r.device_id); + } + } miss = find = 0; for (const MemRecord &r : mem_records_) { auto *event = profile_pb.add_events(); @@ -510,6 +553,7 @@ class DeviceTracerImpl : public DeviceTracer { std::forward_list kernel_records_; std::forward_list mem_records_; std::forward_list> cpu_records_; + std::forward_list> active_kind_records_; std::forward_list>> correlations_pairs; std::unordered_map correlations_; @@ -613,6 +657,7 @@ void initCuptiCbidStr() { REGISTER_RUNTIME_CBID_STR(cudaUnbindTexture_v3020); REGISTER_RUNTIME_CBID_STR(cudaSetupArgument_v3020); REGISTER_RUNTIME_CBID_STR(cudaLaunch_v3020); + REGISTER_RUNTIME_CBID_STR(cudaDeviceGetPCIBusId_v4010); #if CUDA_VERSION >= 9000 REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernel_v9000); REGISTER_RUNTIME_CBID_STR(cudaLaunchCooperativeKernelMultiDevice_v9000); diff --git a/paddle/fluid/platform/device_tracer.h b/paddle/fluid/platform/device_tracer.h index d4418d836d..a8f1d89383 100644 --- a/paddle/fluid/platform/device_tracer.h +++ b/paddle/fluid/platform/device_tracer.h @@ -63,7 +63,14 @@ class DeviceTracer { uint32_t correlation_id; uint64_t bytes; }; - + struct ActiveKindRecord { + std::string name; + uint64_t start_ns; + uint64_t end_ns; + int64_t device_id; + int64_t thread_id; + uint32_t correlation_id; + }; virtual ~DeviceTracer() {} // Needs to be called once before use. virtual void Enable() = 0; @@ -85,6 +92,10 @@ class DeviceTracer { virtual void AddCPURecords(const std::string& anno, uint64_t start_ns, uint64_t end_ns, int64_t device_id, int64_t thread_id) = 0; + virtual void AddActiveKindRecords(const std::string& anno, uint64_t start_ns, + uint64_t end_ns, int64_t device_id, + int64_t thread_id, + uint32_t correlation_id) = 0; // Add a cuda kernel stats. `correlation_id` will be mapped to annotation // added before for human readability. diff --git a/tools/timeline.py b/tools/timeline.py index ebadb29bdb..7879666417 100644 --- a/tools/timeline.py +++ b/tools/timeline.py @@ -131,7 +131,7 @@ class Timeline(object): if (k, event.device_id, "CPU") not in self._devices: pid = self._allocate_pid() self._devices[(k, event.device_id, "CPU")] = pid - # -1 device id represents CUDA api call + # -1 device id represents CUDA API(RunTime) call.(e.g. cudaLaunch, cudaMemcpy) if event.device_id == -1: self._chrome_trace.emit_pid("%s:cuda_api" % k, pid) else: -- GitLab