request_handler_impl.cc 11.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wang Guibao 已提交
15
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
16 17 18 19 20 21 22 23
#include <iostream>
#include <string>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
W
Wang Guibao 已提交
24
#include "paddle/fluid/framework/variable_helper.h"
25
#include "paddle/fluid/operators/distributed/rpc_server.h"
26
#include "paddle/fluid/string/piece.h"
T
tangwei12 已提交
27
#include "paddle/fluid/string/printf.h"
28
#include "paddle/fluid/string/split.h"
29

30 31 32
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
#include "paddle/fluid/operators/distributed/heart_beat_monitor.h"

33 34
namespace paddle {
namespace operators {
35
namespace distributed {
36

T
tangwei12 已提交
37 38
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
T
tangwei12 已提交
39
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
T
tangwei12 已提交
40

41 42 43
bool RequestSendHandler::Handle(const std::string& varname,
                                framework::Scope* scope,
                                framework::Variable* invar,
Q
qiaolongfei 已提交
44
                                framework::Variable** outvar,
W
Wu Yi 已提交
45
                                const int trainer_id,
Q
Qiao Longfei 已提交
46 47
                                const std::string& out_var_name,
                                const std::string& table_name) {
M
minqiyang 已提交
48
  VLOG(4) << "RequestSendHandler:" << varname;
49 50 51

  // Sync
  if (varname == BATCH_BARRIER_MESSAGE) {
M
minqiyang 已提交
52
    VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
53
    rpc_server_->IncreaseBatchBarrier(kRequestSend);
Y
Yancey1989 已提交
54
  } else if (varname == COMPLETE_MESSAGE) {
M
minqiyang 已提交
55
    VLOG(3) << "sync: recv complete message";
56 57 58 59 60

    if (HeartBeatMonitor::GetInstance() != nullptr) {
      HeartBeatMonitor::GetInstance()->Update(trainer_id, "", COMPLETED);
    }

Y
Yancey1989 已提交
61
    rpc_server_->Complete();
62
  } else {
63 64
    // Async
    if (!sync_mode_) {
M
minqiyang 已提交
65
      VLOG(3) << "async process var: " << varname;
Q
Qiao Longfei 已提交
66
      if (varname == BATCH_BARRIER_MESSAGE) {
Q
Qiao Longfei 已提交
67 68 69 70
        PADDLE_THROW(
            "async mode should not recv BATCH_BARRIER_MESSAGE or "
            "COMPLETE_MESSAGE");
      }
71
      HeartBeatMonitor::GetInstance()->Update(trainer_id, varname, RUNNING);
72 73 74 75 76 77 78 79 80 81 82 83 84 85

      std::string run_varname = varname;

      string::Piece part_piece("@PIECE");
      string::Piece var_name_piece = string::Piece(varname);

      if (string::Contains(var_name_piece, part_piece)) {
        auto varname_splits = paddle::string::Split(varname, '@');
        PADDLE_ENFORCE_EQ(varname_splits.size(), 3);
        run_varname = varname_splits[0];
        scope->Rename(varname, run_varname);
      }

      if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
Q
Qiao Longfei 已提交
86
        auto& grad_slr =
87 88
            scope->FindVar(run_varname)->Get<framework::SelectedRows>();
        AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
Q
Qiao Longfei 已提交
89 90
                                                              grad_slr.rows());
      }
91
      executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
Q
Qiao Longfei 已提交
92
                                    scope);
93

94 95 96
      return true;
    } else {  // sync
      rpc_server_->WaitCond(kRequestSend);
M
minqiyang 已提交
97
      VLOG(3) << "sync: processing received var: " << varname;
98

99 100 101 102
      if (invar == nullptr) {
        LOG(FATAL) << "sync: Can not find server side var: " << varname;
        return false;
      }
Y
Yancey1989 已提交
103
    }
104 105 106 107 108 109 110
  }
  return true;
}

