hogwild_worker.cc 14.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <array>
T
tangwei12 已提交
16
#include <ctime>
17

L
lxsbupt 已提交
18
#include "paddle/fluid/framework/barrier.h"
19
#include "paddle/fluid/framework/convert_utils.h"
20
#include "paddle/fluid/framework/data_type.h"
21
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
22
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
23
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
24
#include "paddle/fluid/platform/lodtensor_printer.h"
25

T
tangwei12 已提交
26
#if defined PADDLE_WITH_PSCORE
27
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
28 29
#endif

30 31 32
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
33
#include "paddle/phi/core/flags.h"
34

35
PHI_DECLARE_bool(enable_exit_when_partial_worker);
L
lxsbupt 已提交
36

37 38 39
namespace paddle {
namespace framework {

40
std::atomic<bool> HogwildWorker::quit_flag_(false);
L
lxsbupt 已提交
41 42
Barrier g_barrier;

43
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
44
  fetch_config_ = desc.fetch_config();
45 46
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
47
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
48 49
    skip_ops_[i] = param_.skip_ops(i);
  }
50
  use_cvm_ = desc.use_cvm();
51
  thread_barrier_ = desc.thread_barrier();
52

53 54 55
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
56 57
}

58 59
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
60
  op_names_.clear();
61
  for (auto &op_desc : block.AllOps()) {
62 63
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
64
    OperatorBase *local_op_ptr = local_op.release();
65 66 67
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
68 69
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
70 71
}

72 73
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
74 75

  PADDLE_ENFORCE_NOT_NULL(
76 77 78
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
79 80

  thread_scope_ = &root_scope_->NewScope();
81 82

  for (auto &var : block.AllVars()) {
83
    all_param_.push_back(var->Name());
84
    if (var->Persistable()) {
85
      auto *ptr = root_scope_->Var(var->Name());
86
      InitializeVariable(ptr, var->GetType());
87 88
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
89 90 91
        int tensor_dim = root_scope_->FindVar(var->Name())
                             ->GetMutable<phi::DenseTensor>()
                             ->numel();
92 93
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
94 95 96
        phi::DenseTensor *thread_tensor = ptr1->GetMutable<phi::DenseTensor>();
        phi::DenseTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<phi::DenseTensor>();
97 98 99 100 101
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
102 103 104
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
105
    } else {
106
      auto *ptr = thread_scope_->Var(var->Name());
107 108 109 110 111
      InitializeVariable(ptr, var->GetType());
    }
  }
}

112
template <typename T>
113 114
void HogwildWorker::SetZero(phi::DenseTensor *tensor,
                            phi::DenseTensor *root_tensor,
115 116 117 118 119
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

120
void HogwildWorker::BindingDataFeedMemory() {
121
  const std::vector<std::string> &input_feed =
122
      device_reader_->GetUseSlotAlias();
123
  for (auto name : input_feed) {
124
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
125 126 127
  }
}

128
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
129 130 131
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  float *stat_ptr = sync_stat_.mutable_data<float>(place_, sizeof(float) * 3);
  float flags[] = {0.0, 1.0, 0.0};
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(stat_ptr,  // output
                                             &flags,
                                             sizeof(float) * 3,
                                             cudaMemcpyHostToDevice,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
// check batch num
bool HogwildWorker::CheckBatchNum(int flag) {
  float ret = 0.0;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  if (flag > 1) {
    flag = 1;
  } else if (flag < 0) {
    flag = 0;
  }
  g_barrier.wait();
  float *stat_ptr = sync_stat_.data<float>();
  auto comm =
      platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId());
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag],
                                                              &stat_ptr[2],
                                                              1,
                                                              ncclFloat32,
                                                              ncclProd,
                                                              comm->comm(),
                                                              stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret,  // output
                                             &stat_ptr[2],
                                             sizeof(float),
                                             cudaMemcpyDeviceToHost,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
  g_barrier.wait();
#endif
  return (ret > 0.0);
}
175 176
void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
D
danleifeng 已提交
177 178 179 180 181 182
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
183
  device_reader_->Start();
184 185
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
186
  for (auto &op : ops_) {
187 188 189
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
190 191
  for (double &op_time : op_total_time) {
    op_time = 0.0;
192 193 194 195 196 197
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
L
lxsbupt 已提交
198
  if (thread_id_ == 0) {
199
    quit_flag_.store(false);
L
lxsbupt 已提交
200 201
  }
  g_barrier.wait();
202
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
L
lxsbupt 已提交
203
  bool train_mode = device_reader_->IsTrainMode();
204 205 206 207 208 209 210
  bool is_multi_node = false;
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif

211
  timeline.Start();
D
dongdaxiang 已提交
212
  uint64_t total_inst = 0;
L
lxsbupt 已提交
213 214 215
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
216
  while (true) {
L
lxsbupt 已提交
217
    cur_batch = device_reader_->Next();
218 219 220
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
221 222
        break;
      }
223 224 225 226 227 228 229 230 231 232
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
233
    }
234
#endif
L
lxsbupt 已提交
235 236 237
    if (cur_batch <= 0) {
      break;
    }
238
    VLOG(3) << "read a batch in thread " << thread_id_;
239 240 241 242
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
243
      bool need_skip = false;
244 245
      for (auto &skip_op : skip_ops_) {
        if (ops_[i]->Type().find(skip_op) != std::string::npos) {
246 247 248 249
          need_skip = true;
          break;
        }
      }
250
      timeline.Start();
251
      VLOG(3) << "Going to run op " << op_name[i];
252 253
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
254 255 256
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
257
      }
