listen_and_serv_op.cc 13.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
T
typhoonzero 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
yi.wu 已提交
15
#include <stdio.h>  // for removing the port file
T
typhoonzero 已提交
16
#include <fstream>
T
typhoonzero 已提交
17
#include <ostream>
18 19
#include <thread>  // NOLINT
#include <vector>
T
typhoonzero 已提交
20

T
typhoonzero 已提交
21
#include "paddle/fluid/operators/listen_and_serv_op.h"
22
#include "paddle/fluid/platform/profiler.h"
T
typhoonzero 已提交
23 24 25 26 27 28 29 30 31

namespace paddle {
namespace operators {

void RunServer(std::shared_ptr<detail::AsyncGRPCServer> service) {
  service->RunSyncUpdate();
  VLOG(4) << "RunServer thread end";
}

Q
qiaolongfei 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
static void split(const std::string &str, char sep,
                  std::vector<std::string> *pieces) {
  pieces->clear();
  if (str.empty()) {
    return;
  }
  size_t pos = 0;
  size_t next = str.find(sep, pos);
  while (next != std::string::npos) {
    pieces->push_back(str.substr(pos, next - pos));
    pos = next + 1;
    next = str.find(sep, pos);
  }
  if (!str.substr(pos).empty()) {
    pieces->push_back(str.substr(pos));
  }
}

T
refine  
typhoonzero 已提交
50 51 52 53 54
static void ParallelExecuteBlocks(
    const std::vector<size_t> &parallel_blkids, framework::Executor *executor,
    const std::vector<std::shared_ptr<framework::ExecutorPrepareContext>>
        &prepared,
    framework::ProgramDesc *program, framework::Scope *scope) {
T
update  
typhoonzero 已提交
55 56
  std::vector<std::future<void>> fs;
  for (size_t idx : parallel_blkids) {
T
refine  
typhoonzero 已提交
57 58 59 60
    fs.push_back(
        framework::Async([&executor, &prepared, &program, &scope, idx]() {
          int run_block = idx;  // thread local
          try {
W
Wu Yi 已提交
61
            executor->RunPreparedContext(prepared[run_block].get(), scope);
T
refine  
typhoonzero 已提交
62 63 64 65
          } catch (std::exception &e) {
            LOG(ERROR) << "run sub program error " << e.what();
          }
        }));
T
update  
typhoonzero 已提交
66 67 68 69
  }
  for (size_t i = 0; i < fs.size(); ++i) fs[i].wait();
}

T
done  
typhoonzero 已提交
70
std::atomic_int ListenAndServOp::selected_port_{0};
Q
qiaolongfei 已提交
71

T
typhoonzero 已提交
72 73 74 75 76 77 78 79 80
ListenAndServOp::ListenAndServOp(const std::string &type,
                                 const framework::VariableNameMap &inputs,
                                 const framework::VariableNameMap &outputs,
                                 const framework::AttributeMap &attrs)
    : OperatorBase(type, inputs, outputs, attrs) {}

void ListenAndServOp::Stop() {
  rpc_service_->Push(LISTEN_TERMINATE_MESSAGE);
  server_thread_->join();
Y
yi.wu 已提交
81 82
  auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
  remove(file_path.c_str());
83 84
}

Y
yi.wu 已提交
85
void ListenAndServOp::SavePort() const {
T
done  
typhoonzero 已提交
86 87
  // NOTE: default write file to /tmp/paddle.selected_port
  selected_port_ = rpc_service_->GetSelectedPort();
Y
yi.wu 已提交
88
  auto file_path = string::Sprintf("/tmp/paddle.%d.port", ::getpid());
T
done  
typhoonzero 已提交
89 90 91 92 93 94 95 96 97 98 99 100
  std::ofstream port_file;
  port_file.open(file_path);
  port_file << selected_port_.load();
  port_file.close();
  VLOG(4) << "selected port written to " << file_path;
}

void ListenAndServOp::WaitServerReady() {
  while (selected_port_.load() == 0) {
  }
}

101 102 103 104
void ListenAndServOp::RunSyncLoop(framework::Executor *executor,
                                  framework::ProgramDesc *program,
                                  framework::Scope *recv_scope,
                                  framework::BlockDesc *prefetch_block) const {
T
typhoonzero 已提交
105
  auto fan_in = Attr<int>("Fanin");
Q
qiaolongfei 已提交
106

T
typhoonzero 已提交
107 108 109 110
  size_t num_blocks = program->Size();
  PADDLE_ENFORCE_GE(num_blocks, 2,
                    "server program should have at least 2 blocks");

111 112
  std::vector<int> block_list;
  for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
Q
qiaolongfei 已提交
113
    block_list.push_back(blkid);
114
  }
Q
qiaolongfei 已提交
115
  auto optimize_prepared = executor->Prepare(*program, block_list);
116
  // Insert placeholder for block0 which holds current op itself.
117 118 119
  optimize_prepared.insert(
      optimize_prepared.begin(),
      std::shared_ptr<framework::ExecutorPrepareContext>(nullptr));
T
typhoonzero 已提交
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

