hogwild_worker.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16

L
lxsbupt 已提交
17
#include "paddle/fluid/framework/barrier.h"
18
#include "paddle/fluid/framework/convert_utils.h"
19
#include "paddle/fluid/framework/data_type.h"
20
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
22
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
23
#include "paddle/fluid/platform/lodtensor_printer.h"
24

T
tangwei12 已提交
25
#if defined PADDLE_WITH_PSCORE
26
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
27 28
#endif

29 30 31 32
#if defined(PADDLE_WITH_GLOO)
#include "paddle/fluid/framework/fleet/gloo_wrapper.h"
#endif

L
lxsbupt 已提交
33 34
DECLARE_bool(enable_exit_when_partial_worker);

35 36 37
namespace paddle {
namespace framework {

38
std::atomic<bool> HogwildWorker::quit_flag_(false);
L
lxsbupt 已提交
39 40
Barrier g_barrier;

41
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
42
  fetch_config_ = desc.fetch_config();
43 44
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
45
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
46 47
    skip_ops_[i] = param_.skip_ops(i);
  }
48
  use_cvm_ = desc.use_cvm();
49
  thread_barrier_ = desc.thread_barrier();
50

51 52 53
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
54 55
}

56 57
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
58
  op_names_.clear();
59
  for (auto &op_desc : block.AllOps()) {
60 61
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
62
    OperatorBase *local_op_ptr = local_op.release();
63 64 65
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
66 67
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
68 69
}

70 71
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
72 73

  PADDLE_ENFORCE_NOT_NULL(
74 75 76
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
77 78

  thread_scope_ = &root_scope_->NewScope();
79 80

  for (auto &var : block.AllVars()) {
81
    all_param_.push_back(var->Name());
82
    if (var->Persistable()) {
83
      auto *ptr = root_scope_->Var(var->Name());
84
      InitializeVariable(ptr, var->GetType());
85 86
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
87 88 89
        int tensor_dim = root_scope_->FindVar(var->Name())
                             ->GetMutable<phi::DenseTensor>()
                             ->numel();
90 91
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
92 93 94
        phi::DenseTensor *thread_tensor = ptr1->GetMutable<phi::DenseTensor>();
        phi::DenseTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<phi::DenseTensor>();
95 96 97 98 99
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
100 101 102
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
103
    } else {
104
      auto *ptr = thread_scope_->Var(var->Name());
105 106 107 108 109
      InitializeVariable(ptr, var->GetType());
    }
  }
}

110
template <typename T>
111 112
void HogwildWorker::SetZero(phi::DenseTensor *tensor,
                            phi::DenseTensor *root_tensor,
113 114 115 116 117
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

118
void HogwildWorker::BindingDataFeedMemory() {
119
  const std::vector<std::string> &input_feed =
120
      device_reader_->GetUseSlotAlias();
121
  for (auto name : input_feed) {
122
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
123 124 125
  }
}

126
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
127 128 129
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);

130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  float *stat_ptr = sync_stat_.mutable_data<float>(place_, sizeof(float) * 3);
  float flags[] = {0.0, 1.0, 0.0};
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(stat_ptr,  // output
                                             &flags,
                                             sizeof(float) * 3,
                                             cudaMemcpyHostToDevice,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
#endif
}
// check batch num
bool HogwildWorker::CheckBatchNum(int flag) {
  float ret = 0.0;
#if defined(PADDLE_WITH_CUDA) && defined(PADDLE_WITH_GPU_GRAPH)
  if (flag > 1) {
    flag = 1;
  } else if (flag < 0) {
    flag = 0;
  }
  g_barrier.wait();
  float *stat_ptr = sync_stat_.data<float>();
  auto comm =
      platform::NCCLCommContext::Instance().Get(0, place_.GetDeviceId());
  auto stream = static_cast<phi::GPUContext *>(dev_ctx_)->stream();
  PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::ncclAllReduce(&stat_ptr[flag],
                                                              &stat_ptr[2],
                                                              1,
                                                              ncclFloat32,
                                                              ncclProd,
                                                              comm->comm(),
                                                              stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpyAsync(&ret,  // output
                                             &stat_ptr[2],
                                             sizeof(float),
                                             cudaMemcpyDeviceToHost,
                                             stream));
  PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamSynchronize(stream));
  g_barrier.wait();
#endif
  return (ret > 0.0);
}
173 174
void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
D
danleifeng 已提交
175 176 177 178 179 180
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
181
  device_reader_->Start();
182 183
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
184
  for (auto &op : ops_) {
185 186 187 188 189 190 191 192 193 194 195
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
L
lxsbupt 已提交
196
  if (thread_id_ == 0) {
197
    quit_flag_.store(false);
L
lxsbupt 已提交
198 199
  }
  g_barrier.wait();
200
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
L
lxsbupt 已提交
201
  bool train_mode = device_reader_->IsTrainMode();
202 203 204 205 206 207 208
  bool is_multi_node = false;
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif

209
  timeline.Start();
D
dongdaxiang 已提交
210
  uint64_t total_inst = 0;
L
lxsbupt 已提交
211 212 213 214 215
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
216 217 218
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
219 220
        break;
      }
221 222 223 224 225 226 227 228 229 230
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
231
    }