258
      VLOG(3) << "Op " << op_name[i] << " Finished";
259 260 261 262
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
263 264

    if (need_dump_field_) {
H
hutuxian 已提交
265 266 267 268
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
269 270
    }

D
dongdaxiang 已提交
271
    total_inst += cur_batch;
272
    ++batch_cnt;
D
dongdaxiang 已提交
273
    PrintFetchVars();
274 275 276 277 278 279 280 281
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
282 283 284
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
285 286 287 288 289
          fprintf(stderr,
                  "op_name:[%zu][%s], op_mean_time:[%fs]\n",
                  i,
                  op_name[i].c_str(),
                  op_total_time[i] / batch_cnt);
290 291
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
292
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
293
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
294 295
      }
    }
296
#endif
D
dongdaxiang 已提交
297
    thread_scope_->DropKids();
298 299
    timeline.Start();
  }
D
danleifeng 已提交
300 301 302
  VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time
          << " seconds, ins_num: " << total_inst << " read time: " << read_time
          << "seconds ";
303

H
hutuxian 已提交
304
  if (need_dump_field_ || need_dump_param_) {
305 306 307
    writer_.Flush();
  }

T
tangwei12 已提交
308
#if defined PADDLE_WITH_PSCORE
309
  if (thread_barrier_) {
T
tangwei12 已提交
310
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
311 312
  }
#endif
313 314 315
}
void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
316 317
  platform::Timer timeline;
  timeline.Start();
D
danleifeng 已提交
318 319 320 321 322 323
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
324

D
danleifeng 已提交
325
  int total_batch_num = 0;
326
  // how to accumulate fetched values here
327
  device_reader_->Start();
328
  int cur_batch;
W
wangguanqun 已提交
329
  int batch_cnt = 0;
L
lxsbupt 已提交
330
  if (thread_id_ == 0) {
331 332
    quit_flag_.store(false);
    // quit_flag_2 = false;
L
lxsbupt 已提交
333 334
  }
  g_barrier.wait();
D
danleifeng 已提交
335

L
lxsbupt 已提交
336 337 338 339
#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
  platform::SetDeviceId(thread_id_);
#endif
  // while ((cur_batch = device_reader_->Next()) > 0) {
340 341
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
  bool is_multi_node = false;
L
lxsbupt 已提交
342
  bool train_mode = device_reader_->IsTrainMode();
343 344 345 346 347
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif
L
lxsbupt 已提交
348 349 350
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
351
  while (true) {
L
lxsbupt 已提交
352
    cur_batch = device_reader_->Next();
353 354 355
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
356 357
        break;
      }
358 359 360 361 362 363 364 365 366 367
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
368
    }
369
#endif
L
lxsbupt 已提交
370 371 372
    if (cur_batch <= 0) {
      break;
    }
373
    for (auto &op : ops_) {
374
      bool need_skip = false;
375 376
      for (auto &skip_op : skip_ops_) {
        if (op->Type().find(skip_op) != std::string::npos) {
377 378 379 380 381 382 383
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
384 385
    }

W
wangguanqun 已提交
386 387 388 389 390 391 392
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

D
danleifeng 已提交
393
    total_batch_num += cur_batch;
W
wangguanqun 已提交
394
    ++batch_cnt;
D
dongdaxiang 已提交
395
    PrintFetchVars();
D
dongdaxiang 已提交
396
    thread_scope_->DropKids();
D
danleifeng 已提交
397 398 399
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
#endif
400
  }
401
  timeline.Pause();
402
  VLOG(1) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
D
danleifeng 已提交
403
          << " seconds, batch_num: " << total_batch_num;
W
wangguanqun 已提交
404 405 406 407 408

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
409
#if defined PADDLE_WITH_PSCORE
410
  if (thread_barrier_) {
T
tangwei12 已提交
411
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
412 413
  }
#endif
414 415
}

D
dongdaxiang 已提交
416 417 418 419
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
420 421 422 423 424 425 426 427 428
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
429 430 431 432 433
    std::array<char, 80> mbstr;
    std::strftime(mbstr.data(),
                  sizeof(mbstr),
                  "%Y-%m-%d %H:%M:%S",
                  std::localtime(&curtime));
T
tangwei12 已提交
434 435

    std::stringstream ss;
436
    ss << "time: [" << mbstr.data() << "], ";
T
tangwei12 已提交
437 438 439
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
440 441 442 443
      platform::PrintVar(thread_scope_,
                         fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i),
                         &ss);
T
tangwei12 已提交
444 445
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
446 447
      }
    }
T
tangwei12 已提交
448 449

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
450 451 452
  }
}

453 454
}  // end namespace framework
}  // end namespace paddle