  bool exit_flag = false;
  // Record received sparse variables, so that
  // we could reset those after execute optimize program
  std::vector<framework::Variable *> sparse_vars;
  while (!exit_flag) {
    // Get from multiple trainers, we don't care about the order in which
    // the gradients arrives, just add suffix 0~n and merge the gradient.
    rpc_service_->SetCond(0);
    size_t recv_var_cnt = 0;
    int batch_barrier = 0;
    while (batch_barrier != fan_in) {
      const detail::ReceivedMessage v = rpc_service_->Get();
      auto recv_var_name = v.first;
      if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
        LOG(INFO) << "received terminate message and exit";
        exit_flag = true;
Y
Yancey1989 已提交
137
        break;
T
typhoonzero 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151
      } else if (recv_var_name == BATCH_BARRIER_MESSAGE) {
        VLOG(3) << "recv batch barrier message";
        batch_barrier++;
        continue;
      } else {
        VLOG(3) << "received grad: " << recv_var_name;
        recv_var_cnt++;
        auto var = v.second->GetVar();
        if (var == nullptr) {
          LOG(ERROR) << "Can not find server side var: " << recv_var_name;
          PADDLE_THROW("Can not find server side var");
        }
        if (var->IsType<framework::SelectedRows>()) {
          sparse_vars.push_back(var);
T
typhoonzero 已提交
152
        }
153
      }
T
typhoonzero 已提交
154 155
    }
    if (exit_flag) {
T
typhoonzero 已提交
156
      rpc_service_->SetCond(1);
T
typhoonzero 已提交
157 158 159
      rpc_service_->ShutDown();
      break;
    }
T
typhoonzero 已提交
160

Q
qiaolongfei 已提交
161
    // NOTE: if is_gpu_place, CUDA kernels are launched by multiple threads
T
typhoonzero 已提交
162 163 164 165 166 167 168 169 170
    // and this will still work.

    // The optimize blocks which have the same parent ID would run parallel
    // TODO(Yancey1989): need to use ParallelExecutor for future
    int32_t last_parent_blkid = program->Block(1).Parent();
    std::vector<size_t> parallel_blkids;
    parallel_blkids.push_back(1);
    double ts = detail::GetTimestamp();
    for (size_t blkid = 2; blkid < num_blocks; ++blkid) {
T
typhoonzero 已提交
171
      if (blkid != static_cast<size_t>(prefetch_block->ID())) {
172
        if (program->Block(blkid).Parent() != last_parent_blkid) {
Q
qiaolongfei 已提交
173 174
          ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared,
                                program, recv_scope);
175 176 177 178
          parallel_blkids.clear();
          last_parent_blkid = program->Block(blkid).Parent();
        }
        parallel_blkids.push_back(blkid);
T
typhoonzero 已提交
179 180
      }
    }
Q
qiaolongfei 已提交
181 182
    ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
                          recv_scope);
