listen_and_serv_op.cc 23.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
tangwei12 已提交
2

T
typhoonzero 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
T
tangwei12 已提交
6

T
typhoonzero 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
T
tangwei12 已提交
8

T
typhoonzero 已提交
9 10 11 12 13 14
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
yi.wu 已提交
15
#include <stdio.h>  // for removing the port file
16 17
#include <csignal>
#include <cstdlib>
T
typhoonzero 已提交
18
#include <fstream>
19 20
#include <thread>  // NOLINT
#include <vector>
T
typhoonzero 已提交
21

22 23
#include "gflags/gflags.h"

W
Wu Yi 已提交
24
#include "paddle/fluid/operators/distributed/distributed.h"
Y
Yancey1989 已提交
25
#include "paddle/fluid/operators/math/math_function.h"
G
gongweibao 已提交
26

27
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
28
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"
29
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
W
Wu Yi 已提交
30
#include "paddle/fluid/operators/distributed_ops/listen_and_serv_op.h"
31

32
#include "paddle/fluid/platform/profiler.h"
33

34 35 36
DEFINE_int32(rpc_send_thread_num, 12, "number of threads for rpc send");
DEFINE_int32(rpc_get_thread_num, 12, "number of threads for rpc get");
DEFINE_int32(rpc_prefetch_thread_num, 12, "number of threads for rpc prefetch");
37

