hogwild_worker.cc 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16

L
lxsbupt 已提交
17
#include "paddle/fluid/framework/barrier.h"
18
#include "paddle/fluid/framework/convert_utils.h"
19
#include "paddle/fluid/framework/data_type.h"
20
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
21
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
22
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
23
#include "paddle/fluid/platform/lodtensor_printer.h"
24

T
tangwei12 已提交
25
#if defined PADDLE_WITH_PSCORE
26
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
27 28
#endif

L
lxsbupt 已提交
29 30
DECLARE_bool(enable_exit_when_partial_worker);

31 32 33
namespace paddle {
namespace framework {

L
lxsbupt 已提交
34 35 36
std::atomic<uint64_t> HogwildWorker::worker_num_stat_(0);
Barrier g_barrier;

37
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
38
  fetch_config_ = desc.fetch_config();
39 40
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
41
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
42 43
    skip_ops_[i] = param_.skip_ops(i);
  }
44
  use_cvm_ = desc.use_cvm();
45
  thread_barrier_ = desc.thread_barrier();
46

47 48 49
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
50 51
}

52 53
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
54
  op_names_.clear();
55
  for (auto &op_desc : block.AllOps()) {
56 57
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
58
    OperatorBase *local_op_ptr = local_op.release();
59 60 61
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
62 63
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
64 65
}

66 67
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
68 69

  PADDLE_ENFORCE_NOT_NULL(
70 71 72
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
73 74

  thread_scope_ = &root_scope_->NewScope();
75 76

  for (auto &var : block.AllVars()) {
77
    all_param_.push_back(var->Name());
78
    if (var->Persistable()) {
79
      auto *ptr = root_scope_->Var(var->Name());
80
      InitializeVariable(ptr, var->GetType());
81 82
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
83 84 85
        int tensor_dim = root_scope_->FindVar(var->Name())
                             ->GetMutable<phi::DenseTensor>()
                             ->numel();
86 87
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
88 89 90
        phi::DenseTensor *thread_tensor = ptr1->GetMutable<phi::DenseTensor>();
        phi::DenseTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<phi::DenseTensor>();
91 92 93 94 95
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
96 97 98
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
99
    } else {
100
      auto *ptr = thread_scope_->Var(var->Name());
101 102 103 104 105
      InitializeVariable(ptr, var->GetType());
    }
  }
}

106
template <typename T>
107 108
void HogwildWorker::SetZero(phi::DenseTensor *tensor,
                            phi::DenseTensor *root_tensor,
109 110 111 112 113
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

114
void HogwildWorker::BindingDataFeedMemory() {
115
  const std::vector<std::string> &input_feed =
116
      device_reader_->GetUseSlotAlias();
117
  for (auto name : input_feed) {
118
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
119 120 121
  }
}

122
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
123 124 125 126 127 128
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
D
danleifeng 已提交
129 130 131 132 133 134
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
135
  device_reader_->Start();
136 137
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
138
  for (auto &op : ops_) {
139 140 141 142 143 144 145 146 147 148 149
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
L
lxsbupt 已提交
150 151 152 153 154
  if (thread_id_ == 0) {
    worker_num_stat_.store(0);
  }
  g_barrier.wait();
  bool train_mode = device_reader_->IsTrainMode();
155
  timeline.Start();
D
dongdaxiang 已提交
156
  uint64_t total_inst = 0;
L
lxsbupt 已提交
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
    if (FLAGS_enable_exit_when_partial_worker && train_mode) {
      if (cur_batch > 0) {
        worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
      }
      g_barrier.wait();
      if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
        break;
      }
    }
    if (cur_batch <= 0) {
      break;
    }
174
    VLOG(3) << "read a batch in thread " << thread_id_;
175 176 177 178
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
179 180 181 182 183 184 185
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
186
      timeline.Start();
187
      VLOG(3) << "Going to run op " << op_name[i];
188 189
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
190 191 192
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
193
      }
194
      VLOG(3) << "Op " << op_name[i] << " Finished";
195 196 197 198
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
199 200

    if (need_dump_field_) {
H
hutuxian 已提交
201 202 203 204
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
205 206
    }