183
    VLOG(2) << "run all blocks spent " << detail::GetTimestamp() - ts << "(ms)";
T
typhoonzero 已提交
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198

    // Reset the received sparse variables, the sum operator would not
    // sum the input sparse variables which rows is empty at the next
    // mini-batch.
    // TODO(Yancey1989): move the reset action into an operator, we couldn't
    // have any hide logic in the operator.
    for (auto &var : sparse_vars) {
      var->GetMutable<framework::SelectedRows>()->mutable_rows()->clear();
    }
    rpc_service_->SetCond(1);
    // FIXME(typhoonzero): use another condition to sync wait clients get.
    rpc_service_->WaitClientGet(fan_in);
    sparse_vars.clear();
  }  // while(true)
}
T
typhoonzero 已提交
199

Q
qiaolongfei 已提交
200
static void AsyncUpdateThread(
Q
qiaolongfei 已提交
201 202
    const std::string &var_name, const bool &exit_flag,
    const std::shared_ptr<detail::ReceivedQueue> &queue,
Q
qiaolongfei 已提交
203 204
    framework::Executor *executor,
    framework::ExecutorPrepareContext *prepared) {
Q
qiaolongfei 已提交
205
  VLOG(3) << "update thread for " << var_name << " started";
Q
qiaolongfei 已提交
206 207 208 209 210 211 212 213
  while (!exit_flag) {
    const detail::ReceivedMessage v = queue->Pop();
    auto recv_var_name = v.first;
    auto var = v.second->GetVar();
    if (var == nullptr) {
      LOG(ERROR) << "Can not find server side var: " << recv_var_name;
      PADDLE_THROW("Can not find server side var");
    }
214 215
    auto fs = framework::Async([var_name, &executor, &v, prepared] {
      try {
W
Wu Yi 已提交
216 217
        executor->RunPreparedContext(prepared,
                                     v.second->GetMutableLocalScope());
218 219 220 221 222
      } catch (std::exception &e) {
        LOG(ERROR) << "run sub program error " << e.what();
      }
    });
    fs.wait();
Q
qiaolongfei 已提交
223 224 225
  }
}