T
typhoonzero 已提交
38 39 40
namespace paddle {
namespace operators {

41
void RunServer(std::shared_ptr<distributed::RPCServer> service) {
42
  service->StartServer();
M
minqiyang 已提交
43
  VLOG(4) << "RunServer thread end";
T
typhoonzero 已提交
44
}
Q
qiaolongfei 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
static void split(const std::string &str, char sep,
                  std::vector<std::string> *pieces) {
  pieces->clear();
  if (str.empty()) {
    return;
  }
  size_t pos = 0;
  size_t next = str.find(sep, pos);
  while (next != std::string::npos) {
    pieces->push_back(str.substr(pos, next - pos));
    pos = next + 1;
    next = str.find(sep, pos);
  }
  if (!str.substr(pos).empty()) {
    pieces->push_back(str.substr(pos));
  }
}

T
refine  
typhoonzero 已提交
63 64 65 66 67
static void ParallelExecuteBlocks(
    const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
    const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
        &prepared,
    framework::ProgramDesc *program, framework::Scope *scope) {
T
update  
typhoonzero 已提交
68 69
  std::vector<std::future<void>> fs;
  for (size_t idx : parallel_blkids) {
V
velconia 已提交
70 71 72
    fs.push_back(framework::Async([&executor, &prepared, &scope, idx]() {
      int run_block = idx;  // thread local
      try {
M
minqiyang 已提交
73 74
        VLOG(3) << "running server block: " << run_block
                << "pointer: " << prepared[run_block].get();
V
velconia 已提交
75 76
        executor->RunPreparedContext(prepared[run_block].get(), scope);
      } catch (const std::exception &e) {
77 78 79
        PADDLE_THROW(platform::errors::Fatal(
            "Run %d-th sub program failed. The exception is:\n%s.", idx,
            e.what()));
V
velconia 已提交
80 81
      }
    }));
T
update  
typhoonzero 已提交
82 83 84 85
  }
  for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}

T
typhoonzero 已提交
86 87 88 89 90 91
ListenAndServOp::ListenAndServOp(const std::string &type,
                                 const framework::VariableNameMap &inputs,
                                 const framework::VariableNameMap &outputs,
                                 const framework::AttributeMap &attrs)
    : OperatorBase(type, inputs, outputs, attrs) {}

92 93
ListenAndServOp::~ListenAndServOp() { Stop(); }

T
typhoonzero 已提交
94
void ListenAndServOp::Stop() {
95
  rpc_service_->ShutDown();
T
typhoonzero 已提交
96
  server_thread_->join();
Y
yi.wu 已提交
97 98
  auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
  remove(file_path.c_str());
99 100
}

Y
yi.wu 已提交
101
void ListenAndServOp::SavePort() const {
T
done  
typhoonzero 已提交
102
  // NOTE: default write file to /tmp/paddle.selected_port
103
  rpc_service_->SavePort();
T
done  
typhoonzero 已提交
104 105
}

G
gongweibao 已提交
106 107 108 109 110 111
static int64_t GetTimestamp() {
  struct timeval tp;
  gettimeofday(&tp, NULL);
  return tp.tv_sec * 1000 + tp.tv_usec / 1000;
}

112 113
void ListenAndServOp::RunSyncLoop(
    framework::Executor *executor, framework::ProgramDesc *program,
Y
Yancey1989 已提交
114
    framework::Scope *recv_scope, platform::DeviceContext *dev_ctx,
T
bug fix  
tangwei12 已提交
115
    const std::vector<int> &prefetch_block_id_list,
Y
Yancey1989 已提交
116
    const int checkpoint_point_block_id) const {
M
minqiyang 已提交
117
  VLOG(2) << "RunSyncLoop";
T
typhoonzero 已提交
118
  size_t num_blocks = program->Size();
119 120
  auto optimize_blocks =
      Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
T
typhoonzero 已提交
121 122 123
  PADDLE_ENFORCE_GE(num_blocks, 2,
                    "server program should have at least 2 blocks");

T
typhoonzero 已提交
124 125 126 127
  // Prepare all the server block
  std::vector<int> optimize_blocks_list;
  for (size_t i = 1; i < program->Size(); ++i) {
    optimize_blocks_list.push_back(i);
128
  }
T
typhoonzero 已提交
129 130 131
  auto optimize_prepared = executor->Prepare(*program, optimize_blocks_list);
  // Insert placeholder for block0 which holds current op itself,
  // NOTE the first block in `optimize_prepared` should never be ran.
132 133 134
  optimize_prepared.insert(
      optimize_prepared.begin(),
      std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
T
typhoonzero 已提交
135

G
gongweibao 已提交
136 137
  // Trainers will get all parameters from pserver in the
  // startup program, so we will wait RequestGet first
T
tangwei12 已提交
138 139 140
  rpc_service_->SetCond(distributed::kRequestGet);
  rpc_service_->WaitBarrier(distributed::kRequestGet);
  rpc_service_->ResetBarrierCounter();
Y
Yancey1989 已提交
141

142
  while (true) {
T
typhoonzero 已提交
143 144
    // Get from multiple trainers, we don't care about the order in which
    // the gradients arrives, just add suffix 0~n and merge the gradient.
T
tangwei12 已提交
145 146 147 148
    VLOG(3) << "wait all clients to send gradient";
    rpc_service_->SetCond(distributed::kRequestSend);
    VLOG(3) << "wait all clients to send send_barrier";
    rpc_service_->WaitBarrier(distributed::kRequestSend);
149

T
tangwei12 已提交
150
    if (rpc_service_->IsExit()) {
151
      LOG(WARNING) << "get exit!rpc_processor break!";
T
tangwei12 已提交
152
      rpc_service_->SetCond(distributed::kRequestGet);
T
typhoonzero 已提交
153 154
      break;
    }
T
typhoonzero 已提交
155

Q
qiaolongfei 已提交
156
    // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
T
typhoonzero 已提交
157 158 159
    // and this will still work.
    // The optimize blocks which have the same parent ID would run parallel
    // TODO(Yancey1989): need to use ParallelExecutor for future
160
    int32_t last_parent_blkid = optimize_blocks[0]->Parent();
T
typhoonzero 已提交
161
    std::vector<size_t> parallel_blkids;
162
    parallel_blkids.push_back(optimize_blocks[0]->ID());
G
gongweibao 已提交
163
    double ts = GetTimestamp();
164
    for (size_t i = 1; i < optimize_blocks.size(); ++i) {
Q
qiaolongfei 已提交
165 166
      // skip the first optimize block because it is already in the
      // parallel_blkids.
167
      int blkid = optimize_blocks[i]->ID();
168 169 170 171 172
      if (program->Block(blkid).Parent() != last_parent_blkid) {
        ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
                              program, recv_scope);
        parallel_blkids.clear();
        last_parent_blkid = program->Block(blkid).Parent();
T
typhoonzero 已提交
173
      }
174
      parallel_blkids.push_back(blkid);
T
typhoonzero 已提交
175
    }
Q
qiaolongfei 已提交
176 177
    ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
                          recv_scope);
Q
Qiao Longfei 已提交
178
    VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
T
typhoonzero 已提交
179

Q
Qiao Longfei 已提交
180
    VLOG(3) << "ResetReceivedVars";
Y
Yancey1989 已提交
181
    ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
Q
qiaolongfei 已提交
182

T
tangwei12 已提交
183 184 185 186 187 188
    VLOG(3) << "wait all clients to get parameters back";
    rpc_service_->SetCond(distributed::kRequestGet);
    VLOG(3) << "wait all clients to send fetch_barrier";
    rpc_service_->WaitBarrier(distributed::kRequestGet);
    VLOG(3) << "ResetBarrierCounter";
    rpc_service_->ResetBarrierCounter();
T
typhoonzero 已提交
189 190
  }  // while(true)
}
T
typhoonzero 已提交
191

Y
Yancey1989 已提交
192 193 194 195
void ListenAndServOp::ResetReceivedVars(framework::Scope *recv_scope,
                                        platform::DeviceContext *dev_ctx,
                                        bool reset_all) const {
  for (auto &varname : sparse_vars_) {
Y
Yancey1989 已提交
196 197
    auto var = recv_scope->FindVar(varname);
    if (var == nullptr) {
M
minqiyang 已提交
198
      VLOG(2) << "can not find var " << varname << " in received scope";
Y
Yancey1989 已提交
199 200 201
      continue;
    }
    if (var->IsType<framework::SelectedRows>()) {
M
minqiyang 已提交
202
      VLOG(3) << "reset sparse var: " << varname;
Y
Yancey1989 已提交
203
      var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
Y
Yancey1989 已提交
204 205
    } else {
      PADDLE_THROW("The type of sparse var should be SelectedRows");
Y
Yancey1989 已提交
206
    }
Y
Yancey1989 已提交
207 208 209 210 211
  }
  if (UNLIKELY(reset_all)) {
    for (auto &varname : dense_vars_) {
      auto var = recv_scope->FindVar(varname);
      if (var == nullptr) {
M
minqiyang 已提交
212
        VLOG(2) << "can not find var " << varname << " in received scope";
Y
Yancey1989 已提交
213 214
        continue;
      }
Y
Yancey1989 已提交
215 216 217 218 219 220 221
      if (var->IsType<framework::LoDTensor>()) {
        math::set_constant(*dev_ctx, var->GetMutable<framework::LoDTensor>(),
                           static_cast<float>(0));
      } else if (var->IsType<framework::Tensor>()) {
        math::set_constant(*dev_ctx, var->GetMutable<framework::Tensor>(),
                           static_cast<float>(0));
      } else {
Y
Yancey1989 已提交
222
        PADDLE_THROW("The type of dense var should be in [LoDTensor, Tensor]");
Y
Yancey1989 已提交
223 224 225 226 227
      }
    }
  }
}

228
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
Y
Yancey1989 已提交
229 230
                                   framework::ProgramDesc *program,
                                   framework::Scope *recv_scope) const {
M
minqiyang 已提交
231
  VLOG(2) << "RunAsyncLoop";
Q
qiaolongfei 已提交
232 233
  auto grad_to_block_id_str =
      Attr<std::vector<std::string>>("grad_to_block_id");
W
Wu Yi 已提交
234 235 236 237
  DoubleFindMap<std::string, int32_t> grad_to_block_id;

  auto append_block_maps = [](DoubleFindMap<std::string, int32_t> *out_map,
                              const std::string &grad_and_id) {
Q
qiaolongfei 已提交
238
    std::vector<std::string> pieces;
Q
qiaolongfei 已提交
239
    split(grad_and_id, ':', &pieces);
M
minqiyang 已提交
240
    VLOG(3) << "after split, key = " << pieces[0] << ", id=" << pieces[1];
Q
qiaolongfei 已提交
241
    PADDLE_ENFORCE_EQ(pieces.size(), 2);
W
Wu Yi 已提交
242
    PADDLE_ENFORCE_EQ(out_map->count(pieces[0]), 0);
243

Q
qiaolongfei 已提交
244
    int block_id = std::stoi(pieces[1]);
W
Wu Yi 已提交
245 246 247 248 249
    (*out_map)[pieces[0]] = block_id;
  };

  for (const auto &grad_and_id : grad_to_block_id_str) {
    append_block_maps(&grad_to_block_id, grad_and_id);
Q
qiaolongfei 已提交
250
  }
W
Wu Yi 已提交
251

Q
qiaolongfei 已提交
252 253 254 255 256 257
  size_t num_blocks = program->Size();
  PADDLE_ENFORCE_GE(num_blocks, 2,
                    "server program should have at least 2 blocks");

  std::vector<int> block_list;
  for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
Q
qiaolongfei 已提交
258
    block_list.push_back(blkid);
Q
qiaolongfei 已提交
259 260
  }
  auto optimize_prepared = executor->Prepare(*program, block_list);
