hogwild_worker.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16
#include "paddle/fluid/framework/data_type.h"
17
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
18
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
19
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
20
#include "paddle/fluid/platform/lodtensor_printer.h"
21

T
tangwei12 已提交
22
#if defined PADDLE_WITH_PSCORE
T
tangwei12 已提交
23 24 25
#include "paddle/fluid/distributed/service/communicator.h"
#endif

26 27 28
namespace paddle {
namespace framework {

29
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
30
  fetch_config_ = desc.fetch_config();
31 32
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
33
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
34 35
    skip_ops_[i] = param_.skip_ops(i);
  }
36
  use_cvm_ = desc.use_cvm();
37
  thread_barrier_ = desc.thread_barrier();
38

39 40 41
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
42 43
}

44 45
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
46
  op_names_.clear();
47
  for (auto &op_desc : block.AllOps()) {
48 49
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
50
    OperatorBase *local_op_ptr = local_op.release();
51 52 53
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
54 55
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
56 57
}

58 59
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
60 61

  PADDLE_ENFORCE_NOT_NULL(
62 63 64
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
65 66

  thread_scope_ = &root_scope_->NewScope();
67 68

  for (auto &var : block.AllVars()) {
69
    all_param_.push_back(var->Name());
70
    if (var->Persistable()) {
71
      auto *ptr = root_scope_->Var(var->Name());
72
      InitializeVariable(ptr, var->GetType());
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
90
    } else {
91
      auto *ptr = thread_scope_->Var(var->Name());
92 93 94 95 96
      InitializeVariable(ptr, var->GetType());
    }
  }
}

97 98 99 100 101 102 103
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

104
void HogwildWorker::BindingDataFeedMemory() {
105
  const std::vector<std::string> &input_feed =
106
      device_reader_->GetUseSlotAlias();
107
  for (auto name : input_feed) {
108
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
109 110 111
  }
}

112
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
113 114 115 116 117 118
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
119
  device_reader_->Start();
120 121
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
122
  for (auto &op : ops_) {
123 124 125 126 127 128 129 130 131 132 133 134
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
135
  uint64_t total_inst = 0;
136
  while ((cur_batch = device_reader_->Next()) > 0) {
137
    VLOG(3) << "read a batch in thread " << thread_id_;
138 139 140 141
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
142 143 144 145 146 147 148
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
149
      timeline.Start();
150
      VLOG(3) << "Going to run op " << op_name[i];
151 152
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
153 154 155
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
156
      }
157
      VLOG(3) << "Op " << op_name[i] << " Finished";
158 159 160 161
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
162 163

    if (need_dump_field_) {
H
hutuxian 已提交
164 165 166 167
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
168 169
    }

D
dongdaxiang 已提交
170
    total_inst += cur_batch;
171
    ++batch_cnt;
D
dongdaxiang 已提交
172
    PrintFetchVars();
173 174 175 176 177 178 179 180 181 182
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time
            << " seconds, ins_num: " << total_inst;
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
183 184 185 186 187 188 189
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
190
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
191
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
192 193
      }
    }
194
#endif
D
dongdaxiang 已提交
195
    thread_scope_->DropKids();
196 197
    timeline.Start();
  }
198

H
hutuxian 已提交
199
  if (need_dump_field_ || need_dump_param_) {
200 201 202
    writer_.Flush();
  }

T
tangwei12 已提交
203
#if defined PADDLE_WITH_PSCORE
204
  if (thread_barrier_) {
T
tangwei12 已提交
205
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
206 207
  }
#endif
208 209 210 211
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
212 213
  platform::Timer timeline;
  timeline.Start();
214

215
  int total_ins_num = 0;
216
  // how to accumulate fetched values here
217
  device_reader_->Start();
218
  int cur_batch;
219
  while ((cur_batch = device_reader_->Next()) > 0) {
220
    for (auto &op : ops_) {
221 222 223 224 225 226 227 228 229 230
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
231 232
    }

233
    total_ins_num += cur_batch;
D
dongdaxiang 已提交
234
    PrintFetchVars();
D
dongdaxiang 已提交
235
    thread_scope_->DropKids();
236
  }
237 238 239
  timeline.Pause();
  VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
          << " seconds, ins_num: " << total_ins_num;
T
tangwei12 已提交
240
#if defined PADDLE_WITH_PSCORE
241
  if (thread_barrier_) {
T
tangwei12 已提交
242
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
243 244
  }
#endif
245 246
}

D
dongdaxiang 已提交
247 248 249 250
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
    std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S",
                  std::localtime(&curtime));

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
      platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i), &ss);
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
273 274
      }
    }
T
tangwei12 已提交
275 276

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
277 278 279
  }
}

280 281
}  // end namespace framework
}  // end namespace paddle