226
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
Q
qiaolongfei 已提交
227
                                   framework::ProgramDesc *program) const {
Q
qiaolongfei 已提交
228
  VLOG(3) << "RunAsyncLoop in";
Q
qiaolongfei 已提交
229
  // grad name to block id
Q
qiaolongfei 已提交
230
  std::unordered_map<std::string, int32_t> grad_to_block_id;
Q
qiaolongfei 已提交
231
  std::unordered_map<int32_t, std::string> id_to_grad;
Q
qiaolongfei 已提交
232 233
  std::unordered_map<std::string, std::shared_ptr<detail::ReceivedQueue>>
      grad_to_queue;
Q
qiaolongfei 已提交
234

Q
qiaolongfei 已提交
235 236 237
  auto grad_to_block_id_str =
      Attr<std::vector<std::string>>("grad_to_block_id");
  for (auto &grad_and_id : grad_to_block_id_str) {
Q
qiaolongfei 已提交
238
    std::vector<std::string> pieces;
Q
qiaolongfei 已提交
239 240
    split(grad_and_id, ':', &pieces);
    VLOG(3) << "after split, grad = " << pieces[0] << ", id=" << pieces[1];
Q
qiaolongfei 已提交
241
    PADDLE_ENFORCE_EQ(pieces.size(), 2);
Q
qiaolongfei 已提交
242
    PADDLE_ENFORCE_EQ(grad_to_block_id.count(pieces[0]), 0);
Q
qiaolongfei 已提交
243
    int block_id = std::stoi(pieces[1]);
Q
qiaolongfei 已提交
244
    grad_to_block_id[pieces[0]] = block_id;
Q
qiaolongfei 已提交
245
    grad_to_queue[pieces[0]] = std::make_shared<detail::ReceivedQueue>();
Q
qiaolongfei 已提交
246 247 248 249 250 251 252 253
    id_to_grad[block_id] = pieces[0];
  }
  size_t num_blocks = program->Size();
  PADDLE_ENFORCE_GE(num_blocks, 2,
                    "server program should have at least 2 blocks");

  std::vector<int> block_list;
  for (size_t blkid = 1; blkid < num_blocks; ++blkid) {
Q
qiaolongfei 已提交
254
    block_list.push_back(blkid);
Q
qiaolongfei 已提交
255 256 257 258
  }
  auto optimize_prepared = executor->Prepare(*program, block_list);
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>
Q
qiaolongfei 已提交
259
      grad_to_prepared_ctx;
Q
qiaolongfei 已提交
260
  for (size_t i = 0; i < block_list.size(); ++i) {
Q
qiaolongfei 已提交
261
    grad_to_prepared_ctx[id_to_grad[block_list[i]]] = optimize_prepared[i];
Q
qiaolongfei 已提交
262 263 264
  }

  bool exit_flag = false;
Q
qiaolongfei 已提交
265 266 267 268 269

  VLOG(3) << "start async optimize threads";
  std::vector<std::future<void>> fs;
  for (auto iter = grad_to_queue.begin(); iter != grad_to_queue.end(); iter++) {
    std::string grad_name = iter->first;
Q
qiaolongfei 已提交
270
    VLOG(3) << "create async update thread for " << grad_name;
271 272
    fs.push_back(framework::AsyncIO([grad_name, &exit_flag, &executor,
                                     &grad_to_queue, &grad_to_prepared_ctx]() {
Q
qiaolongfei 已提交
273 274
      AsyncUpdateThread(grad_name, exit_flag, grad_to_queue[grad_name],
                        executor, grad_to_prepared_ctx[grad_name].get());
Q
qiaolongfei 已提交
275 276 277
    }));
  }

Q
qiaolongfei 已提交
278
  VLOG(3) << "RunAsyncLoop into while";
Q
qiaolongfei 已提交
279 280 281 282 283 284 285 286 287
  while (!exit_flag) {
    const detail::ReceivedMessage v = rpc_service_->Get();
    auto recv_var_name = v.first;
    if (recv_var_name == LISTEN_TERMINATE_MESSAGE) {
      LOG(INFO) << "received terminate message and exit";
      exit_flag = true;
      break;
    } else {
      VLOG(3) << "received grad: " << recv_var_name;
Q
qiaolongfei 已提交
288
      grad_to_queue[recv_var_name]->Push(v);
Q
qiaolongfei 已提交
289 290 291 292 293 294 295 296 297
    }

    if (exit_flag) {
      rpc_service_->ShutDown();
      break;
    }
  }  // while(true)
}

