request_handler_impl.cc 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

W
Wang Guibao 已提交
15
#include "paddle/fluid/operators/distributed/request_handler_impl.h"
16 17 18 19 20 21 22 23
#include <iostream>
#include <string>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
W
Wang Guibao 已提交
24
#include "paddle/fluid/framework/variable_helper.h"
Q
Qiao Longfei 已提交
25
#include "paddle/fluid/operators/distributed/async_sparse_param_update_recorder.h"
26
#include "paddle/fluid/operators/distributed/rpc_server.h"
27
#include "paddle/fluid/string/piece.h"
T
tangwei12 已提交
28
#include "paddle/fluid/string/printf.h"
29
#include "paddle/fluid/string/split.h"
30 31 32

namespace paddle {
namespace operators {
33
namespace distributed {
34

T
tangwei12 已提交
35 36
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
T
tangwei12 已提交
37
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
T
tangwei12 已提交
38

39 40 41
bool RequestSendHandler::Handle(const std::string& varname,
                                framework::Scope* scope,
                                framework::Variable* invar,
Q
qiaolongfei 已提交
42
                                framework::Variable** outvar,
W
Wu Yi 已提交
43
                                const int trainer_id,
Q
Qiao Longfei 已提交
44 45
                                const std::string& out_var_name,
                                const std::string& table_name) {
M
minqiyang 已提交
46
  VLOG(4) << "RequestSendHandler:" << varname;
47 48 49

  // Sync
  if (varname == BATCH_BARRIER_MESSAGE) {
M
minqiyang 已提交
50
    VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
51
    rpc_server_->IncreaseBatchBarrier(kRequestSend);
Y
Yancey1989 已提交
52
  } else if (varname == COMPLETE_MESSAGE) {
M
minqiyang 已提交
53
    VLOG(3) << "sync: recv complete message";
Y
Yancey1989 已提交
54
    rpc_server_->Complete();
55
  } else {
56 57
    // Async
    if (!sync_mode_) {
M
minqiyang 已提交
58
      VLOG(3) << "async process var: " << varname;
Q
Qiao Longfei 已提交
59
      if (varname == BATCH_BARRIER_MESSAGE) {
Q
Qiao Longfei 已提交
60 61 62 63
        PADDLE_THROW(
            "async mode should not recv BATCH_BARRIER_MESSAGE or "
            "COMPLETE_MESSAGE");
      }
64 65 66 67 68 69 70 71 72 73 74 75 76 77

      std::string run_varname = varname;

      string::Piece part_piece("@PIECE");
      string::Piece var_name_piece = string::Piece(varname);

      if (string::Contains(var_name_piece, part_piece)) {
        auto varname_splits = paddle::string::Split(varname, '@');
        PADDLE_ENFORCE_EQ(varname_splits.size(), 3);
        run_varname = varname_splits[0];
        scope->Rename(varname, run_varname);
      }

      if (AsyncSparseParamUpdateRecorder::GetInstance()->HasGrad(run_varname)) {
Q
Qiao Longfei 已提交
78
        auto& grad_slr =
79 80
            scope->FindVar(run_varname)->Get<framework::SelectedRows>();
        AsyncSparseParamUpdateRecorder::GetInstance()->Update(run_varname,
Q
Qiao Longfei 已提交
81 82
                                                              grad_slr.rows());
      }
83
      executor_->RunPreparedContext((*grad_to_prepared_ctx_)[run_varname].get(),
Q
Qiao Longfei 已提交
84
                                    scope);
85 86 87
      return true;
    } else {  // sync
      rpc_server_->WaitCond(kRequestSend);
M
minqiyang 已提交
88
      VLOG(3) << "sync: processing received var: " << varname;
89

90 91 92 93
      if (invar == nullptr) {
        LOG(FATAL) << "sync: Can not find server side var: " << varname;
        return false;
      }
Y
Yancey1989 已提交
94
    }
95 96 97 98 99 100 101
  }
  return true;
}

bool RequestGetHandler::Handle(const std::string& varname,
                               framework::Scope* scope,
                               framework::Variable* invar,
Q
qiaolongfei 已提交
102
                               framework::Variable** outvar,
W
Wu Yi 已提交
103
                               const int trainer_id,
Q
Qiao Longfei 已提交
104 105
                               const std::string& out_var_name,
                               const std::string& table_name) {
Q
Qiao Longfei 已提交
106 107 108
  VLOG(3) << "RequestGetHandler:" << varname
          << " out_var_name: " << out_var_name << " trainer_id: " << trainer_id
          << " table_name: " << table_name;
109

Y
Yancey1989 已提交
110 111
  if (sync_mode_) {
    if (varname == FETCH_BARRIER_MESSAGE) {
M
minqiyang 已提交
112
      VLOG(3) << "sync: recv fetch barrier message";
Y
Yancey1989 已提交
113 114
      rpc_server_->IncreaseBatchBarrier(kRequestGet);
    } else {
115
      rpc_server_->WaitCond(kRequestGet);
Y
Yancey1989 已提交
116 117 118
      *outvar = scope_->FindVar(varname);
    }
  } else {
Y
Yancey1989 已提交
119
    if (varname != FETCH_BARRIER_MESSAGE && varname != COMPLETE_MESSAGE) {
W
Wu Yi 已提交
120
      if (enable_dc_asgd_) {
T
tangwei12 已提交
121
        // NOTE: the format is determined by distribute_transpiler.py
W
Wu Yi 已提交
122 123
        std::string param_bak_name =
            string::Sprintf("%s.trainer_%d_bak", varname, trainer_id);
M
minqiyang 已提交
124
        VLOG(3) << "getting " << param_bak_name << " trainer_id " << trainer_id;
W
Wu Yi 已提交
125 126 127 128 129
        auto var = scope_->FindVar(varname);
        auto t_orig = var->Get<framework::LoDTensor>();
        auto param_bak = scope_->Var(param_bak_name);
        auto t = param_bak->GetMutable<framework::LoDTensor>();
        t->mutable_data(dev_ctx_->GetPlace(), t_orig.type());
M
minqiyang 已提交
130
        VLOG(3) << "copying " << varname << " to " << param_bak_name;
W
Wu Yi 已提交
131 132
        framework::TensorCopy(t_orig, dev_ctx_->GetPlace(), t);
      }
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
      VLOG(1) << "Table name empty? " << table_name.empty();
      VLOG(1) << "AsyncSparseParamUpdateRecorder " << varname << " exist "
              << AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(
                     varname);
      if (AsyncSparseParamUpdateRecorder::GetInstance()->HasParam(varname) &&
          !table_name.empty()) {
        std::vector<int64_t> updated_rows;
        AsyncSparseParamUpdateRecorder::GetInstance()->GetAndClear(
            varname, trainer_id, &updated_rows);
        if (VLOG_IS_ON(3)) {
          std::ostringstream sstream;
          sstream << "[";
          for (auto& row_id : updated_rows) {
            sstream << row_id << ", ";
          }
          sstream << "]";
          VLOG(3) << "updated_rows size: " << updated_rows.size() << " "
                  << sstream.str();
        }
        auto& origin_tensor =
            scope_->FindVar(varname)->Get<framework::LoDTensor>();
        auto* origin_tensor_data = origin_tensor.data<float>();
        auto& dims = origin_tensor.dims();
        *outvar = scope->Var();
        auto* out_slr = (*outvar)->GetMutable<framework::SelectedRows>();
        out_slr->set_rows(updated_rows);
        out_slr->set_height(dims[0]);
        auto out_dims = framework::make_ddim(
            {static_cast<int64_t>(updated_rows.size()), dims[1]});
        auto* data = out_slr->mutable_value()->mutable_data<float>(
            out_dims, origin_tensor.place());
        auto width = dims[1];
        for (auto i = 0; i < updated_rows.size(); ++i) {
          PADDLE_ENFORCE_LT(updated_rows[i], dims[0]);
          memcpy(data + i * width, origin_tensor_data + updated_rows[i] * width,
                 sizeof(float) * width);
        }
      } else {
        *outvar = scope_->FindVar(varname);
      }
173 174 175 176 177
    }
  }
  return true;
}

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
                                        framework::Scope* scope,
                                        framework::Variable* invar,
                                        framework::Variable** outvar,
                                        const int trainer_id,
                                        const std::string& out_var_name,
                                        const std::string& table_name) {
  VLOG(4) << "RequestGetNoBarrierHandler:" << varname
          << " out_var_name: " << out_var_name;

  // get var from pserver immediately without barriers
  string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
  string::Piece var_name_piece = string::Piece(varname);

  if (string::Contains(var_name_piece, without_barrier_piece)) {
    var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
    VLOG(4) << "Get var " << var_name_piece << " with "
            << WITHOUT_BARRIER_MESSAGE;
    *outvar = scope_->FindVar(var_name_piece.ToString());
    return true;
  } else {
    PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
  }
  return true;
}

204 205 206
bool RequestPrefetchHandler::Handle(const std::string& varname,
                                    framework::Scope* scope,
                                    framework::Variable* invar,
Q
qiaolongfei 已提交
207
                                    framework::Variable** outvar,
W
Wu Yi 已提交
208
                                    const int trainer_id,
Q
Qiao Longfei 已提交
209 210
                                    const std::string& out_var_name,
                                    const std::string& table_name) {
M
minqiyang 已提交
211
  VLOG(4) << "RequestPrefetchHandler " << varname;
212

Q
Qiao Longfei 已提交
213
  if (table_name.empty()) {
Q
Qiao Longfei 已提交
214 215
    auto var_desc = program_->Block(0).FindVar(out_var_name);
    InitializeVariable(*outvar, var_desc->GetType());
Q
Qiao Longfei 已提交
216 217 218
    executor_->RunPreparedContext(
        (*prefetch_var_name_to_prepared_ctx_)[varname].get(), scope);
  } else {
Q
Qiao Longfei 已提交
219
    (*outvar)->GetMutable<framework::LoDTensor>();
Q
Qiao Longfei 已提交
220 221 222 223 224
    auto lookup_table_op =
        BuildLookupTableOp(table_name, varname, out_var_name);
    paddle::platform::CPUPlace cpu_place;
    lookup_table_op->Run(*scope, cpu_place);
  }
225 226 227
  return true;
}

T
tangwei12 已提交
228 229 230 231
bool RequestCheckpointHandler::Handle(const std::string& varname,
                                      framework::Scope* scope,
                                      framework::Variable* invar,
                                      framework::Variable** outvar,
W
Wu Yi 已提交
232
                                      const int trainer_id,
Q
Qiao Longfei 已提交
233 234
                                      const std::string& out_var_name,
                                      const std::string& table_name) {
235 236 237
  PADDLE_ENFORCE(
      checkpoint_notify_id != -1,
      "when checkpoint_notify_id = -1, there should be no RPC invoke.");
T
tangwei12 已提交
238

T
tangwei12 已提交
239
  // TODO(tangwei12): find out why scope will be error.
T
bug fix  
tangwei12 已提交
240
  auto* lt_var = scope_->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
T
tangwei12 已提交
241 242
  lt_var->clear();
  lt_var->append(out_var_name);
M
minqiyang 已提交
243 244
  VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
          << out_var_name;
T
bug fix  
tangwei12 已提交
245
  executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope_);
T
bug fix  
tangwei12 已提交
246 247
  return true;
}
T
tangwei12 已提交
248

249
}  // namespace distributed
250 251
}  // namespace operators
}  // namespace paddle