提交 7643c2cb 编写于 作者: Y Yu Yang

Add flag for use event

上级 ca4b3d25
......@@ -86,8 +86,8 @@ struct OpHandle {
virtual ~OpHandle() {}
void Run() {
if (events_.empty()) {
void Run(bool use_event) {
if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id);
......@@ -97,6 +97,7 @@ struct OpHandle {
RunImpl();
if (use_event) {
for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream =
......@@ -104,9 +105,10 @@ struct OpHandle {
cudaEventRecord(events_.at(dev_id), stream);
}
}
}
virtual void Wait(platform::DeviceContext *waited_dev) {
if (platform::is_cpu_place(waited_dev->GetPlace())) {
if (platform::is_cpu_place(waited_dev->GetPlace()) && events_.empty()) {
for (auto &dev_ctx : dev_ctx_) {
dev_ctx.second->Wait();
}
......@@ -677,7 +679,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) {
VLOG(3) << "Run iter";
bool use_event = false;
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
// Version --> VarHandle
member_->exception_.reset();
......@@ -748,7 +750,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
RunOp(pending_vars, op);
RunOp(use_event, pending_vars, op);
}
while (!pending_vars.empty()) {
......@@ -776,7 +778,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
for (auto *op : to_run) {
pending_ops.erase(op);
RunOp(pending_vars, op);
RunOp(use_event, pending_vars, op);
}
}
......@@ -790,6 +792,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
}
void ParallelExecutor::RunOp(
bool use_event,
std::unordered_map<VarHandleBase *, std::atomic<bool>> &pending_vars,
OpHandle *op) const {
std::vector<std::atomic<bool> *> *ready_buffer =
......@@ -798,10 +801,10 @@ void ParallelExecutor::RunOp(
ready_buffer->emplace_back(&pending_vars[var]);
}
auto op_run = [ready_buffer, op, this] {
auto op_run = [ready_buffer, op, this, use_event] {
try {
VLOG(10) << op->DebugString();
op->Run();
op->Run(use_event);
for (auto *ready : *ready_buffer) {
ready->store(true, std::memory_order_release);
}
......
......@@ -62,6 +62,7 @@ class ParallelExecutor {
void BuildNCCLCommunicator() const;
void RunOp(
bool use_event,
std::unordered_map<VarHandleBase*, std::atomic<bool>>& pending_vars,
OpHandle* op) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册