request_handler.h 8.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <condition_variable>  // NOLINT
19 20

#include <functional>
Q
Qiao Longfei 已提交
21
#include <memory>
22
#include <string>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
24 25 26 27 28 29 30 31 32 33
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
34
#include "paddle/fluid/platform/macros.h"
35 36 37

namespace paddle {
namespace operators {
38
namespace distributed {
39 40 41

constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
42 43
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
44
constexpr char kRequestPrefetch[] = "RequestPrefetch";
T
tangwei12 已提交
45
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
Y
Yancey1989 已提交
46
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
47
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
48
constexpr char kRequestNotify[] = "RequestNotify";
49 50 51 52 53 54 55 56 57 58 59

constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
60

G
gongweibao 已提交
61 62 63
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
W
Wu Yi 已提交
64
#define COMPLETE_MESSAGE "COMPLETE@RECV"
65
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
66
#define LEARNING_RATE_DECAY_MESSAGE "LRDECAY@RECV"
G
gongweibao 已提交
67

T
tangwei12 已提交
68 69
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
T
tangwei12 已提交
70

71 72
class RPCServer;

73 74 75 76 77
class VarHandle {
 public:
  VarHandle(const std::string ep, const std::string& method,
            const std::string& name,
            const platform::DeviceContext* p_ctx = nullptr,
Q
Qiao Longfei 已提交
78
            const framework::Scope* p_scope = nullptr)
G
gongweibao 已提交
79
      : status_(kDefaultState) {
80 81 82 83 84 85 86 87 88 89
    ep_ = ep;
    ctx_ = p_ctx;
    scope_ = p_scope;
    name_ = name;
    method_ = method;
  }

  virtual ~VarHandle() {}

 public:
90 91
  bool should_retry = false;

92
  bool Wait() {
G
gongweibao 已提交
93
    int ret = kDefaultState;
94 95
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
96 97
      wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
      ret = status_;
98
    }
M
minqiyang 已提交
99
    VLOG(7) << "VarHandle wait:" << ret;
G
gongweibao 已提交
100
    return ret != kErrorState;
101 102 103 104 105
  }

  void Finish(bool ok) {
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
106
      status_ = ok ? kFinishState : kErrorState;
107
    }
M
minqiyang 已提交
108
    VLOG(7) << "VarHandle finish:" << ok;
109 110
    wait_cond_.notify_all();
  }
111 112 113

  std::string String() const {
    std::ostringstream s;
G
gongweibao 已提交
114 115
    s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
      << status_ << "]";
116 117
    return s.str();
  }
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

  std::string ep() const { return ep_; }
  const platform::DeviceContext* ctx() const { return ctx_; }
  const framework::Scope* scope() const { return scope_; }
  std::string name() const { return name_; }
  std::string method() const { return method_; }

 protected:
  // RPC endpoint.
  std::string ep_;
  const platform::DeviceContext* ctx_;
  const framework::Scope* scope_;
  // Variable name.
  std::string name_;
  // RPC method name.
  std::string method_;

 protected:
  std::mutex sync_mutex_;
  std::condition_variable wait_cond_;

G
gongweibao 已提交
139 140 141 142 143 144
  enum VarHandleStatus {
    kDefaultState = -1,
    kErrorState = 0,
    kFinishState = 1,
  };
  VarHandleStatus status_;
145 146 147

 private:
  DISABLE_COPY_AND_ASSIGN(VarHandle);
148 149
};

150 151
typedef std::shared_ptr<VarHandle> VarHandlePtr;

152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
class RequestHandler {
 public:
  explicit RequestHandler(bool sync_mode)
      : sync_mode_(sync_mode),
        dev_ctx_(nullptr),
        executor_(nullptr),
        scope_(nullptr),
        program_(nullptr),
        rpc_server_(nullptr) {}

  virtual ~RequestHandler() {}

  // Set attributes.
  void SetScope(framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
169 170

  // Used for dist lookup table prefetch
171
  void SetPrefetchPreparedCtx(
172 173 174
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    prefetch_var_name_to_prepared_ctx_ = g;
175 176
  }

T
tangwei12 已提交
177 178 179 180 181
  void SetCheckpointNotifyPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    checkpoint_prepared_ctx_ = g;
  }

182 183 184 185 186 187 188
  // Used for async.
  void SetGradToPreparedCtx(
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    grad_to_prepared_ctx_ = g;
  }

189 190 191 192
  void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
    sparse_grad_to_param_ = g;
  }

193 194 195 196 197
  void SetLrDecayPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    lr_decay_prepared_ctx_ = g;
  }

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }

  // Get attributes.
  bool sync_mode() { return sync_mode_; }
  framework::Scope* scope() { return scope_; }
  const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
  framework::ProgramDesc* program() { return program_; }
  framework::Executor* executor() { return executor_; }

  // This function processes user's rpc request.
  // The implemention is in request_handler_impl.
  // example:
  //    std::string varname = request_.varname();
  //
  //    auto scope = request_handler_->scope();
  //    auto invar = scope->FindVar(varname);
  //    framework::Variable* outvar = nullptr;
  //
  //    request_handler_->Handle(varname, scope, invar, &outvar);
  //    if (outvar) {
  //        SerializeToByteBuffer(varname, outvar,
  //           *request_handler_->dev_ctx(), &reply_);
  //    }
  virtual bool Handle(const std::string& varname, framework::Scope* scope,
Q
qiaolongfei 已提交
222
                      framework::Variable* var, framework::Variable** outvar,
W
Wu Yi 已提交
223
                      const int trainer_id,
Q
Qiao Longfei 已提交
224 225
                      const std::string& out_var_name = "",
                      const std::string& table_name = "") = 0;
226 227 228 229 230 231 232 233

 protected:
  const bool sync_mode_;

  const platform::DeviceContext* dev_ctx_;
  framework::Executor* executor_;
  framework::Scope* scope_;
  framework::ProgramDesc* program_;
234 235 236 237 238

  // used for distribute lookup table prefetch
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      prefetch_var_name_to_prepared_ctx_;
T
tangwei12 已提交
239 240
  // used for checkpoint notify
  std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
241 242 243 244 245

  // Used for async.
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      grad_to_prepared_ctx_;
246
  std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
247

248 249
  // used for lr decay
  std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
250 251 252
  RPCServer* rpc_server_;
};

253
}  // namespace distributed
254 255
}  // namespace operators
}  // namespace paddle