hogwild_worker.cc 9.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16 17

#include "paddle/fluid/framework/convert_utils.h"
18
#include "paddle/fluid/framework/data_type.h"
19
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
20
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
21
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
22
#include "paddle/fluid/platform/lodtensor_printer.h"
23

T
tangwei12 已提交
24
#if defined PADDLE_WITH_PSCORE
25
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
26 27
#endif

28 29 30
namespace paddle {
namespace framework {

31
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
32
  fetch_config_ = desc.fetch_config();
33 34
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
35
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
36 37
    skip_ops_[i] = param_.skip_ops(i);
  }
38
  use_cvm_ = desc.use_cvm();
39
  thread_barrier_ = desc.thread_barrier();
40

41 42 43
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
44 45
}

46 47
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
48
  op_names_.clear();
49
  for (auto &op_desc : block.AllOps()) {
50 51
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
52
    OperatorBase *local_op_ptr = local_op.release();
53 54 55
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
56 57
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
58 59
}

60 61
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
62 63

  PADDLE_ENFORCE_NOT_NULL(
64 65 66
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
67 68

  thread_scope_ = &root_scope_->NewScope();
69 70

  for (auto &var : block.AllVars()) {
71
    all_param_.push_back(var->Name());
72
    if (var->Persistable()) {
73
      auto *ptr = root_scope_->Var(var->Name());
74
      InitializeVariable(ptr, var->GetType());
75 76 77 78 79 80 81 82 83
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
84 85 86 87 88
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
89 90 91
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
92
    } else {
93
      auto *ptr = thread_scope_->Var(var->Name());
94 95 96 97 98
      InitializeVariable(ptr, var->GetType());
    }
  }
}

99
template <typename T>
100 101
void HogwildWorker::SetZero(LoDTensor *tensor,
                            LoDTensor *root_tensor,
102 103 104 105 106
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

107
void HogwildWorker::BindingDataFeedMemory() {
108
  const std::vector<std::string> &input_feed =
109
      device_reader_->GetUseSlotAlias();
110
  for (auto name : input_feed) {
111
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
112 113 114
  }
}

115
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
116 117 118 119 120 121
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
D
danleifeng 已提交
122 123 124 125 126 127
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
128
  device_reader_->Start();
129 130
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
131
  for (auto &op : ops_) {
132 133 134 135 136 137 138 139 140 141 142 143
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
144
  uint64_t total_inst = 0;
145
  while ((cur_batch = device_reader_->Next()) > 0) {
146
    VLOG(3) << "read a batch in thread " << thread_id_;
147 148 149 150
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
151 152 153 154 155 156 157
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
158
      timeline.Start();
159
      VLOG(3) << "Going to run op " << op_name[i];
160 161
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
162 163 164
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
165
      }
166
      VLOG(3) << "Op " << op_name[i] << " Finished";
167 168 169 170
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
171 172

    if (need_dump_field_) {
H
hutuxian 已提交
173 174 175 176
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
177 178
    }

D
dongdaxiang 已提交
179
    total_inst += cur_batch;
180
    ++batch_cnt;
D
dongdaxiang 已提交
181
    PrintFetchVars();
182 183 184 185 186 187 188 189
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
190 191 192
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
193 194 195 196 197
          fprintf(stderr,
                  "op_name:[%zu][%s], op_mean_time:[%fs]\n",
                  i,
                  op_name[i].c_str(),
                  op_total_time[i] / batch_cnt);
198 199
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
200
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
201
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
202 203
      }
    }
204
#endif
D
dongdaxiang 已提交
205
    thread_scope_->DropKids();
206 207
    timeline.Start();
  }
D
danleifeng 已提交
208 209 210
  VLOG(0) << "GpuPs worker " << thread_id_ << " train cost " << total_time
          << " seconds, ins_num: " << total_inst << " read time: " << read_time
          << "seconds ";
211

H
hutuxian 已提交
212
  if (need_dump_field_ || need_dump_param_) {
213 214 215
    writer_.Flush();
  }

T
tangwei12 已提交
216
#if defined PADDLE_WITH_PSCORE
217
  if (thread_barrier_) {
T
tangwei12 已提交
218
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
219 220
  }
#endif
221 222 223 224
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
225 226
  platform::Timer timeline;
  timeline.Start();
D
danleifeng 已提交
227 228 229 230 231 232
#if defined(PADDLE_WITH_HETERPS) && \
    (defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL))
  platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_HETERPS) && defined(PADDLE_WITH_XPU_BKCL)
  platform::SetXPUDeviceId(thread_id_);
#endif
233

D
danleifeng 已提交
234
  int total_batch_num = 0;
235
  // how to accumulate fetched values here
236
  device_reader_->Start();
237
  int cur_batch;
W
wangguanqun 已提交
238
  int batch_cnt = 0;
D
danleifeng 已提交
239

240
  while ((cur_batch = device_reader_->Next()) > 0) {
241
    for (auto &op : ops_) {
242 243 244 245 246 247 248 249 250 251
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
252 253
    }

W
wangguanqun 已提交
254 255 256 257 258 259 260
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

D
danleifeng 已提交
261
    total_batch_num += cur_batch;
W
wangguanqun 已提交
262
    ++batch_cnt;
D
dongdaxiang 已提交
263
    PrintFetchVars();
D
dongdaxiang 已提交
264
    thread_scope_->DropKids();
D
danleifeng 已提交
265 266 267
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
#endif
268
  }
269
  timeline.Pause();
D
danleifeng 已提交
270 271
  VLOG(0) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
          << " seconds, batch_num: " << total_batch_num;
W
wangguanqun 已提交
272 273 274 275 276

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
277
#if defined PADDLE_WITH_PSCORE
278
  if (thread_barrier_) {
T
tangwei12 已提交
279
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
280 281
  }
#endif
282 283
}

D
dongdaxiang 已提交
284 285 286 287
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
288 289 290 291 292 293 294 295 296 297
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
298 299
    std::strftime(
        mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S", std::localtime(&curtime));
T
tangwei12 已提交
300 301 302 303 304 305

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
306 307 308 309
      platform::PrintVar(thread_scope_,
                         fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i),
                         &ss);
T
tangwei12 已提交
310 311
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
312 313
      }
    }
T
tangwei12 已提交
314 315

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
316 317 318
  }
}

319 320
}  // end namespace framework
}  // end namespace paddle