hogwild_worker.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16

L
lxsbupt 已提交
17
#include "paddle/fluid/framework/barrier.h"
18
#include "paddle/fluid/framework/convert_utils.h"
19
#include "paddle/fluid/framework/data_type.h"
20
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
22
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
23
#include "paddle/fluid/platform/lodtensor_printer.h"
24

T
tangwei12 已提交
25
#if defined PADDLE_WITH_PSCORE
26
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
27 28
#endif

29 30 31
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif
32
#include "paddle/phi/core/flags.h"
33

34
PHI_DECLARE_bool(enable_exit_when_partial_worker);
L
lxsbupt 已提交
35

36 37 38
namespace paddle {
namespace framework {

39
std::atomic<bool> HogwildWorker::quit_flag_(false);
L
lxsbupt 已提交
40 41
Barrier g_barrier;

42
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
43
  fetch_config_ = desc.fetch_config();
44 45
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
46
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
47 48
    skip_ops_[i] = param_.skip_ops(i);
  }
49
  use_cvm_ = desc.use_cvm();
50
  thread_barrier_ = desc.thread_barrier();
51

52 53 54
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
55 56
}

57 58
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
59
  op_names_.clear();
60
  for (auto &op_desc : block.AllOps()) {
61 62
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
63
    OperatorBase *local_op_ptr = local_op.release();
64 65 66
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
67 68
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
69 70
}

71 72
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
73 74

  PADDLE_ENFORCE_NOT_NULL(
75 76 77
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
78 79

  thread_scope_ = &root_scope_->NewScope();
80 81

  for (auto &var : block.AllVars()) {
82
    all_param_.push_back(var->Name());
83
    if (var->Persistable()) {
84
      auto *ptr = root_scope_->Var(var->Name());
85
      InitializeVariable(ptr, var->GetType());
86 87
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
88 89 90
        int tensor_dim = root_scope_->FindVar(var->Name())
                             ->GetMutable<phi::DenseTensor>()
                             ->numel();
91 92
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
93 94 95
        phi::DenseTensor *thread_tensor = ptr1->GetMutable<phi::DenseTensor>();
        phi::DenseTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<phi::DenseTensor>();
96 97 98 99 100
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
101 102 103
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
104
    } else {
105
      auto *ptr = thread_scope_->Var(var->Name());
106 107 108 109 110
      InitializeVariable(ptr, var->GetType());
    }
  }
}

111
template <typename T>
112 113
void HogwildWorker::SetZero(phi::DenseTensor *tensor,
                            phi::DenseTensor *root_tensor,
114 115 116 117 118
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

119
void HogwildWorker::BindingDataFeedMemory() {
120
  const std::vector<std::string> &input_feed =
121
      device_reader_->GetUseSlotAlias();
122
  for (auto name : input_feed) {
123
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
124 125 126
  }
}

127
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
128 129 130
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);

131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  float *stat_ptr = sync_stat_.mutable_data<float>(place_, sizeof(float) * 3);
  float flags[] = {0.0, 1.0, 0.0};
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(stat_ptr,  // output
                                             &flags,
                                             sizeof(float) * 3,
                                             cudaMemcpyHostToDevice,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
// check batch num
bool HogwildWorker::CheckBatchNum(int flag) {
  float ret = 0.0;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  if (flag > 1) {
    flag = 1;
  } else if (flag < 0) {
    flag = 0;
  }
  g_barrier.wait();
  float *stat_ptr = sync_stat_.data<float>();
  auto comm =
      platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId());
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag],
                                                              &stat_ptr[2],
                                                              1,
                                                              ncclFloat32,
                                                              ncclProd,
                                                              comm->comm(),
                                                              stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret,  // output
                                             &stat_ptr[2],
                                             sizeof(float),
                                             cudaMemcpyDeviceToHost,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
  g_barrier.wait();
#endif
  return (ret > 0.0);
}
174 175
void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
D
danleifeng 已提交
176 177 178 179 180 181
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
182
  device_reader_->Start();
183 184
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
185
  for (auto &op : ops_) {
186 187 188
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
189 190
  for (double &op_time : op_total_time) {
    op_time = 0.0;
191 192 193 194 195 196
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
L
lxsbupt 已提交
197
  if (thread_id_ == 0) {
198
    quit_flag_.store(false);
L
lxsbupt 已提交
199 200
  }
  g_barrier.wait();
201
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
L
lxsbupt 已提交
202
  bool train_mode = device_reader_->IsTrainMode();
203 204 205 206 207 208 209
  bool is_multi_node = false;
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif

210
  timeline.Start();
D
dongdaxiang 已提交
211
  uint64_t total_inst = 0;
L
lxsbupt 已提交
212 213 214 215 216
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
217 218 219
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
220 221
        break;
      }
222 223 224 225 226 227 228 229 230 231
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
232
    }
