request_handler.h 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <time.h>
X
Xin Pan 已提交
18
#include <condition_variable>  // NOLINT
19 20

#include <functional>
Q
Qiao Longfei 已提交
21
#include <memory>
22
#include <string>
Q
Qiao Longfei 已提交
23
#include <unordered_map>
24 25 26 27 28 29 30 31 32 33
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type.h"
34
#include "paddle/fluid/platform/macros.h"
35 36 37

namespace paddle {
namespace operators {
38
namespace distributed {
39 40 41

constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet";
42 43
constexpr char kRequestGetMonomerVariable[] = "RequestGetMonomerVariable";
constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
44
constexpr char kRequestPrefetch[] = "RequestPrefetch";
T
tangwei12 已提交
45
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
Y
Yancey1989 已提交
46
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
47 48 49 50 51 52 53 54 55 56 57 58
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";

constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
59

G
gongweibao 已提交
60 61 62
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
W
Wu Yi 已提交
63
#define COMPLETE_MESSAGE "COMPLETE@RECV"
64
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
G
gongweibao 已提交
65

T
tangwei12 已提交
66 67
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
T
tangwei12 已提交
68

69 70
class RPCServer;

71 72 73 74 75
class VarHandle {
 public:
  VarHandle(const std::string ep, const std::string& method,
            const std::string& name,
            const platform::DeviceContext* p_ctx = nullptr,
Q
Qiao Longfei 已提交
76
            const framework::Scope* p_scope = nullptr)
G
gongweibao 已提交
77
      : status_(kDefaultState) {
78 79 80 81 82 83 84 85 86 87
    ep_ = ep;
    ctx_ = p_ctx;
    scope_ = p_scope;
    name_ = name;
    method_ = method;
  }

  virtual ~VarHandle() {}

 public:
88 89
  bool should_retry = false;

90
  bool Wait() {
G
gongweibao 已提交
91
    int ret = kDefaultState;
92 93
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
94 95
      wait_cond_.wait(lk, [this] { return status_ != kDefaultState; });
      ret = status_;
96
    }
M
minqiyang 已提交
97
    VLOG(7) << "VarHandle wait:" << ret;
G
gongweibao 已提交
98
    return ret != kErrorState;
99 100 101 102 103
  }

  void Finish(bool ok) {
    {
      std::unique_lock<std::mutex> lk(sync_mutex_);
G
gongweibao 已提交
104
      status_ = ok ? kFinishState : kErrorState;
105
    }
M
minqiyang 已提交
106
    VLOG(7) << "VarHandle finish:" << ok;
107 108
    wait_cond_.notify_all();
  }
109 110 111

  std::string String() const {
    std::ostringstream s;
G
gongweibao 已提交
112 113
    s << method_ << " name:[" << name_ << "], ep:[" << ep_ << "], status:["
      << status_ << "]";
114 115
    return s.str();
  }
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136

  std::string ep() const { return ep_; }
  const platform::DeviceContext* ctx() const { return ctx_; }
  const framework::Scope* scope() const { return scope_; }
  std::string name() const { return name_; }
  std::string method() const { return method_; }

 protected:
  // RPC endpoint.
  std::string ep_;
  const platform::DeviceContext* ctx_;
  const framework::Scope* scope_;
  // Variable name.
  std::string name_;
  // RPC method name.
  std::string method_;

 protected:
  std::mutex sync_mutex_;
  std::condition_variable wait_cond_;

G
gongweibao 已提交
137 138 139 140 141 142
  enum VarHandleStatus {
    kDefaultState = -1,
    kErrorState = 0,
    kFinishState = 1,
  };
  VarHandleStatus status_;
143 144 145

 private:
  DISABLE_COPY_AND_ASSIGN(VarHandle);
146 147
};

148 149
typedef std::shared_ptr<VarHandle> VarHandlePtr;

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
class RequestHandler {
 public:
  explicit RequestHandler(bool sync_mode)
      : sync_mode_(sync_mode),
        dev_ctx_(nullptr),
        executor_(nullptr),
        scope_(nullptr),
        program_(nullptr),
        rpc_server_(nullptr) {}

  virtual ~RequestHandler() {}

  // Set attributes.
  void SetScope(framework::Scope* scope) { scope_ = scope; }
  void SetDevCtx(const platform::DeviceContext* dev_ctx) { dev_ctx_ = dev_ctx; }
  void SetProgram(framework::ProgramDesc* program) { program_ = program; }
  void SetExecutor(framework::Executor* executor) { executor_ = executor; }
167 168

  // Used for dist lookup table prefetch
169
  void SetPrefetchPreparedCtx(
170 171 172
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    prefetch_var_name_to_prepared_ctx_ = g;
173 174
  }

T
tangwei12 已提交
175 176 177 178 179
  void SetCheckpointNotifyPreparedCtx(
      std::shared_ptr<framework::ExecutorPrepareContext> g) {
    checkpoint_prepared_ctx_ = g;
  }

180 181 182 183 184 185 186
  // Used for async.
  void SetGradToPreparedCtx(
      std::unordered_map<
          std::string, std::shared_ptr<framework::ExecutorPrepareContext>>* g) {
    grad_to_prepared_ctx_ = g;
  }

187 188 189 190
  void SetSparseGradToParam(std::unordered_map<std::string, std::string>* g) {
    sparse_grad_to_param_ = g;
  }

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  void SetRPCServer(RPCServer* rpc_server) { rpc_server_ = rpc_server; }

  // Get attributes.
  bool sync_mode() { return sync_mode_; }
  framework::Scope* scope() { return scope_; }
  const platform::DeviceContext* dev_ctx() { return dev_ctx_; }
  framework::ProgramDesc* program() { return program_; }
  framework::Executor* executor() { return executor_; }

  // This function processes user's rpc request.
  // The implemention is in request_handler_impl.
  // example:
  //    std::string varname = request_.varname();
  //
  //    auto scope = request_handler_->scope();
  //    auto invar = scope->FindVar(varname);
  //    framework::Variable* outvar = nullptr;
  //
  //    request_handler_->Handle(varname, scope, invar, &outvar);
  //    if (outvar) {
  //        SerializeToByteBuffer(varname, outvar,
  //           *request_handler_->dev_ctx(), &reply_);
  //    }
  virtual bool Handle(const std::string& varname, framework::Scope* scope,
Q
qiaolongfei 已提交
215
                      framework::Variable* var, framework::Variable** outvar,
W
Wu Yi 已提交
216
                      const int trainer_id,
Q
Qiao Longfei 已提交
217 218
                      const std::string& out_var_name = "",
                      const std::string& table_name = "") = 0;
219 220 221 222 223 224 225 226

 protected:
  const bool sync_mode_;

  const platform::DeviceContext* dev_ctx_;
  framework::Executor* executor_;
  framework::Scope* scope_;
  framework::ProgramDesc* program_;
227 228 229 230 231

  // used for distribute lookup table prefetch
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      prefetch_var_name_to_prepared_ctx_;
T
tangwei12 已提交
232 233
  // used for checkpoint notify
  std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
234 235 236 237 238

  // Used for async.
  std::unordered_map<std::string,
                     std::shared_ptr<framework::ExecutorPrepareContext>>*
      grad_to_prepared_ctx_;
239
  std::unordered_map<std::string, std::string>* sparse_grad_to_param_;
240

241 242 243
  RPCServer* rpc_server_;
};

244
}  // namespace distributed
245 246
}  // namespace operators
}  // namespace paddle