提交 e008600b 编写于 作者: Q qiaolongfei

optimize code

上级 7c649e06
......@@ -189,6 +189,8 @@ void CUPTIAPI bufferCompleted(CUcontext ctx, uint32_t streamId, uint8_t *buffer,
}
} // namespace
#endif // PADDLE_WITH_CUPTI
class DeviceTracerImpl : public DeviceTracer {
public:
DeviceTracerImpl() : enabled_(false) {}
......@@ -244,6 +246,8 @@ class DeviceTracerImpl : public DeviceTracer {
if (enabled_) {
return;
}
#ifdef PADDLE_WITH_CUPTI
EnableActivity();
// Register callbacks for buffer requests and completed by CUPTI.
......@@ -262,6 +266,7 @@ class DeviceTracerImpl : public DeviceTracer {
dynload::cuptiEnableCallback(1, subscriber_, CUPTI_CB_DOMAIN_DRIVER_API,
CUPTI_DRIVER_TRACE_CBID_cuLaunchKernel));
CUPTI_CALL(dynload::cuptiGetTimestamp(&start_ns_));
#endif // PADDLE_WITH_CUPTI
enabled_ = true;
}
......@@ -313,16 +318,20 @@ class DeviceTracerImpl : public DeviceTracer {
}
void Disable() {
std::lock_guard<std::mutex> l(trace_mu_);
#ifdef PADDLE_WITH_CUPTI
// flush might cause additional calls to DeviceTracker.
dynload::cuptiActivityFlushAll(CUPTI_ACTIVITY_FLAG_FLUSH_FORCED);
std::lock_guard<std::mutex> l(trace_mu_);
DisableActivity();
dynload::cuptiUnsubscribe(subscriber_);
CUPTI_CALL(dynload::cuptiGetTimestamp(&end_ns_));
#endif // PADDLE_WITH_CUPTI
enabled_ = false;
}
private:
#ifdef PADDLE_WITH_CUPTI
static void CUPTIAPI ApiCallback(void *userdata, CUpti_CallbackDomain domain,
CUpti_CallbackId cbid, const void *cbdata) {
auto *cbInfo = reinterpret_cast<const CUpti_CallbackData *>(cbdata);
......@@ -340,126 +349,8 @@ class DeviceTracerImpl : public DeviceTracer {
VLOG(1) << "Unhandled API Callback for " << domain << " " << cbid;
}
}
std::mutex trace_mu_;
bool enabled_;
uint64_t start_ns_;
uint64_t end_ns_;
std::vector<KernelRecord> kernel_records_;
std::vector<MemRecord> mem_records_;
std::vector<CPURecord> cpu_records_;
std::unordered_map<uint32_t, std::string> correlations_;
CUpti_SubscriberHandle subscriber_;
};
#endif // PADDLE_WITH_CUPTI
class DeviceTracerDummy : public DeviceTracer {
public:
DeviceTracerDummy() {}
void AddAnnotation(uint64_t id, const std::string &anno) {
std::lock_guard<std::mutex> l(trace_mu_);
correlations_[id] = anno;
}
void AddCPURecords(const std::string &anno, uint64_t start_ns,
uint64_t end_ns, int64_t device_id, int64_t thread_id) {
if (anno.empty()) {
VLOG(1) << "Empty timeline annotation.";
return;
}
std::lock_guard<std::mutex> l(trace_mu_);
cpu_records_.push_back(
CPURecord{anno, start_ns, end_ns, device_id, thread_id});
}
void AddMemRecords(const std::string &name, uint64_t start_ns,
uint64_t end_ns, int64_t device_id, int64_t stream_id,
uint32_t correlation_id, uint64_t bytes) {
// 0 means timestamp information could not be collected for the kernel.
if (start_ns == 0 || end_ns == 0) {
VLOG(3) << name << " cannot be traced";
return;
}
std::lock_guard<std::mutex> l(trace_mu_);
mem_records_.push_back(MemRecord{name, start_ns, end_ns, device_id,
stream_id, correlation_id, bytes});
}
void AddKernelRecords(uint64_t start, uint64_t end, int64_t device_id,
int64_t stream_id, uint32_t correlation_id) {}
bool IsEnabled() {
std::lock_guard<std::mutex> l(trace_mu_);
return enabled_;
}
void Enable() {
std::lock_guard<std::mutex> l(trace_mu_);
if (enabled_) {
return;
}
int64_t start_ns_ = PosixInNsec();
VLOG(3) << "start_ns_ = " << start_ns_;
enabled_ = true;
}
void Disable() {
std::lock_guard<std::mutex> l(trace_mu_);
uint64_t end_ns_ = PosixInNsec();
VLOG(3) << "end_ns_ = " << end_ns_;
enabled_ = false;
}
proto::Profile GenProfile(const std::string &profile_path) {
std::lock_guard<std::mutex> l(trace_mu_);
proto::Profile profile_pb;
profile_pb.set_start_ns(start_ns_);
profile_pb.set_end_ns(end_ns_);
for (const KernelRecord &r : kernel_records_) {
if (correlations_.find(r.correlation_id) == correlations_.end()) {
fprintf(stderr, "cannot relate a kernel activity\n");
continue;
}
auto *event = profile_pb.add_events();
event->set_type(proto::Event::GPUKernel);
event->set_name(correlations_.at(r.correlation_id));
event->set_start_ns(r.start_ns);
event->set_end_ns(r.end_ns);
event->set_sub_device_id(r.stream_id);
event->set_device_id(r.device_id);
}
for (const CPURecord &r : cpu_records_) {
auto *event = profile_pb.add_events();
event->set_type(proto::Event::CPU);
event->set_name(r.name);
event->set_start_ns(r.start_ns);
event->set_end_ns(r.end_ns);
event->set_sub_device_id(r.thread_id);
event->set_device_id(r.device_id);
}
for (const MemRecord &r : mem_records_) {
auto *event = profile_pb.add_events();
event->set_type(proto::Event::GPUKernel);
event->set_name(r.name);
event->set_start_ns(r.start_ns);
event->set_end_ns(r.end_ns);
event->set_sub_device_id(r.stream_id);
event->set_device_id(r.device_id);
event->mutable_memcopy()->set_bytes(r.bytes);
}
std::ofstream profile_f;
profile_f.open(profile_path, std::ios::out | std::ios::trunc);
std::string profile_str;
profile_pb.SerializeToString(&profile_str);
profile_f << profile_str;
profile_f.close();
return profile_pb;
}
private:
std::mutex trace_mu_;
bool enabled_;
uint64_t start_ns_;
......@@ -470,13 +361,7 @@ class DeviceTracerDummy : public DeviceTracer {
std::unordered_map<uint32_t, std::string> correlations_;
};
void CreateTracer(DeviceTracer **t) {
#ifdef PADDLE_WITH_CUPTI
*t = new DeviceTracerImpl();
#else
*t = new DeviceTracerDummy();
#endif // PADDLE_WITH_CUPTI
}
void CreateTracer(DeviceTracer **t) { *t = new DeviceTracerImpl(); }
DeviceTracer *GetDeviceTracer() {
std::call_once(tracer_once_flag, CreateTracer, &tracer);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册