request_handler.h 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <condition_variable>  // NOLINT
19 20

#include <functional>
Q
Qiao Longfei 已提交
21
#include <memory>
22
#include <string>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
24 25 26 27 28 29 30 31 32 33
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
34
#include "paddle/fluid/platform/macros.h"
35 36 37

namespace paddle {
namespace operators {
38
namespace distributed {
39 40 41

constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
42 43
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
44
constexpr char kRequestPrefetch[] = "RequestPrefetch";
T
tangwei12 已提交
45
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
Y
Yancey1989 已提交
46
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
47
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
48
constexpr char kRequestNotify[] = "RequestNotify";
49 50 51 52 53 54 55 56 57 58 59

constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
60

G
gongweibao 已提交
61 62 63
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
W
Wu Yi 已提交
64
#define COMPLETE_MESSAGE "COMPLETE@RECV"
65
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
1
123malin 已提交
66
#define LEARNING_RATE_DECAY_COUNTER "@LR_DECAY_COUNTER@"
G
gongweibao 已提交
67

T
tangwei12 已提交
68 69
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
T
tangwei12 已提交
70

71 72
enum DistributedMode { kSync = 0, kAsync = 1, kHalfAsync = 2, kGeo = 3 };

73 74
class RPCServer;

75 76 77 78 79
class VarHandle {
 public:
  VarHandle(const std::string ep, const std::string& method,
            const std::string& name,
            const platform::DeviceContext* p_ctx = nullptr,
Q
Qiao Longfei 已提交
80
            const framework::Scope* p_scope = nullptr)
G
gongweibao 已提交
81
      : status_(kDefaultState) {
82 83 84 85 86 87 88 89 90 91
    ep_ = ep;
    ctx_ = p_ctx;
    scope_ = p_scope;
    name_ = name;
    method_ = method;
  }

  virtual ~VarHandle() {}

 public:
92 93
  bool should_retry = false;

94
  bool Wait() {
G
gongweibao 已提交
95
    int ret = kDefaultState;
96 97
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
98 99
      wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
      ret = status_;
100
    }
M
minqiyang 已提交
101
    VLOG(7) << "VarHandle wait:" << ret;
G
gongweibao 已提交
102
    return ret != kErrorState;
103 104 105 106 107
  }

  void Finish(bool ok) {
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
108
      status_ = ok ? kFinishState : kErrorState;
109
    }
M
minqiyang 已提交
110
    VLOG(7) << "VarHandle finish:" << ok;
111 112
    wait_cond_.notify_all();
  }
113 114 115

  std::string String() const {
    std::ostringstream s;
G
gongweibao 已提交
116 117
    s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
      << status_ << "]";
118 119
    return s.str();
  }
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

  std::string ep() const { return ep_; }
  const platform::DeviceContext* ctx() const { return ctx_; }
  const framework::Scope* scope() const { return scope_; }
  std::string name() const { return name_; }
  std::string method() const { return method_; }

 protected:
  // RPC endpoint.
  std::string ep_;
  const platform::DeviceContext* ctx_;
  const framework::Scope* scope_;
  // Variable name.
  std::string name_;
  // RPC method name.
  std::string method_;

 protected:
  std::mutex sync_mutex_;
  std::condition_variable wait_cond_;

G
gongweibao 已提交
141 142 143 144 145 146
  enum VarHandleStatus {
    kDefaultState = -1,
    kErrorState = 0,
    kFinishState = 1,
  };
  VarHandleStatus status_;
147 148 149

 private:
  DISABLE_COPY_AND_ASSIGN(VarHandle);
150 151
};

152 153
typedef std::shared_ptr<VarHandle> VarHandlePtr;

154 155
class RequestHandler {
 public:
156 157
  explicit RequestHandler(int distributed_mode)
      : distributed_mode_(distributed_mode),
158 159 160 161 162 163 164 165 166 167 168 169 170
        dev_ctx_(nullptr),
        executor_(nullptr),
        scope_(nullptr),
        program_(nullptr),
        rpc_server_(nullptr) {}

  virtual ~RequestHandler() {}

  // Set attributes.
  void SetScope(framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
171 172

  // Used for dist lookup table prefetch
173
  void SetPrefetchPreparedCtx(
174 175 176
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    prefetch_var_name_to_prepared_ctx_ = g;
177 178
  }

T
tangwei12 已提交
179 180 181 182 183
  void SetCheckpointNotifyPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    checkpoint_prepared_ctx_ = g;
  }

184 185 186 187 188 189 190
  // Used for async.
  void SetGradToPreparedCtx(
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    grad_to_prepared_ctx_ = g;
  }

191 192 193 194
  void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
    sparse_grad_to_param_ = g;
  }

195 196 197 198 199
  void SetLrDecayPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    lr_decay_prepared_ctx_ = g;
  }

200 201 202
  void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }

  // Get attributes.
203
  int distributed_mode() { return distributed_mode_; }
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
  framework::Scope* scope() { return scope_; }
  const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
  framework::ProgramDesc* program() { return program_; }
  framework::Executor* executor() { return executor_; }

  // This function processes user's rpc request.
  // The implemention is in request_handler_impl.
  // example:
  //    std::string varname = request_.varname();
  //
  //    auto scope = request_handler_->scope();
  //    auto invar = scope->FindVar(varname);
  //    framework::Variable* outvar = nullptr;
  //
  //    request_handler_->Handle(varname, scope, invar, &outvar);
  //    if (outvar) {
  //        SerializeToByteBuffer(varname, outvar,
  //           *request_handler_->dev_ctx(), &reply_);
  //    }
  virtual bool Handle(const std::string& varname, framework::Scope* scope,
Q
qiaolongfei 已提交
224
                      framework::Variable* var, framework::Variable** outvar,
W
Wu Yi 已提交
225
                      const int trainer_id,
Q
Qiao Longfei 已提交
226 227
                      const std::string& out_var_name = "",
                      const std::string& table_name = "") = 0;
228 229

 protected:
230
  const int distributed_mode_;
231 232 233 234 235

  const platform::DeviceContext* dev_ctx_;
  framework::Executor* executor_;
  framework::Scope* scope_;
  framework::ProgramDesc* program_;
236 237 238 239 240

  // used for distribute lookup table prefetch
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      prefetch_var_name_to_prepared_ctx_;
T
tangwei12 已提交
241 242
  // used for checkpoint notify
  std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
243 244 245 246 247

  // Used for async.
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      grad_to_prepared_ctx_;
248
  std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
249

250 251
  // used for lr decay
  std::shared_ptr<framework::ExecutorPrepareContext> lr_decay_prepared_ctx_;
252 253 254
  RPCServer* rpc_server_;
};

255
}  // namespace distributed
256 257
}  // namespace operators
}  // namespace paddle