hogwild_worker.cc 9.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

T
tangwei12 已提交
15
#include <ctime>
16 17

#include "paddle/fluid/framework/convert_utils.h"
18
#include "paddle/fluid/framework/data_type.h"
19
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
20
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
21
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
22
#include "paddle/fluid/platform/lodtensor_printer.h"
23

T
tangwei12 已提交
24
#if defined PADDLE_WITH_PSCORE
25
#include "paddle/fluid/distributed/ps/service/communicator/communicator.h"
T
tangwei12 已提交
26 27
#endif

28 29 30
namespace paddle {
namespace framework {

31
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
32
  fetch_config_ = desc.fetch_config();
33 34
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
35
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
36 37
    skip_ops_[i] = param_.skip_ops(i);
  }
38
  use_cvm_ = desc.use_cvm();
39
  thread_barrier_ = desc.thread_barrier();
40

41 42 43
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
44 45
}

46 47
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
48
  op_names_.clear();
49
  for (auto &op_desc : block.AllOps()) {
50 51
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
52
    OperatorBase *local_op_ptr = local_op.release();
53 54 55
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
56 57
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
58 59
}

60 61
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
62 63

  PADDLE_ENFORCE_NOT_NULL(
64 65 66
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
67 68

  thread_scope_ = &root_scope_->NewScope();
69 70

  for (auto &var : block.AllVars()) {
71
    all_param_.push_back(var->Name());
72
    if (var->Persistable()) {
73
      auto *ptr = root_scope_->Var(var->Name());
74
      InitializeVariable(ptr, var->GetType());
75 76 77 78 79 80 81 82 83
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
84 85 86 87 88
#define MemsetCallback(cpp_type, proto_type)                                  \
  do {                                                                        \
    if (framework::TransToProtoVarType(root_tensor->dtype()) == proto_type) { \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim);              \
    }                                                                         \
89 90 91
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
92
    } else {
93
      auto *ptr = thread_scope_->Var(var->Name());
94 95 96 97 98
      InitializeVariable(ptr, var->GetType());
    }
  }
}

99 100 101 102 103 104 105
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

106
void HogwildWorker::BindingDataFeedMemory() {
107
  const std::vector<std::string> &input_feed =
108
      device_reader_->GetUseSlotAlias();
109
  for (auto name : input_feed) {
110
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
111 112 113
  }
}

114
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
115 116 117 118 119 120
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
121
  device_reader_->Start();
122 123
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
124
  for (auto &op : ops_) {
125 126 127 128 129 130 131 132 133 134 135 136
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
137
  uint64_t total_inst = 0;
138
  while ((cur_batch = device_reader_->Next()) > 0) {
139
    VLOG(3) << "read a batch in thread " << thread_id_;
140 141 142 143
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
144 145 146 147 148 149 150
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
151
      timeline.Start();
152
      VLOG(3) << "Going to run op " << op_name[i];
153 154
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
155 156 157
#ifdef PADDLE_WITH_HETERPS
        dev_ctx_->Wait();
#endif
158
      }
159
      VLOG(3) << "Op " << op_name[i] << " Finished";
160 161 162 163
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
164 165

    if (need_dump_field_) {
H
hutuxian 已提交
166 167 168 169
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
170 171
    }

D
dongdaxiang 已提交
172
    total_inst += cur_batch;
173
    ++batch_cnt;
D
dongdaxiang 已提交
174
    PrintFetchVars();
175 176 177 178 179 180 181 182 183 184
#ifdef PADDLE_WITH_HETERPS
    dev_ctx_->Wait();
    VLOG(1) << "GpuPs worker " << thread_id_ << " train cost " << total_time
            << " seconds, ins_num: " << total_inst;
    for (size_t i = 0; i < op_name.size(); ++i) {
      VLOG(1) << "card:" << thread_id_ << ", op: " << op_name[i]
              << ", mean time: " << op_total_time[i] / total_inst
              << "s, totol time:" << op_total_time[i] << "sec";
    }
#else
185 186 187 188 189 190 191
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
192
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
193
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
194 195
      }
    }
196
#endif
D
dongdaxiang 已提交
197
    thread_scope_->DropKids();
198 199
    timeline.Start();
  }
200

H
hutuxian 已提交
201
  if (need_dump_field_ || need_dump_param_) {
202 203 204
    writer_.Flush();
  }

T
tangwei12 已提交
205
#if defined PADDLE_WITH_PSCORE
206
  if (thread_barrier_) {
T
tangwei12 已提交
207
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
208 209
  }
#endif
210 211 212 213
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
214 215
  platform::Timer timeline;
  timeline.Start();
216

217
  int total_ins_num = 0;
218
  // how to accumulate fetched values here
219
  device_reader_->Start();
220
  int cur_batch;
W
wangguanqun 已提交
221
  int batch_cnt = 0;
222
  while ((cur_batch = device_reader_->Next()) > 0) {
223
    for (auto &op : ops_) {
224 225 226 227 228 229 230 231 232 233
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
234 235
    }

W
wangguanqun 已提交
236 237 238 239 240 241 242
    if (need_dump_field_) {
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
    }

243
    total_ins_num += cur_batch;
W
wangguanqun 已提交
244
    ++batch_cnt;
D
dongdaxiang 已提交
245
    PrintFetchVars();
D
dongdaxiang 已提交
246
    thread_scope_->DropKids();
247
  }
248 249 250
  timeline.Pause();
  VLOG(3) << "worker " << thread_id_ << " train cost " << timeline.ElapsedSec()
          << " seconds, ins_num: " << total_ins_num;
W
wangguanqun 已提交
251 252 253 254 255

  if (need_dump_field_ || need_dump_param_) {
    writer_.Flush();
  }

T
tangwei12 已提交
256
#if defined PADDLE_WITH_PSCORE
257
  if (thread_barrier_) {
T
tangwei12 已提交
258
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
259 260
  }
#endif
261 262
}

D
dongdaxiang 已提交
263 264 265 266
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
T
tangwei12 已提交
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
    std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S",
                  std::localtime(&curtime));

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
      platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i), &ss);
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
289 290
      }
    }
T
tangwei12 已提交
291 292

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
293 294 295
  }
}

296 297
}  // end namespace framework
}  // end namespace paddle