bool RequestGetHandler::Handle(const std::string& varname,
                               framework::Scope* scope,
                               framework::Variable* invar,
Q
qiaolongfei 已提交
111
                               framework::Variable** outvar,
W
Wu Yi 已提交
112
                               const int trainer_id,
Q
Qiao Longfei 已提交
113 114
                               const std::string& out_var_name,
                               const std::string& table_name) {
Q
Qiao Longfei 已提交
115 116 117
  VLOG(3) << "RequestGetHandler:" << varname
          << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
          << " table_name: " << table_name;
118

Y
Yancey1989 已提交
119 120
  if (sync_mode_) {
    if (varname == FETCH_BARRIER_MESSAGE) {
M
minqiyang 已提交
121
      VLOG(3) << "sync: recv fetch barrier message";
Y
Yancey1989 已提交
122 123
      rpc_server_->IncreaseBatchBarrier(kRequestGet);
    } else {
124
      rpc_server_->WaitCond(kRequestGet);
Y
Yancey1989 已提交
125 126 127
      *outvar = scope_->FindVar(varname);
    }
  } else {
Y
Yancey1989 已提交
128
    if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
W
Wu Yi 已提交
129
      if (enable_dc_asgd_) {
T
tangwei12 已提交
130
        // NOTE: the format is determined by distribute_transpiler.py
W
Wu Yi 已提交
131 132
        std::string param_bak_name =
            string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
M
minqiyang 已提交
133
        VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
W
Wu Yi 已提交
134 135 136 137 138
        auto var = scope_->FindVar(varname);
        auto t_orig = var->Get<framework::LoDTensor>();
        auto param_bak = scope_->Var(param_bak_name);
        auto t = param_bak->GetMutable<framework::LoDTensor>();
        t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
M
minqiyang 已提交
139
        VLOG(3) << "copying " << varname << " to " << param_bak_name;
W
Wu Yi 已提交
140 141
        framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
      }
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
      VLOG(1) << "Table name empty? " << table_name.empty();
      VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
              << AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
                     varname);
      if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
          !table_name.empty()) {
        std::vector<int64_t> updated_rows;
        AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
            varname, trainer_id, &updated_rows);
        if (VLOG_IS_ON(3)) {
          std::ostringstream sstream;
          sstream << "[";
          for (auto& row_id : updated_rows) {
            sstream << row_id << ", ";
          }
          sstream << "]";
          VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
                  << sstream.str();
        }
        auto& origin_tensor =
            scope_->FindVar(varname)->Get<framework::LoDTensor>();
        auto* origin_tensor_data = origin_tensor.data<float>();
        auto& dims = origin_tensor.dims();
        *outvar = scope->Var();
        auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
        out_slr->set_rows(updated_rows);
        out_slr->set_height(dims[0]);
        auto out_dims = framework::make_ddim(
            {static_cast<int64_t>(updated_rows.size()), dims[1]});
        auto* data = out_slr->mutable_value()->mutable_data<float>(
            out_dims, origin_tensor.place());
        auto width = dims[1];
        for (auto i = 0; i < updated_rows.size(); ++i) {
          PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
          memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
                 sizeof(float) * width);
        }
      } else {
        *outvar = scope_->FindVar(varname);
      }
182 183 184 185 186
    }
  }
  return true;
}

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
                                        framework::Scope* scope,
                                        framework::Variable* invar,
                                        framework::Variable** outvar,
                                        const int trainer_id,
                                        const std::string& out_var_name,
                                        const std::string& table_name) {
  VLOG(4) << "RequestGetNoBarrierHandler:" << varname
          << " out_var_name: " << out_var_name;

  // get var from pserver immediately without barriers
  string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
  string::Piece var_name_piece = string::Piece(varname);

  if (string::Contains(var_name_piece, without_barrier_piece)) {
    var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
    VLOG(4) << "Get var " << var_name_piece << " with "
            << WITHOUT_BARRIER_MESSAGE;
    *outvar = scope_->FindVar(var_name_piece.ToString());
    return true;
  } else {
    PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
  }
  return true;
}

