request_handler.h 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <condition_variable>  // NOLINT
19 20

#include <functional>
Q
Qiao Longfei 已提交
21
#include <memory>
22
#include <string>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
24 25 26 27 28 29 30 31 32 33
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
34
#include "paddle/fluid/platform/macros.h"
35 36 37

namespace paddle {
namespace operators {
38
namespace distributed {
39 40 41

constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
42 43
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
44
constexpr char kRequestPrefetch[] = "RequestPrefetch";
T
tangwei12 已提交
45
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
Y
Yancey1989 已提交
46
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
47
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
48
constexpr char kRequestNotify[] = "RequestNotify";
49
constexpr char kRequestSendAndRecv[] = "RequestSendAndRecv";
50 51 52 53 54 55 56 57 58 59 60

constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
61
constexpr char kSendAndRecvRPC[] = "SendAndRecvRPC";
T
tangwei12 已提交
62
constexpr int64_t kPrefetchTimeout = 60000;
63

G
gongweibao 已提交
64 65 66
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
W
Wu Yi 已提交
67
#define COMPLETE_MESSAGE "COMPLETE@RECV"
68
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
1
123malin 已提交
69
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
70
#define STEP_COUNTER "@PS_STEP_COUNTER@"
G
gongweibao 已提交
71

T
tangwei12 已提交
72 73
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
T
tangwei12 已提交
74

1
123malin 已提交
75 76
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };

77 78
class RPCServer;

79 80 81 82 83
class VarHandle {
 public:
  VarHandle(const std::string ep, const std::string& method,
            const std::string& name,
            const platform::DeviceContext* p_ctx = nullptr,
Q
Qiao Longfei 已提交
84
            const framework::Scope* p_scope = nullptr)
G
gongweibao 已提交
85
      : status_(kDefaultState) {
86 87 88 89 90 91 92 93 94 95
    ep_ = ep;
    ctx_ = p_ctx;
    scope_ = p_scope;
    name_ = name;
    method_ = method;
  }

  virtual ~VarHandle() {}

 public:
96 97
  bool should_retry = false;

98
  bool Wait() {
G
gongweibao 已提交
99
    int ret = kDefaultState;
100 101
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
102 103
      wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
      ret = status_;
104
    }
M
minqiyang 已提交
105
    VLOG(7) << "VarHandle wait:" << ret;
G
gongweibao 已提交
106
    return ret != kErrorState;
107 108 109 110 111
  }

  void Finish(bool ok) {
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
112
      status_ = ok ? kFinishState : kErrorState;
113
    }
M
minqiyang 已提交
114
    VLOG(7) << "VarHandle finish:" << ok;
115 116
    wait_cond_.notify_all();
  }
117 118 119

  std::string String() const {
    std::ostringstream s;
G
gongweibao 已提交
120 121
    s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
      << status_ << "]";
122 123
    return s.str();
  }
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

  std::string ep() const { return ep_; }
  const platform::DeviceContext* ctx() const { return ctx_; }
  const framework::Scope* scope() const { return scope_; }
  std::string name() const { return name_; }
  std::string method() const { return method_; }

 protected:
  // RPC endpoint.
  std::string ep_;
  const platform::DeviceContext* ctx_;
  const framework::Scope* scope_;
  // Variable name.
  std::string name_;
  // RPC method name.
  std::string method_;

 protected:
  std::mutex sync_mutex_;
  std::condition_variable wait_cond_;

G
gongweibao 已提交
145 146 147 148 149 150
  enum VarHandleStatus {
    kDefaultState = -1,
    kErrorState = 0,
    kFinishState = 1,
  };
  VarHandleStatus status_;
151 152 153

 private:
  DISABLE_COPY_AND_ASSIGN(VarHandle);
154 155
};

156 157
typedef std::shared_ptr<VarHandle> VarHandlePtr;

158 159
class RequestHandler {
 public:
1
123malin 已提交
160 161
  explicit RequestHandler(int distributed_mode)
      : distributed_mode_(distributed_mode),
162 163 164 165 166 167 168 169 170 171 172 173 174
        dev_ctx_(nullptr),
        executor_(nullptr),
        scope_(nullptr),
        program_(nullptr),
        rpc_server_(nullptr) {}

  virtual ~RequestHandler() {}

  // Set attributes.
  void SetScope(framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
175 176

  // Used for dist lookup table prefetch
177
  void SetPrefetchPreparedCtx(
178 179 180
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    prefetch_var_name_to_prepared_ctx_ = g;
181 182
  }

T
tangwei12 已提交
183 184 185 186 187
  void SetCheckpointNotifyPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    checkpoint_prepared_ctx_ = g;
  }

188 189 190 191 192 193 194
  // Used for async.
  void SetGradToPreparedCtx(
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    grad_to_prepared_ctx_ = g;
  }

195 196 197 198
  void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
    sparse_grad_to_param_ = g;
  }

199 200 201 202 203
  void SetLrDecayPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    lr_decay_prepared_ctx_ = g;
  }

204 205 206
  void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }

  // Get attributes.
1
123malin 已提交
207
  int distributed_mode() { return distributed_mode_; }
208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
  framework::Scope* scope() { return scope_; }
  const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
  framework::ProgramDesc* program() { return program_; }
  framework::Executor* executor() { return executor_; }

  // This function processes user's rpc request.
  // The implemention is in request_handler_impl.
  // example:
  //    std::string varname = request_.varname();
  //
  //    auto scope = request_handler_->scope();
  //    auto invar = scope->FindVar(varname);
  //    framework::Variable* outvar = nullptr;
  //
  //    request_handler_->Handle(varname, scope, invar, &outvar);
  //    if (outvar) {
  //        SerializeToByteBuffer(varname, outvar,
  //           *request_handler_->dev_ctx(), &reply_);
  //    }
  virtual bool Handle(const std::string& varname, framework::Scope* scope,
Q
qiaolongfei 已提交
228
                      framework::Variable* var, framework::Variable** outvar,
W
Wu Yi 已提交
229
                      const int trainer_id,
Q
Qiao Longfei 已提交
230 231
                      const std::string& out_var_name = "",
                      const std::string& table_name = "") = 0;
232 233

 protected:
1
123malin 已提交
234
  const int distributed_mode_;
235 236 237 238 239

  const platform::DeviceContext* dev_ctx_;
  framework::Executor* executor_;
  framework::Scope* scope_;
  framework::ProgramDesc* program_;
240 241 242 243 244

  // used for distribute lookup table prefetch
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      prefetch_var_name_to_prepared_ctx_;
T
tangwei12 已提交
245 246
  // used for checkpoint notify
  std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
247 248 249 250 251

  // Used for async.
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      grad_to_prepared_ctx_;
252
  std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
253

254 255
  // used for lr decay
  std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
256 257 258
  RPCServer* rpc_server_;
};

259
}  // namespace distributed
260 261
}  // namespace operators
}  // namespace paddle