communicator.cc 8.6 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed/communicator.h"

Q
Qiao Longfei 已提交
17
#include <gflags/gflags.h>
Q
Qiao Longfei 已提交
18 19 20
#include <chrono>  // NOLINT
#include <thread>  // NOLINT

Q
Qiao Longfei 已提交
21
#include "paddle/fluid/framework/eigen.h"
Q
Qiao Longfei 已提交
22 23 24 25 26 27
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/parameter_recv.h"
#include "paddle/fluid/operators/distributed/parameter_send.h"

Q
Qiao Longfei 已提交
28 29 30 31
DEFINE_bool(communicator_independent_recv_thread, true,
            "use an independent to recv vars from parameter server");
DEFINE_int32(communicator_send_queue_size, 20,
             "queue size to recv gradient before send");
32
DEFINE_int32(communicator_min_send_grad_num_before_recv, 20,
33
             "max grad num to send before recv parameters");
34
DEFINE_int32(communicator_thread_pool_size, 5, "thread num to do send or recv");
Q
Qiao Longfei 已提交
35 36 37
DEFINE_int32(communicator_send_wait_times, 5,
             "times that send thread will wait if merge num does not reach "
             "max_merge_var_num");
38 39 40 41
DEFINE_int32(communicator_max_merge_var_num, 20,
             "max var num to merge and send");
DEFINE_bool(communicator_fake_rpc, false,
            "fake mode does not really send any thing");
Q
Qiao Longfei 已提交
42