232
#endif
L
lxsbupt 已提交
233 234 235
    if (cur_batch <= 0) {
      break;
    }
236
    VLOG(3) << "read a batch in thread " << thread_id_;
237 238 239 240
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
241 242 243 244 245 246 247
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
248
      timeline.Start();
249
      VLOG(3) << "Going to run op " << op_name[i];
250 251
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
252 253 254
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
255
      }
256
      VLOG(3) << "Op " << op_name[i] << " Finished";
257 258 259 260
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
261 262

    if (need_dump_field_) {
H
hutuxian 已提交
263 264 265 266
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
267 268
    }

D
dongdaxiang 已提交
269
    total_inst += cur_batch;
270
    ++batch_cnt;
D
dongdaxiang 已提交
271
    PrintFetchVars();
272 273 274 275 276 277 278 279
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
280 281 282
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
283 284 285 286 287
          fprintf(stderr,
                  "op_name:[%zu][%s], op_mean_time:[%fs]\n",
                  i,
                  op_name[i].c_str(),
                  op_total_time[i] / batch_cnt);
288 289
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
290
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
291
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
292 293
      }
    }
294
#endif
D
dongdaxiang 已提交
295
    thread_scope_->DropKids();
296 297
    timeline.Start();
  }
D
danleifeng 已提交
298 299 300
  VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time
          << " seconds, ins_num: " << total_inst << " read time: " << read_time
          << "seconds ";
301

H
hutuxian 已提交
302
  if (need_dump_field_ || need_dump_param_) {
303 304 305
    writer_.Flush();
  }

T
tangwei12 已提交
306
#if defined PADDLE_WITH_PSCORE
307
  if (thread_barrier_) {
T
tangwei12 已提交
308
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
309 310
  }
#endif
311 312 313
}
void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
314 315
  platform::Timer timeline;
  timeline.Start();
D
danleifeng 已提交
316 317 318 319 320 321
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
322

D
danleifeng 已提交
323
  int total_batch_num = 0;
324
  // how to accumulate fetched values here
325
  device_reader_->Start();
326
  int cur_batch;
W
wangguanqun 已提交
327
  int batch_cnt = 0;
L
lxsbupt 已提交
328
  if (thread_id_ == 0) {
329 330
    quit_flag_.store(false);
    // quit_flag_2 = false;
L
lxsbupt 已提交
331 332
  }
  g_barrier.wait();
D
danleifeng 已提交
333

L
lxsbupt 已提交
334 335 336 337
#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
  platform::SetDeviceId(thread_id_);
#endif
  // while ((cur_batch = device_reader_->Next()) > 0) {
338 339
#if defined(PADDLE_WITH_GLOO) && defined(PADDLE_WITH_GPU_GRAPH)
  bool is_multi_node = false;
L
lxsbupt 已提交
340
  bool train_mode = device_reader_->IsTrainMode();
341 342 343 344 345
  auto gloo = paddle::framework::GlooWrapper::GetInstance();
  if (gloo->Size() > 1) {
    is_multi_node = true;
  }
#endif
L
lxsbupt 已提交
346 347 348 349 350
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
351 352 353
#if defined(PADDLE_WITH_GPU_GRAPH)
    if (is_multi_node) {
      if (!CheckBatchNum(cur_batch)) {
L
lxsbupt 已提交
354 355
        break;
      }
356 357 358 359 360 361 362 363 364 365
    } else {
      if (FLAGS_enable_exit_when_partial_worker && train_mode) {
        if (cur_batch <= 0) {
          quit_flag_.store(true, std::memory_order_relaxed);
        }
        g_barrier.wait();
        if (quit_flag_.load(std::memory_order_relaxed) == true) {
          break;
        }
      }
L
lxsbupt 已提交
366
    }
367
#endif
L
lxsbupt 已提交
368 369 370
    if (cur_batch <= 0) {
      break;
    }
371
    for (auto &op : ops_) {
372 373 374 375 376 377 378 379 380 381
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
382 383
    }

W
wangguanqun 已提交
384 385 386 387 388 389 390
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

D
danleifeng 已提交
391
    total_batch_num += cur_batch;
W
wangguanqun 已提交
392
    ++batch_cnt;
D
dongdaxiang 已提交
393
    PrintFetchVars();
D
dongdaxiang 已提交
394
    thread_scope_->DropKids();
D
danleifeng 已提交
395 396 397
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
#endif
398
  }
399
  timeline.Pause();
400
  VLOG(1) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
D
danleifeng 已提交
401
          << " seconds, batch_num: " << total_batch_num;
W
wangguanqun 已提交
402 403 404 405 406

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
407
#if defined PADDLE_WITH_PSCORE
408
  if (thread_barrier_) {
T
tangwei12 已提交
409
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
410 411
  }
#endif
412 413
}

D
dongdaxiang 已提交
414 415 416 417
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
418 419 420 421 422 423 424 425 426 427
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
428 429
    std::strftime(
        mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", std::localtime(&curtime));
T
tangwei12 已提交
430 431 432 433 434 435

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
436 437 438 439
      platform::PrintVar(thread_scope_,
                         fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i),
                         &ss);
T
tangwei12 已提交
440 441
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
442 443
      }
    }
T
tangwei12 已提交
444 445

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
446 447 448
  }
}

449 450
}  // end namespace framework
}  // end namespace paddle