233
#endif
L
lxsbupt 已提交
234 235 236
    if (cur_batch <= 0) {
      break;
    }
237
    VLOG(3) << "read a batch in thread " << thread_id_;
238 239 240 241
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
242
      bool need_skip = false;
243 244
      for (auto &skip_op : skip_ops_) {
        if (ops_[i]->Type().find(skip_op) != std::string::npos) {
245 246 247 248
          need_skip = true;
          break;
        }
      }
249
      timeline.Start();
250
      VLOG(3) << "Going to run op " << op_name[i];
251 252
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
253 254 255
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
256
      }
257
      VLOG(3) << "Op " << op_name[i] << " Finished";
258 259 260 261
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
262 263

    if (need_dump_field_) {
H
hutuxian 已提交
264 265 266 267
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
268 269
    }

D
dongdaxiang 已提交
270
    total_inst += cur_batch;
271
    ++batch_cnt;
D
dongdaxiang 已提交
272
    PrintFetchVars();
273 274 275 276 277 278 279 280
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
281 282 283
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
284 285 286 287 288
          fprintf(stderr,
                  "op_name:[%zu][%s], op_mean_time:[%fs]\n",
                  i,
                  op_name[i].c_str(),
                  op_total_time[i] / batch_cnt);
289 290
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
291
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
292
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
293 294
      }
    }
295
#endif
D
dongdaxiang 已提交
296
    thread_scope_->DropKids();
297 298
    timeline.Start();
  }
D
danleifeng 已提交
299 300 301
  VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time
          << " seconds, ins_num: " << total_inst << " read time: " << read_time
          << "seconds ";
302

H
hutuxian 已提交
303
  if (need_dump_field_ || need_dump_param_) {
304 305 306
    writer_.Flush();
  }

T
tangwei12 已提交
307
#if defined PADDLE_WITH_PSCORE
308
  if (thread_barrier_) {
T
tangwei12 已提交
309
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
310 311
  }
#endif
312 313 314
}
void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
315 316
  platform::Timer timeline;
  timeline.Start();
D
danleifeng 已提交
317 318 319 320 321 322
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
323

D
danleifeng 已提交
324
  int total_batch_num = 0;
325
  // how to accumulate fetched values here
326
  device_reader_->Start();
327
  int cur_batch;
W
wangguanqun 已提交
328
  int batch_cnt = 0;
L
lxsbupt 已提交
329
  if (thread_id_ == 0) {
330 331
    quit_flag_.store(false);
    // quit_flag_2 = false;
L
lxsbupt 已提交
332 333
  }
  g_barrier.wait();
D
danleifeng 已提交
334

L
lxsbupt 已提交
335 336 337 338
#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
  platform::SetDeviceId(thread_id_);
#endif
  // while ((cur_batch = device_reader_->Next()) > 0) {
339 340
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
  bool is_multi_node = false;
L
lxsbupt 已提交
341
  bool train_mode = device_reader_->IsTrainMode();
342 343 344 345 346
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif
L
lxsbupt 已提交
347 348 349 350 351
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
352 353 354
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
355 356
        break;
      }
357 358 359 360 361 362 363 364 365 366
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
367
    }
368
#endif
L
lxsbupt 已提交
369 370 371
    if (cur_batch <= 0) {
      break;
    }
372
    for (auto &op : ops_) {
373
      bool need_skip = false;
374 375
      for (auto &skip_op : skip_ops_) {
        if (op->Type().find(skip_op) != std::string::npos) {
376 377 378 379 380 381 382
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
383 384
    }

W
wangguanqun 已提交
385 386 387 388 389 390 391
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

D
danleifeng 已提交
392
    total_batch_num += cur_batch;
W
wangguanqun 已提交
393
    ++batch_cnt;
D
dongdaxiang 已提交
394
    PrintFetchVars();
D
dongdaxiang 已提交
395
    thread_scope_->DropKids();
D
danleifeng 已提交
396 397 398
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
#endif
399
  }
400
  timeline.Pause();
401
  VLOG(1) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
D
danleifeng 已提交
402
          << " seconds, batch_num: " << total_batch_num;
W
wangguanqun 已提交
403 404 405 406 407

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
408
#if defined PADDLE_WITH_PSCORE
409
  if (thread_barrier_) {
T
tangwei12 已提交
410
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
411 412
  }
#endif
413 414
}

D
dongdaxiang 已提交
415 416 417 418
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
419 420 421 422 423 424 425 426 427 428
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
429 430
    std::strftime(
        mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", std::localtime(&curtime));
T
tangwei12 已提交
431 432 433 434 435 436

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
437 438 439 440
      platform::PrintVar(thread_scope_,
                         fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i),
                         &ss);
T
tangwei12 已提交
441 442
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
443 444
      }
    }
T
tangwei12 已提交
445 446

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
447 448 449
  }
}

450 451
}  // end namespace framework
}  // end namespace paddle