W
Wu Yi 已提交
261 262 263 264 265
  // execute global block if needed, block id 1 in the program is global
  // block if it's not bind to a grad var for it's update.
  if (block_list[0] == 1 &&
      grad_to_block_id.find_value(static_cast<int32_t>(1)) ==
          grad_to_block_id.end()) {
Y
Yancey1989 已提交
266 267
    executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
  }
Q
qiaolongfei 已提交
268 269
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
W
Wu Yi 已提交
270
      grad_to_prepared_ctx, param_to_prepared_ctx;
Q
qiaolongfei 已提交
271
  for (size_t i = 0; i < block_list.size(); ++i) {
W
Wu Yi 已提交
272 273 274 275 276
    auto blkid = block_list[i];
    auto it = grad_to_block_id.find_value(blkid);
    if (it != grad_to_block_id.end()) {
      grad_to_prepared_ctx[it->first] = optimize_prepared[i];
    }
Q
qiaolongfei 已提交
277 278
  }

279 280 281
  request_send_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
  request_get_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
  request_prefetch_handler_->SetGradToPreparedCtx(&grad_to_prepared_ctx);
Q
qiaolongfei 已提交
282

283
  while (true) {
T
tangwei12 已提交
284
    if (rpc_service_->IsExit()) {
M
minqiyang 已提交
285
      VLOG(4) << "get exit!rpc_processor break!";
Q
qiaolongfei 已提交
286 287 288
      break;
    }

289
    sleep(1);
Q
qiaolongfei 已提交
290 291 292
  }  // while(true)
}