D
dongdaxiang 已提交
207
    total_inst += cur_batch;
208
    ++batch_cnt;
D
dongdaxiang 已提交
209
    PrintFetchVars();
210 211 212 213 214 215 216 217
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
218 219 220
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
221 222 223 224 225
          fprintf(stderr,
                  "op_name:[%zu][%s], op_mean_time:[%fs]\n",
                  i,
                  op_name[i].c_str(),
                  op_total_time[i] / batch_cnt);
226 227
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
228
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
229
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
230 231
      }
    }
232
#endif
D
dongdaxiang 已提交
233
    thread_scope_->DropKids();
234 235
    timeline.Start();
  }
D
danleifeng 已提交
236 237 238
  VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time
          << " seconds, ins_num: " << total_inst << " read time: " << read_time
          << "seconds ";
239

H
hutuxian 已提交
240
  if (need_dump_field_ || need_dump_param_) {
241 242 243
    writer_.Flush();
  }

T
tangwei12 已提交
244
#if defined PADDLE_WITH_PSCORE
245
  if (thread_barrier_) {
T
tangwei12 已提交
246
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
247 248
  }
#endif
249 250 251 252
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
253 254
  platform::Timer timeline;
  timeline.Start();
D
danleifeng 已提交
255 256 257 258 259 260
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
261

D
danleifeng 已提交
262
  int total_batch_num = 0;
263
  // how to accumulate fetched values here
264
  device_reader_->Start();
265
  int cur_batch;
W
wangguanqun 已提交
266
  int batch_cnt = 0;
L
lxsbupt 已提交
267 268 269 270
  if (thread_id_ == 0) {
    worker_num_stat_.store(0);
  }
  g_barrier.wait();
D
danleifeng 已提交
271

L
lxsbupt 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
#if defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_CUDA)
  platform::SetDeviceId(thread_id_);
#endif
  // while ((cur_batch = device_reader_->Next()) > 0) {
  bool train_mode = device_reader_->IsTrainMode();
#if defined(PADDLE_WITH_GPU_GRAPH) && defined(PADDLE_WITH_HETERPS)
  device_reader_->InitGraphTrainResource();
#endif
  while (1) {
    cur_batch = device_reader_->Next();
    if (FLAGS_enable_exit_when_partial_worker && train_mode) {
      if (cur_batch > 0) {
        worker_num_stat_.fetch_add(1, std::memory_order_relaxed);
      }
      g_barrier.wait();
      if (worker_num_stat_.load(std::memory_order_relaxed) % thread_num_ != 0) {
        break;
      }
    }
    if (cur_batch <= 0) {
      break;
    }
294
    for (auto &op : ops_) {
295 296 297 298 299 300 301 302 303 304
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
305 306
    }

W
wangguanqun 已提交
307 308 309 310 311 312 313
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

D
danleifeng 已提交
314
    total_batch_num += cur_batch;
W
wangguanqun 已提交
315
    ++batch_cnt;
D
dongdaxiang 已提交
316
    PrintFetchVars();
D
dongdaxiang 已提交
317
    thread_scope_->DropKids();
D
danleifeng 已提交
318 319 320
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
#endif
321
  }
322
  timeline.Pause();
D
danleifeng 已提交
323 324
  VLOG(0) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
          << " seconds, batch_num: " << total_batch_num;
W
wangguanqun 已提交
325 326 327 328 329

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
330
#if defined PADDLE_WITH_PSCORE
331
  if (thread_barrier_) {
T
tangwei12 已提交
332
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
333 334
  }
#endif
335 336
}

D
dongdaxiang 已提交
337 338 339 340
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
341 342 343 344 345 346 347 348 349 350
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
351 352
    std::strftime(
        mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", std::localtime(&curtime));
T
tangwei12 已提交
353 354 355 356 357 358

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
359 360 361 362
      platform::PrintVar(thread_scope_,
                         fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i),
                         &ss);
T
tangwei12 已提交
363 364
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
365 366
      }
    }
T
tangwei12 已提交
367 368

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
369 370 371
  }
}

372 373
}  // end namespace framework
}  // end namespace paddle