提交 7643c2cb 编写于 作者: Y Yu Yang

Add flag for use event

上级 ca4b3d25
...@@ -86,8 +86,8 @@ struct OpHandle { ...@@ -86,8 +86,8 @@ struct OpHandle {
virtual ~OpHandle() {} virtual ~OpHandle() {}
void Run() { void Run(bool use_event) {
if (events_.empty()) { if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) { for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id); cudaSetDevice(dev_id);
...@@ -97,16 +97,18 @@ struct OpHandle { ...@@ -97,16 +97,18 @@ struct OpHandle {
RunImpl(); RunImpl();
for (auto &p : dev_ctx_) { if (use_event) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; for (auto &p : dev_ctx_) {
auto stream = int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
static_cast<platform::CUDADeviceContext *>(p.second)->stream(); auto stream =
cudaEventRecord(events_.at(dev_id), stream); static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream);
}
} }
} }
virtual void Wait(platform::DeviceContext *waited_dev) { virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace())) { if (platform::is_cpu_place(waited_dev->GetPlace()) && events_.empty()) {
for (auto &dev_ctx : dev_ctx_) { for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait(); dev_ctx.second->Wait();
} }
...@@ -677,7 +679,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const { ...@@ -677,7 +679,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
VLOG(3) << "Run iter"; bool use_event = false;
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size()); auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle // Version --> VarHandle
member_->exception_.reset(); member_->exception_.reset();
...@@ -748,7 +750,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -748,7 +750,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
RunOp(pending_vars, op); RunOp(use_event, pending_vars, op);
} }
while (!pending_vars.empty()) { while (!pending_vars.empty()) {
...@@ -776,7 +778,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -776,7 +778,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
for (auto *op : to_run) { for (auto *op : to_run) {
pending_ops.erase(op); pending_ops.erase(op);
RunOp(pending_vars, op); RunOp(use_event, pending_vars, op);
} }
} }
...@@ -790,6 +792,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, ...@@ -790,6 +792,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
} }
void ParallelExecutor::RunOp( void ParallelExecutor::RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars, std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const { OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer = std::vector<std::atomic<bool> *> *ready_buffer =
...@@ -798,10 +801,10 @@ void ParallelExecutor::RunOp( ...@@ -798,10 +801,10 @@ void ParallelExecutor::RunOp(
ready_buffer->emplace_back(&pending_vars[var]); ready_buffer->emplace_back(&pending_vars[var]);
} }
auto op_run = [ready_buffer, op, this] { auto op_run = [ready_buffer, op, this, use_event] {
try { try {
VLOG(10) << op->DebugString(); VLOG(10) << op->DebugString();
op->Run(); op->Run(use_event);
for (auto *ready : *ready_buffer) { for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release); ready->store(true, std::memory_order_release);
} }
......
...@@ -62,6 +62,7 @@ class ParallelExecutor { ...@@ -62,6 +62,7 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const; void BuildNCCLCommunicator() const;
void RunOp( void RunOp(
bool use_event,
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars, std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const; OpHandle* op) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册