293
static void FillRequestCtx(
294
    distributed::RequestHandler *h, framework::Scope *scope,
295 296 297 298 299
    platform::DeviceContext *dev_ctx, framework::Executor *executor,
    framework::ProgramDesc *program,
    std::unordered_map<std::string,
                       std::shared_ptr<framework::ExecutorPrepareContext>>
        *prefetch_ctx,
300 301
    std::unordered_map<std::string, std::string>
        *sparse_grad_name_to_param_name,
302
    std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
303
    std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_ctx,
304
    distributed::RPCServer *rpc_server) {
305 306 307 308
  h->SetScope(scope);
  h->SetDevCtx(dev_ctx);
  h->SetExecutor(executor);
  h->SetProgram(program);
309
  h->SetPrefetchPreparedCtx(prefetch_ctx);
310
  h->SetSparseGradToParam(sparse_grad_name_to_param_name);
311
  h->SetRPCServer(rpc_server);
312
  h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
313
  h->SetLrDecayPreparedCtx(lr_decay_ctx);
314 315
}

Y
Yancey1989 已提交
316 317 318 319
void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
                                    const framework::Scope &scope) const {
  for (const auto &varname : varnames) {
    auto var = scope.FindVar(varname);
320 321 322
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::PreconditionNotMet(
                 "Received var is not initialized in the received scope."));
Y
Yancey1989 已提交
323 324 325 326 327 328 329 330 331 332 333 334 335
    if (var->IsType<framework::SelectedRows>()) {
      sparse_vars_.push_back(varname);
    } else if (var->IsType<framework::LoDTensor>() ||
               var->IsType<framework::Tensor>()) {
      dense_vars_.push_back(varname);
    } else {
      PADDLE_THROW(
          "The type of received var should be in [SelectedRows, LoDTensor, "
          "Tensor].");
    }
  }
}

Q
qiaolongfei 已提交
336 337
void ListenAndServOp::RunImpl(const framework::Scope &scope,
                              const platform::Place &dev_place) const {
338
  // Mark this as PS that it should decide profiling by listening from trainer.
X
Xin Pan 已提交
339
  platform::SetProfileListener();
Q
qiaolongfei 已提交
340 341 342 343
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &dev_ctx = *pool.Get(dev_place);
  framework::Scope &recv_scope = scope.NewScope();

1
123malin 已提交
344
  int distributed_mode = Attr<int>("distributed_mode");
W
Wu Yi 已提交
345
  bool dc_sgd = Attr<bool>("dc_asgd");
346
  auto fan_in = Attr<int>("Fanin");
347
  auto pserver_id = Attr<int>("pserver_id");
Y
Yancey1989 已提交
348
  auto inputs = Inputs("X");
Q
qiaolongfei 已提交
349

350 351 352
  PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
                    platform::errors::PreconditionNotMet(
                        "RPC service has been created unexpectedly."));
