request_handler.h 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <condition_variable>  // NOLINT
19 20

#include <functional>
Q
Qiao Longfei 已提交
21
#include <memory>
22
#include <string>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
24 25 26 27 28 29 30 31 32 33
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
34
#include "paddle/fluid/platform/macros.h"
35 36 37

namespace paddle {
namespace operators {
38
namespace distributed {
39 40 41

constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
42 43
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
44
constexpr char kRequestPrefetch[] = "RequestPrefetch";
T
tangwei12 已提交
45
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
Y
Yancey1989 已提交
46
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
47
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
48
constexpr char kRequestNotify[] = "RequestNotify";
49 50 51 52 53 54 55 56 57 58 59

constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
T
tangwei12 已提交
60
constexpr int64_t kPrefetchTimeout = 60000;
61

G
gongweibao 已提交
62 63 64
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
W
Wu Yi 已提交
65
#define COMPLETE_MESSAGE "COMPLETE@RECV"
66
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
1
123malin 已提交
67
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
G
gongweibao 已提交
68

T
tangwei12 已提交
69 70
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
T
tangwei12 已提交
71

1
123malin 已提交
72 73
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };

74 75
class RPCServer;

76 77 78 79 80
class VarHandle {
 public:
  VarHandle(const std::string ep, const std::string& method,
            const std::string& name,
            const platform::DeviceContext* p_ctx = nullptr,
Q
Qiao Longfei 已提交
81
            const framework::Scope* p_scope = nullptr)
G
gongweibao 已提交
82
      : status_(kDefaultState) {
83 84 85 86 87 88 89 90 91 92
    ep_ = ep;
    ctx_ = p_ctx;
    scope_ = p_scope;
    name_ = name;
    method_ = method;
  }

  virtual ~VarHandle() {}

 public:
93 94
  bool should_retry = false;

95
  bool Wait() {
G
gongweibao 已提交
96
    int ret = kDefaultState;
97 98
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
99 100
      wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
      ret = status_;
101
    }
M
minqiyang 已提交
102
    VLOG(7) << "VarHandle wait:" << ret;
G
gongweibao 已提交
103
    return ret != kErrorState;
104 105 106 107 108
  }

  void Finish(bool ok) {
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
109
      status_ = ok ? kFinishState : kErrorState;
110
    }
M
minqiyang 已提交
111
    VLOG(7) << "VarHandle finish:" << ok;
112 113
    wait_cond_.notify_all();
  }
114 115 116

  std::string String() const {
    std::ostringstream s;
G
gongweibao 已提交
117 118
    s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
      << status_ << "]";
119 120
    return s.str();
  }
121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141

  std::string ep() const { return ep_; }
  const platform::DeviceContext* ctx() const { return ctx_; }
  const framework::Scope* scope() const { return scope_; }
  std::string name() const { return name_; }
  std::string method() const { return method_; }

 protected:
  // RPC endpoint.
  std::string ep_;
  const platform::DeviceContext* ctx_;
  const framework::Scope* scope_;
  // Variable name.
  std::string name_;
  // RPC method name.
  std::string method_;

 protected:
  std::mutex sync_mutex_;
  std::condition_variable wait_cond_;

G
gongweibao 已提交
142 143 144 145 146 147
  enum VarHandleStatus {
    kDefaultState = -1,
    kErrorState = 0,
    kFinishState = 1,
  };
  VarHandleStatus status_;
148 149 150

 private:
  DISABLE_COPY_AND_ASSIGN(VarHandle);
151 152
};

153 154
typedef std::shared_ptr<VarHandle> VarHandlePtr;

155 156
class RequestHandler {
 public:
1
123malin 已提交
157 158
  explicit RequestHandler(int distributed_mode)
      : distributed_mode_(distributed_mode),
159 160 161 162 163 164 165 166 167 168 169 170 171
        dev_ctx_(nullptr),
        executor_(nullptr),
        scope_(nullptr),
        program_(nullptr),
        rpc_server_(nullptr) {}

  virtual ~RequestHandler() {}

  // Set attributes.
  void SetScope(framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
172 173

  // Used for dist lookup table prefetch
174
  void SetPrefetchPreparedCtx(
175 176 177
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    prefetch_var_name_to_prepared_ctx_ = g;
178 179
  }

T
tangwei12 已提交
180 181 182 183 184
  void SetCheckpointNotifyPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    checkpoint_prepared_ctx_ = g;
  }

185 186 187 188 189 190 191
  // Used for async.
  void SetGradToPreparedCtx(
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    grad_to_prepared_ctx_ = g;
  }

192 193 194 195
  void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
    sparse_grad_to_param_ = g;
  }

196 197 198 199 200
  void SetLrDecayPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    lr_decay_prepared_ctx_ = g;
  }

201 202 203
  void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }

  // Get attributes.
1
123malin 已提交
204
  int distributed_mode() { return distributed_mode_; }
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
  framework::Scope* scope() { return scope_; }
  const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
  framework::ProgramDesc* program() { return program_; }
  framework::Executor* executor() { return executor_; }

  // This function processes user's rpc request.
  // The implemention is in request_handler_impl.
  // example:
  //    std::string varname = request_.varname();
  //
  //    auto scope = request_handler_->scope();
  //    auto invar = scope->FindVar(varname);
  //    framework::Variable* outvar = nullptr;
  //
  //    request_handler_->Handle(varname, scope, invar, &outvar);
  //    if (outvar) {
  //        SerializeToByteBuffer(varname, outvar,
  //           *request_handler_->dev_ctx(), &reply_);
  //    }
  virtual bool Handle(const std::string& varname, framework::Scope* scope,
Q
qiaolongfei 已提交
225
                      framework::Variable* var, framework::Variable** outvar,
W
Wu Yi 已提交
226
                      const int trainer_id,
Q
Qiao Longfei 已提交
227 228
                      const std::string& out_var_name = "",
                      const std::string& table_name = "") = 0;
229 230

 protected:
1
123malin 已提交
231
  const int distributed_mode_;
232 233 234 235 236

  const platform::DeviceContext* dev_ctx_;
  framework::Executor* executor_;
  framework::Scope* scope_;
  framework::ProgramDesc* program_;
237 238 239 240 241

  // used for distribute lookup table prefetch
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      prefetch_var_name_to_prepared_ctx_;
T
tangwei12 已提交
242 243
  // used for checkpoint notify
  std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
244 245 246 247 248

  // Used for async.
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      grad_to_prepared_ctx_;
249
  std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
250

251 252
  // used for lr decay
  std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
253 254 255
  RPCServer* rpc_server_;
};

256
}  // namespace distributed
257 258
}  // namespace operators
}  // namespace paddle