Q
qiaolongfei 已提交
298 299
void ListenAndServOp::RunImpl(const framework::Scope &scope,
                              const platform::Place &dev_place) const {
300
  // Mark this as PS that it should decide profiling by listening from trainer.
X
Xin Pan 已提交
301
  platform::SetProfileListener();
Q
qiaolongfei 已提交
302 303 304 305
  platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
  auto &dev_ctx = *pool.Get(dev_place);
  framework::Scope &recv_scope = scope.NewScope();

Q
qiaolongfei 已提交
306 307
  bool sync_mode = Attr<bool>("sync_mode");

Q
qiaolongfei 已提交
308 309
  PADDLE_ENFORCE(!rpc_service_);
  std::string endpoint = Attr<std::string>("endpoint");
Q
qiaolongfei 已提交
310 311

  rpc_service_.reset(new detail::AsyncGRPCServer(endpoint, sync_mode));
Q
qiaolongfei 已提交
312 313 314 315 316 317 318 319 320 321

  auto *optimize_block = Attr<framework::BlockDesc *>(kOptimizeBlock);
  auto *prefetch_block = Attr<framework::BlockDesc *>(kPrefetchBlock);
  auto *program = optimize_block->Program();
  framework::Executor executor(dev_place);

  // prepare rpc_service
  rpc_service_->SetScope(&recv_scope);
  rpc_service_->SetDevCtx(&dev_ctx);
  rpc_service_->SetProgram(program);
Q
qiaolongfei 已提交
322 323 324 325 326
  rpc_service_->SetExecutor(&executor);

  // prepare for prefetch
  VLOG(3) << "prefetch block id is " << prefetch_block->ID();
  auto prefetch_prepared = executor.Prepare(*program, prefetch_block->ID());
X
Xin Pan 已提交
327
  rpc_service_->SetPrefetchPreparedCtx(std::move(prefetch_prepared));
Q
qiaolongfei 已提交
328

Q
qiaolongfei 已提交
329 330 331
  // start the server listening after all member initialized.
  server_thread_.reset(new std::thread(RunServer, rpc_service_));
  VLOG(3) << "wait server thread to become ready...";
T
done  
typhoonzero 已提交
332
  rpc_service_->WaitServerReady();
T
wip  
typhoonzero 已提交
333

Q
qiaolongfei 已提交
334
  // Write to a file of server selected port for python use.
335 336
  std::string file_path = string::Sprintf("/tmp/paddle.%d.selected_port",
                                          static_cast<int>(::getpid()));
Y
yi.wu 已提交
337
  SavePort();
Q
qiaolongfei 已提交
338
  if (sync_mode) {
339
    RunSyncLoop(&executor, program, &recv_scope, prefetch_block);
Q
qiaolongfei 已提交
340
  } else {
Q
qiaolongfei 已提交
341
    RunAsyncLoop(&executor, program);
Q
qiaolongfei 已提交
342
  }
Q
qiaolongfei 已提交
343 344
}

T
typhoonzero 已提交
345 346
class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
347
  void Make() {
T
typhoonzero 已提交
348
    AddInput("X", "(Tensor) Variables that server recv.").AsDuplicable();
T
typhoonzero 已提交
349 350 351 352 353 354 355 356 357 358 359
    AddComment(R"DOC(
ListenAndServ operator

This operator will start a RPC server which can receive variables
from send_op and send back variables to recv_op.
)DOC");
    AddAttr<std::string>("endpoint",
                         "(string, default 127.0.0.1:6164)"
                         "IP address to listen on.")
        .SetDefault("127.0.0.1:6164")
        .AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
Q
qiaolongfei 已提交
360
    AddAttr<std::vector<std::string>>(
Q
qiaolongfei 已提交
361
        "grad_to_block_id",
Q
qiaolongfei 已提交
362
        "['param1@GRAD.block0:1', 'param2@GRAD.blockn:2'] "
Q
qiaolongfei 已提交
363 364
        "a map from grad name to it's optimize block id")
        .SetDefault({});
Q
qiaolongfei 已提交
365
    AddAttr<bool>("sync_mode", "if works at sync_mode or not").SetDefault(true);
T
typhoonzero 已提交
366 367
    AddAttr<framework::BlockDesc *>(kOptimizeBlock,
                                    "BlockID to run on server side.");
368 369
    AddAttr<framework::BlockDesc *>(kPrefetchBlock,
                                    "prefetch block to run on server side.");
370 371
    AddAttr<int>("Fanin", "How many clients send to this server.")
        .SetDefault(1);
T
typhoonzero 已提交
372 373 374 375 376 377 378 379 380
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OPERATOR(listen_and_serv, ops::ListenAndServOp,
Y
Yancey1989 已提交
381
                  ops::ListenAndServOpMaker);