Q
qiaolongfei 已提交
353
  std::string endpoint = Attr<std::string>("endpoint");
T
tangwei12 已提交
354
  int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
355
  int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
Q
qiaolongfei 已提交
356

1
123malin 已提交
357 358 359
  VLOG(4) << "pserver_id: " << pserver_id
          << ", distributed_mode:" << distributed_mode << ", fan_in:" << fan_in
          << ", end_point:" << endpoint
360 361
          << ", checkpoint_block_id: " << checkpoint_block_id
          << ", lr_decay_block_id: " << lr_decay_block_id;
362

G
gongweibao 已提交
363 364
  rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));

1
123malin 已提交
365 366 367 368
  auto rpc_get_thread_num = Attr<int>("rpc_get_thread_num");
  auto rpc_send_thread_num = Attr<int>("rpc_send_thread_num");
  auto rpc_prefetch_thread_num = Attr<int>("rpc_prefetch_thread_num");

W
Wu Yi 已提交
369
  request_send_handler_.reset(
1
123malin 已提交
370
      new distributed::RequestSendHandler(distributed_mode, dc_sgd));
W
Wu Yi 已提交
371
  request_get_handler_.reset(
1
123malin 已提交
372
      new distributed::RequestGetHandler(distributed_mode, dc_sgd));
373
  request_prefetch_handler_.reset(
1
123malin 已提交
374
      new distributed::RequestPrefetchHandler(distributed_mode));
T
merge  
tangwei12 已提交
375
  request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
1
123malin 已提交
376
      distributed_mode, checkpoint_block_id));
377 378
  request_get_no_barrier_handler_.reset(
      new distributed::RequestGetNoBarrierHandler());
1
123malin 已提交
379 380
  request_notify_handler_.reset(new distributed::RequestNotifyHandler(
      distributed_mode, lr_decay_block_id));
381

382
  rpc_service_->RegisterRPC(distributed::kRequestSend,
1
123malin 已提交
383
                            request_send_handler_.get(), rpc_send_thread_num);
384
  rpc_service_->RegisterRPC(distributed::kRequestGet,
1
123malin 已提交
385
                            request_get_handler_.get(), rpc_get_thread_num);
386
  rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
387
                            request_prefetch_handler_.get(),
1
123malin 已提交
388
                            rpc_prefetch_thread_num);
T
tangwei12 已提交
389
  rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
T
tangwei12 已提交
390
                            request_checkpoint_handler_.get());
391 392
  rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
                            request_get_no_barrier_handler_.get());
393
  rpc_service_->RegisterRPC(distributed::kRequestNotify,
T
tangwei12 已提交
394
                            request_notify_handler_.get(), rpc_send_thread_num);
Q
qiaolongfei 已提交
395

396 397
  auto optimize_blocks =
      Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
398 399 400 401
  PADDLE_ENFORCE_GE(optimize_blocks.size(), 1,
                    platform::errors::PreconditionNotMet(
                        "optimize blocks is less than 1. Optimize blocks "
                        "should be 1 at least on the pserver side."));
402
  auto *program = optimize_blocks[0]->Program();
T
bug fix  
tangwei12 已提交
403 404
  framework::Executor executor(dev_place);

405
  std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
T
tangwei12 已提交
406 407 408
  if (checkpoint_block_id != -1) {
    auto ctx = executor.Prepare(*program, checkpoint_block_id);
    // see: https://stackoverflow.com/a/14856553
409 410 411
    ckpt_pre_context = std::move(ctx);
  }

412 413 414 415 416 417 418
  std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_context = nullptr;
  if (lr_decay_block_id != -1) {
    auto ctx = executor.Prepare(*program, lr_decay_block_id);
    // see: https://stackoverflow.com/a/14856553
    lr_decay_context = std::move(ctx);
  }

Q
qiaolongfei 已提交
419
  // prepare for prefetch
