hogwild_worker.cc 9.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16
#include "paddle/fluid/framework/data_type.h"
17
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
18
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
19
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
20
#include "paddle/fluid/platform/lodtensor_printer.h"
21

T
tangwei12 已提交
22
#if defined PADDLE_WITH_PSCORE
T
tangwei12 已提交
23 24 25
#include "paddle/fluid/distributed/service/communicator.h"
#endif

26 27 28
namespace paddle {
namespace framework {

29
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
30
  fetch_config_ = desc.fetch_config();
31 32
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
33
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
34 35
    skip_ops_[i] = param_.skip_ops(i);
  }
36
  use_cvm_ = desc.use_cvm();
37
  thread_barrier_ = desc.thread_barrier();
38

39 40 41
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
42 43
}

44 45
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
46
  op_names_.clear();
47
  for (auto &op_desc : block.AllOps()) {
48 49
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
50
    OperatorBase *local_op_ptr = local_op.release();
51 52 53
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
54 55
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
56 57
}

58 59
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
60 61

  PADDLE_ENFORCE_NOT_NULL(
62 63 64
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
65 66

  thread_scope_ = &root_scope_->NewScope();
67 68

  for (auto &var : block.AllVars()) {
69
    all_param_.push_back(var->Name());
70
    if (var->Persistable()) {
71
      auto *ptr = root_scope_->Var(var->Name());
72
      InitializeVariable(ptr, var->GetType());
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
90
    } else {
91
      auto *ptr = thread_scope_->Var(var->Name());
92 93 94 95 96
      InitializeVariable(ptr, var->GetType());
    }
  }
}

97 98 99 100 101 102 103
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

104
void HogwildWorker::BindingDataFeedMemory() {
105
  const std::vector<std::string> &input_feed =
106
      device_reader_->GetUseSlotAlias();
107
  for (auto name : input_feed) {
108
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
109 110 111
  }
}

112
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
113 114 115 116 117 118
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
119
  device_reader_->Start();
120 121
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
122
  for (auto &op : ops_) {
123 124 125 126 127 128 129 130 131 132 133 134
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
135
  uint64_t total_inst = 0;
136
  while ((cur_batch = device_reader_->Next()) > 0) {
137
    VLOG(3) << "read a batch in thread " << thread_id_;
138 139 140 141
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
142 143 144 145 146 147 148
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
149
      timeline.Start();
150
      VLOG(3) << "Going to run op " << op_name[i];
151 152
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
153 154 155
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
156
      }
157
      VLOG(3) << "Op " << op_name[i] << " Finished";
158 159 160 161
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
162 163

    if (need_dump_field_) {
H
hutuxian 已提交
164 165 166 167
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
168 169
    }

D
dongdaxiang 已提交
170
    total_inst += cur_batch;
171
    ++batch_cnt;
D
dongdaxiang 已提交
172
    PrintFetchVars();
173 174 175 176 177 178 179 180 181 182
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time
            << " seconds, ins_num: " << total_inst;
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
183 184 185 186 187 188 189
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
190
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
191
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
192 193
      }
    }
194
#endif
D
dongdaxiang 已提交
195
    thread_scope_->DropKids();
196 197
    timeline.Start();
  }
198

H
hutuxian 已提交
199
  if (need_dump_field_ || need_dump_param_) {
200 201 202
    writer_.Flush();
  }

T
tangwei12 已提交
203
#if defined PADDLE_WITH_PSCORE
204
  if (thread_barrier_) {
T
tangwei12 已提交
205
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
206 207
  }
#endif
208 209 210 211
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
212 213
  platform::Timer timeline;
  timeline.Start();
214

215
  int total_ins_num = 0;
216
  // how to accumulate fetched values here
217
  device_reader_->Start();
218
  int cur_batch;
W
wangguanqun 已提交
219
  int batch_cnt = 0;
220
  while ((cur_batch = device_reader_->Next()) > 0) {
221
    for (auto &op : ops_) {
222 223 224 225 226 227 228 229 230 231
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
232 233
    }

W
wangguanqun 已提交
234 235 236 237 238 239 240
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

241
    total_ins_num += cur_batch;
W
wangguanqun 已提交
242
    ++batch_cnt;
D
dongdaxiang 已提交
243
    PrintFetchVars();
D
dongdaxiang 已提交
244
    thread_scope_->DropKids();
245
  }
246 247 248
  timeline.Pause();
  VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
          << " seconds, ins_num: " << total_ins_num;
W
wangguanqun 已提交
249 250 251 252 253

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
254
#if defined PADDLE_WITH_PSCORE
255
  if (thread_barrier_) {
T
tangwei12 已提交
256
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
257 258
  }
#endif
259 260
}

D
dongdaxiang 已提交
261 262 263 264
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
    std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S",
                  std::localtime(&curtime));

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
      platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i), &ss);
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
287 288
      }
    }
T
tangwei12 已提交
289 290

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
291 292 293
  }
}

294 295
}  // end namespace framework
}  // end namespace paddle