213 214 215
bool RequestPrefetchHandler::Handle(const std::string& varname,
                                    framework::Scope* scope,
                                    framework::Variable* invar,
Q
qiaolongfei 已提交
216
                                    framework::Variable** outvar,
W
Wu Yi 已提交
217
                                    const int trainer_id,
Q
Qiao Longfei 已提交
218 219
                                    const std::string& out_var_name,
                                    const std::string& table_name) {
M
minqiyang 已提交
220
  VLOG(4) << "RequestPrefetchHandler " << varname;
221

Q
Qiao Longfei 已提交
222
  if (table_name.empty()) {
Q
Qiao Longfei 已提交
223 224
    auto var_desc = program_->Block(0).FindVar(out_var_name);
    InitializeVariable(*outvar, var_desc->GetType());
Q
Qiao Longfei 已提交
225 226 227
    executor_->RunPreparedContext(
        (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
  } else {
Q
Qiao Longfei 已提交
228
    (*outvar)->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
229 230 231 232 233
    auto lookup_table_op =
        BuildLookupTableOp(table_name, varname, out_var_name);
    paddle::platform::CPUPlace cpu_place;
    lookup_table_op->Run(*scope, cpu_place);
  }
234 235 236
  return true;
}

T
tangwei12 已提交
237 238 239 240
bool RequestCheckpointHandler::Handle(const std::string& varname,
                                      framework::Scope* scope,
                                      framework::Variable* invar,
                                      framework::Variable** outvar,
W
Wu Yi 已提交
241
                                      const int trainer_id,
Q
Qiao Longfei 已提交
242 243
                                      const std::string& out_var_name,
                                      const std::string& table_name) {
244 245 246
  PADDLE_ENFORCE(
      checkpoint_notify_id != -1,
      "when checkpoint_notify_id = -1, there should be no RPC invoke.");
T
tangwei12 已提交
247

T
tangwei12 已提交
248
  // TODO(tangwei12): find out why scope will be error.
T
bug fix  
tangwei12 已提交
249
  auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
T
tangwei12 已提交
250 251
  lt_var->clear();
  lt_var->append(out_var_name);
M
minqiyang 已提交
252 253
  VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
          << out_var_name;
T
bug fix  
tangwei12 已提交
254
  executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
T
bug fix  
tangwei12 已提交
255 256
  return true;
}
T
tangwei12 已提交
257

258 259 260 261 262 263 264
bool RequestNotifyHandler::Handle(const std::string& varname,
                                  framework::Scope* scope,
                                  framework::Variable* invar,
                                  framework::Variable** outvar,
                                  const int trainer_id,
                                  const std::string& out_var_name,
                                  const std::string& table_name) {
1
123malin 已提交
265 266 267 268 269 270 271
  VLOG(4) << "RequestNotifyHandler: " << varname;
  VLOG(3) << "async process var: " << varname << ", trainer_id: " << trainer_id;

  string::Piece decay_piece(LEARNING_RATE_DECAY_COUNTER);
  string::Piece var_name_piece = string::Piece(varname);
  if (string::Contains(var_name_piece, decay_piece)) {
    VLOG(3) << "LearningRate Decay Counter Update";
272 273 274
    PADDLE_ENFORCE_NE(
        lr_decay_block_id, -1,
        "when lr_decay_block_id = -1, there should be no RPC invoke.");
1
123malin 已提交
275 276 277 278 279 280 281 282 283
    auto* origin_var = scope_->FindVar(varname);
    auto origin_var_tensor = origin_var->Get<framework::LoDTensor>();
    auto* send_var = scope->FindVar(varname);
    auto send_var_tensor = send_var->Get<framework::LoDTensor>();
    int64_t* origin_value =
        origin_var_tensor.mutable_data<int64_t>(origin_var_tensor.place());
    int64_t* send_value =
        send_var_tensor.mutable_data<int64_t>(send_var_tensor.place());
    origin_value[0] += send_value[0];
284 285 286 287 288
    executor_->RunPreparedContext(lr_decay_prepared_ctx_.get(), scope_);
  }
  return true;
}

289
}  // namespace distributed
290 291
}  // namespace operators
}  // namespace paddle