420
  std::vector<int> prefetch_block_id_list;
Q
qiaolongfei 已提交
421
  std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
422 423 424 425 426 427 428

  auto prefetch_var_name_to_block_id_str =
      Attr<std::vector<std::string>>(kPrefetchVarNameToBlockId);
  for (const auto &prefetch_var_name_and_id :
       prefetch_var_name_to_block_id_str) {
    std::vector<std::string> pieces;
    split(prefetch_var_name_and_id, ':', &pieces);
M
minqiyang 已提交
429 430
    VLOG(3) << "after split, prefetch_var = " << pieces[0]
            << ", id=" << pieces[1];
431 432 433 434 435 436 437 438 439 440 441 442
    PADDLE_ENFORCE_EQ(pieces.size(), 2);

    int block_id = std::stoi(pieces[1]);
    prefetch_block_id_list.push_back(block_id);
    block_id_to_prefetch_var_name[block_id] = pieces[0];
  }

  auto prefetch_prepared = executor.Prepare(*program, prefetch_block_id_list);

  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
      prefetch_var_name_to_prepared_ctx;
Q
qiaolongfei 已提交
443
  for (size_t i = 0; i < prefetch_block_id_list.size(); ++i) {
444 445 446 447
    auto block_id = prefetch_block_id_list[i];
    auto prefetch_var_name = block_id_to_prefetch_var_name[block_id];
    prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
  }
448

449 450 451 452 453 454 455 456 457 458 459 460 461 462
  // parse attr of kSparseGradToParam  sparse_grad_name -> param_name
  std::unordered_map<std::string, std::string> sparse_grad_name_to_param_name;
  auto sparse_grad_name_to_param_name_str =
      Attr<std::vector<std::string>>(kSparseGradToParam);
  for (const auto &sparse_grad_name_and_param_name :
       sparse_grad_name_to_param_name_str) {
    std::vector<std::string> pieces;
    split(sparse_grad_name_and_param_name, ':', &pieces);
    PADDLE_ENFORCE_EQ(pieces.size(), 2);
    VLOG(3) << "after split, sparse_grad_name = " << pieces[0]
            << ", param_name = " << pieces[1];
    sparse_grad_name_to_param_name[pieces[0]] = pieces[1];
  }

463 464 465 466 467
  auto f =
      std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
                &executor, program, &prefetch_var_name_to_prepared_ctx,
                &sparse_grad_name_to_param_name, ckpt_pre_context,
                lr_decay_context, rpc_service_.get());
468 469 470 471

  f(request_send_handler_.get());
  f(request_get_handler_.get());
  f(request_prefetch_handler_.get());
T
tangwei12 已提交
472
  f(request_checkpoint_handler_.get());
473
  f(request_get_no_barrier_handler_.get());
474
  f(request_notify_handler_.get());
Q
qiaolongfei 已提交
475

476 477 478 479
  // register SIGINT(from ctrl+C) and SIGTERM(from kill) signal handlers
  signal(SIGINT, SignalHandler::StopAndExit);
  signal(SIGTERM, SignalHandler::StopAndExit);

1
123malin 已提交
480
  if (distributed_mode == distributed::DistributedMode::kSync) {
T
tangwei12 已提交
481 482 483 484 485
    // start the server listening after all member initialized.
    server_thread_.reset(new std::thread(RunServer, rpc_service_));
    VLOG(3) << "wait server thread to become ready...";
    rpc_service_->WaitServerReady();

T
tangwei12 已提交
486 487
    CacheVarsType(inputs, recv_scope);

T
tangwei12 已提交
488 489 490
    // Write to a file of server selected port for python use.
    SavePort();

Y
Yancey1989 已提交
491
    RunSyncLoop(&executor, program, &recv_scope, &dev_ctx,
Y
Yancey1989 已提交
492
                prefetch_block_id_list, checkpoint_block_id);
Q
qiaolongfei 已提交
493
  } else {
1
123malin 已提交
494 495 496 497
    if (distributed_mode == distributed::DistributedMode::kGeo) {
      distributed::AsyncSparseParamUpdateRecorder::Init(
          fan_in, sparse_grad_name_to_param_name);
    }
498 499 500 501 502 503 504 505 506 507 508 509

    VLOG(2) << "RunAsyncLoop";
    auto grad_to_block_id_str =
        Attr<std::vector<std::string>>("grad_to_block_id");

    if (grad_to_block_id_str.size() == 0) {
      VLOG(0) << "there are no gradients on this parameter server";
    } else {
      std::vector<std::string> pieces;
      split(grad_to_block_id_str[0], ':', &pieces);
      distributed::HeartBeatMonitor::Init(fan_in, pserver_id == 0, pieces[0]);
    }
T
tangwei12 已提交
510 511 512 513 514 515 516 517 518

    // start the server listening after all member initialized.
    server_thread_.reset(new std::thread(RunServer, rpc_service_));
    VLOG(3) << "wait server thread to become ready...";
    rpc_service_->WaitServerReady();

    // Write to a file of server selected port for python use.
    SavePort();

Y
Yancey1989 已提交
519
    RunAsyncLoop(&executor, program, &recv_scope);
Q
qiaolongfei 已提交
520
  }