Q
Qiao Longfei 已提交
43 44 45 46
namespace paddle {
namespace operators {
namespace distributed {

Q
Qiao Longfei 已提交
47 48 49 50 51 52
inline double GetCurrentUS() {
  struct timeval time;
  gettimeofday(&time, NULL);
  return 1e+6 * time.tv_sec + time.tv_usec;
}

Q
can run  
Qiao Longfei 已提交
53 54 55
std::unique_ptr<Communicator> Communicator::communicator_(nullptr);
std::once_flag Communicator::init_flag_;

Q
Qiao Longfei 已提交
56 57 58 59 60 61 62 63 64 65 66
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
                           const RpcCtxMap &recv_varname_to_ctx,
                           Scope *recv_scope)
    : send_varname_to_ctx_(send_varname_to_ctx),
      recv_varname_to_ctx_(recv_varname_to_ctx),
      recv_scope_(recv_scope) {
  // get all send information from graph, build vars_to_send
  VLOG(0) << "communicator_independent_recv_thread: "
          << FLAGS_communicator_independent_recv_thread;
  VLOG(0) << "communicator_send_queue_size: "
          << FLAGS_communicator_send_queue_size;
67 68
  VLOG(0) << "communicator_min_send_grad_num_before_recv: "
          << FLAGS_communicator_min_send_grad_num_before_recv;
Q
Qiao Longfei 已提交
69 70
  VLOG(0) << "communicator_thread_pool_size: "
          << FLAGS_communicator_thread_pool_size;
71 72
  VLOG(0) << "communicator_send_wait_times: "
          << FLAGS_communicator_send_wait_times;
Q
Qiao Longfei 已提交
73
  VLOG(0) << "communicator_max_merge_var_num: "
74 75
          << FLAGS_communicator_max_merge_var_num;
  VLOG(0) << "communicator_fake_rpc: " << FLAGS_communicator_fake_rpc;
Q
Qiao Longfei 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
  send_scope_.reset(new Scope());
  for (auto &iter : send_varname_to_ctx_) {
    send_varname_to_queue_[iter.first] =
        std::make_shared<BlockingQueue<std::shared_ptr<Variable>>>(
            FLAGS_communicator_send_queue_size);
  }
  send_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
  recv_threadpool_.reset(new ::ThreadPool(FLAGS_communicator_thread_pool_size));
}

Communicator::~Communicator() {
  VLOG(3) << "~Communicator";
  running_ = false;
  if (send_thread_) send_thread_->join();
  if (recv_thread_) recv_thread_->join();
  VLOG(3) << "~Communicator done";
}

Q
Qiao Longfei 已提交
94
void Communicator::SendThread() {
Q
Qiao Longfei 已提交
95
  VLOG(3) << "SendThread start!";
Q
Qiao Longfei 已提交
96 97 98
  while (running_) {
    std::vector<std::future<void>> task_futures;
    task_futures.reserve(send_varname_to_ctx_.size());
Q
Qiao Longfei 已提交
99
    VLOG(3) << "run send graph";
Q
Qiao Longfei 已提交
100
    auto before_run_send_graph = GetCurrentUS();
Q
Qiao Longfei 已提交
101
    for (auto &iter : send_varname_to_queue_) {
Q
Qiao Longfei 已提交
102 103
      auto &var_name = iter.first;
      auto &var_queue = iter.second;
Q
Qiao Longfei 已提交
104
      if (var_queue->Size() > 0) {
Q
Qiao Longfei 已提交
105
        auto send_task = [this, &var_name, &var_queue] {
Q
Qiao Longfei 已提交
106
          VLOG(3) << var_name << " merge and send";
Q
Qiao Longfei 已提交
107 108
          std::vector<std::shared_ptr<Variable>> vars;
          size_t merged_var_num = 0;
Q
Qiao Longfei 已提交
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
          size_t wait_times = 0;
          while (merged_var_num < FLAGS_communicator_max_merge_var_num) {
            if (var_queue->Size() == 0) {
              VLOG(3) << "wait_times -> " << wait_times;
              if (wait_times >= FLAGS_communicator_send_wait_times) {
                break;
              }
              std::this_thread::sleep_for(std::chrono::milliseconds(10));
              wait_times++;
              continue;
            } else {
              wait_times = 0;

              vars.push_back(var_queue->Pop());
              // only count the send number of the first var
              if (var_name == send_varname_to_queue_.begin()->first) {
                grad_num_.fetch_add(1, std::memory_order_relaxed);
              }
              merged_var_num++;
128
            }
Q
Qiao Longfei 已提交
129
          }
Q
Qiao Longfei 已提交
130
          auto before_merge = GetCurrentUS();
Q
Qiao Longfei 已提交
131
          MergeVars(var_name, vars, send_scope_.get());
Q
Qiao Longfei 已提交
132
          auto after_merge = GetCurrentUS();
Q
Qiao Longfei 已提交
133 134
          VLOG(3) << "merge " << merged_var_num << " " << var_name
                  << " use time " << after_merge - before_merge;
Q
Qiao Longfei 已提交
135 136
          auto send_functor = distributed::ParameterSend<float>();
          auto &ctx = send_varname_to_ctx_.at(var_name);
137 138 139
          if (!FLAGS_communicator_fake_rpc) {
            send_functor(ctx, *send_scope_, true);
          }
Q
Qiao Longfei 已提交
140 141 142
          auto after_send = GetCurrentUS();
          VLOG(3) << "send " << var_name << " use time "
                  << after_send - after_merge;
Q
Qiao Longfei 已提交
143 144 145
        };
        task_futures.emplace_back(
            send_threadpool_->enqueue(std::move(send_task)));
Q
Qiao Longfei 已提交
146 147
      } else {
        VLOG(3) << var_name << " queue empty";
Q
Qiao Longfei 已提交
148
      }
Q
Qiao Longfei 已提交
149 150 151
    }
    for (auto &task_f : task_futures) {
      task_f.wait();
Q
Qiao Longfei 已提交
152
    }
Q
Qiao Longfei 已提交
153
    auto after_run_send_graph = GetCurrentUS();
Q
Qiao Longfei 已提交
154
    auto send_graph_use_time = after_run_send_graph - before_run_send_graph;
155
    if (send_graph_use_time > 100) {
Q
Qiao Longfei 已提交
156 157 158
      VLOG(1) << "run send graph use time "
              << after_run_send_graph - before_run_send_graph;
    }
Q
Qiao Longfei 已提交
159 160 161
    if (!FLAGS_communicator_independent_recv_thread) {
      RecvAll();
    }
Q
Qiao Longfei 已提交
162 163 164
  }
}

Q
Qiao Longfei 已提交
165 166
void Communicator::RecvAll() {
  VLOG(3) << "parallel run recv graph";
Q
Qiao Longfei 已提交
167
  auto before_send = GetCurrentUS();
Q
Qiao Longfei 已提交
168 169 170 171 172 173 174
  std::vector<std::future<void>> task_futures;
  task_futures.reserve(recv_varname_to_ctx_.size());
  for (auto &iter : recv_varname_to_ctx_) {
    auto recv_task = [this, &iter] {
      auto &var_name = iter.first;
      VLOG(3) << "recv var " << var_name;
      auto recv_functor = distributed::ParameterRecv<float>();
175 176 177
      if (!FLAGS_communicator_fake_rpc) {
        recv_functor(iter.second, *recv_scope_);
      }
Q
Qiao Longfei 已提交
178 179 180 181 182 183
    };
    task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
  }
  for (auto &task : task_futures) {
    task.wait();
  }
Q
Qiao Longfei 已提交
184
  auto after_recv = GetCurrentUS();
Q
Qiao Longfei 已提交
185
  VLOG(1) << "run recv graph use time " << after_recv - before_send;
Q
Qiao Longfei 已提交
186 187
}

Q
Qiao Longfei 已提交
188
void Communicator::RecvThread() {
Q
Qiao Longfei 已提交
189
  VLOG(3) << "RecvThread start!";
Q
Qiao Longfei 已提交
190
  while (running_) {
191
    auto grad_num = grad_num_.load();
192
    if (grad_num > FLAGS_communicator_min_send_grad_num_before_recv) {
193 194 195 196 197 198
      VLOG(1) << "current grad num " << grad_num;
      RecvAll();
      grad_num_.store(0);
    } else {
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
    }
Q
Qiao Longfei 已提交
199 200 201 202 203
  }
}

void Communicator::Send(const std::string &var_name,
                        const framework::Scope &scope) {
Q
Qiao Longfei 已提交
204 205 206 207 208 209 210 211 212
  VLOG(3) << "communicator send " << var_name;
  // push var into send queue by var_name
  auto *grad_var = scope.FindVar(var_name);
  PADDLE_ENFORCE(grad_var->IsInitialized(), "grad var should be inited");
  auto tmp_grad_var = std::make_shared<Variable>();
  framework::CopyVariable(*grad_var, tmp_grad_var.get());
  auto &queue = send_varname_to_queue_.at(var_name);
  VLOG(3) << "send " << var_name << " queue size " << queue->Size();
  queue->Push(tmp_grad_var);
Q
Qiao Longfei 已提交
213 214
}

Q
can run  
Qiao Longfei 已提交
215 216
Communicator *Communicator::GetInstance() { return communicator_.get(); }

Q
Qiao Longfei 已提交
217
void Communicator::Start() {
Q
Qiao Longfei 已提交
218
  running_ = true;
Q
Qiao Longfei 已提交
219 220 221
  // start send and recv thread
  send_thread_.reset(
      new std::thread(std::bind(&Communicator::SendThread, this)));
Q
Qiao Longfei 已提交
222 223 224 225
  if (FLAGS_communicator_independent_recv_thread) {
    recv_thread_.reset(
        new std::thread(std::bind(&Communicator::RecvThread, this)));
  }
Q
Qiao Longfei 已提交
226 227 228 229 230
}

}  // namespace distributed
}  // namespace operators
}  // namespace paddle