Q
qiaolongfei 已提交
521 522
}

T
typhoonzero 已提交
523 524
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
525
  void Make() {
T
typhoonzero 已提交
526
    AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
527 528 529
    AddComment(R"DOC(" + "ListenAndServ operator" + "\n" + "This operator" +
" will start a RPC server which can receive variables from send_op and send" +
"back variables to recv_op.)DOC");
T
typhoonzero 已提交
530 531 532 533 534
    AddAttr<std::string>("endpoint",
                         "(string, default 127.0.0.1:6164)"
                         "IP address to listen on.")
        .SetDefault("127.0.0.1:6164")
        .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
535 536 537
    AddAttr<int>("pserver_id",
                 "(int, default -1), the parameter server index id")
        .SetDefault(-1);
Q
qiaolongfei 已提交
538
    AddAttr<std::vector<std::string>>(
Q
qiaolongfei 已提交
539
        "grad_to_block_id",
Q
qiaolongfei 已提交
540
        "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
Q
qiaolongfei 已提交
541 542
        "a map from grad name to it's optimize block id")
        .SetDefault({});
1
123malin 已提交
543 544 545 546
    AddAttr<int>("distributed_mode",
                 "indicate distriubte training mode, 0 is sync, 1 is "
                 "fully-async, 2 is half-async, 3 is geo")
        .SetDefault(0);
W
Wu Yi 已提交
547 548
    AddAttr<bool>("dc_asgd", "set to true will enable DC-ASGD training.")
        .SetDefault(false);
549
    AddAttr<std::vector<framework::BlockDesc *>>(
Y
Yancey1989 已提交
550 551
        kOptimizeBlocks, "Optimize blocks to run on server side.")
        .SetDefault({});
552
    AddAttr<std::vector<std::string>>(kPrefetchVarNameToBlockId,
553 554
                                      "prefetch blocks to run on server side.")
        .SetDefault({});
555 556 557 558
    AddAttr<std::vector<std::string>>(
        kSparseGradToParam,
        "sparse grad name to param name. like: 'emb@Grad:emb'")
        .SetDefault({});
559 560
    AddAttr<int>("Fanin", "How many clients send to this server.")
        .SetDefault(1);
T
tangwei12 已提交
561 562 563
    AddAttr<int>(kCheckpointBlockId,
                 "BolckID to run save checkpoint on pserer.")
        .SetDefault(-1);
564 565
    AddAttr<int>(kLRDecayBlockId, "BolckID to run lr decay on pserer.")
        .SetDefault(-1);
1
123malin 已提交
566 567 568 569 570
    AddAttr<int>("rpc_get_thread_num", "pserver get thread num.").SetDefault(1);
    AddAttr<int>("rpc_send_thread_num", "pserver send thread num.")
        .SetDefault(1);
    AddAttr<int>("rpc_prefetch_thread_num", "pserver prefetch thread num.")
        .SetDefault(1);
T
typhoonzero 已提交
571 572 573
  }
};

574
void SignalHandler::StopAndExit(int signal_num) {
Y
yi.wu 已提交
575 576
  // Do not use VLOG here for the device for printing maybe already released.
  // exit will release interal allocated resoureces.
T
tangwei12 已提交
577 578 579
  auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
  remove(file_path.c_str());
  exit(0);
580 581
}

T
typhoonzero 已提交
582 583 584 585 586 587
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp,
Y
Yancey1989 已提交
588
                  ops::ListenAndServOpMaker);