hogwild_worker.cc 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/data_type.h"
16
#include "paddle/fluid/framework/device_worker.h"
17
#include "paddle/fluid/framework/device_worker_factory.h"
Z
zhang wenhui 已提交
18
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
19
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
20
#include "paddle/fluid/platform/lodtensor_printer.h"
21

T
tangwei12 已提交
22 23 24 25
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/distributed/service/communicator.h"
#endif

26 27 28
namespace paddle {
namespace framework {

29
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
30
  fetch_config_ = desc.fetch_config();
31 32
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
33
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
34 35
    skip_ops_[i] = param_.skip_ops(i);
  }
36
  use_cvm_ = desc.use_cvm();
37
  thread_barrier_ = desc.thread_barrier();
38

39 40 41
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
42 43
}

44 45
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
46
  op_names_.clear();
47
  for (auto &op_desc : block.AllOps()) {
48 49
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
50
    OperatorBase *local_op_ptr = local_op.release();
51 52 53
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
54 55
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
56 57
}

58 59
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
60 61

  PADDLE_ENFORCE_NOT_NULL(
62 63 64
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
65 66

  thread_scope_ = &root_scope_->NewScope();
67 68

  for (auto &var : block.AllVars()) {
69
    all_param_.push_back(var->Name());
70
    if (var->Persistable()) {
71
      auto *ptr = root_scope_->Var(var->Name());
72
      InitializeVariable(ptr, var->GetType());
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
90
    } else {
91
      auto *ptr = thread_scope_->Var(var->Name());
92 93 94 95 96
      InitializeVariable(ptr, var->GetType());
    }
  }
}

97 98 99 100 101 102 103
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

104
void HogwildWorker::BindingDataFeedMemory() {
105
  const std::vector<std::string> &input_feed =
106
      device_reader_->GetUseSlotAlias();
107
  for (auto name : input_feed) {
108
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
109 110 111
  }
}

112
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
113 114 115 116 117 118
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
119
  device_reader_->Start();
120 121
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
122
  for (auto &op : ops_) {
123 124 125 126 127 128 129 130 131 132 133 134
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
135
  uint64_t total_inst = 0;
136
  while ((cur_batch = device_reader_->Next()) > 0) {
137
    VLOG(3) << "read a batch in thread " << thread_id_;
138 139 140 141
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
142 143 144 145 146 147 148
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
149
      timeline.Start();
150
      VLOG(3) << "Going to run op " << op_name[i];
151 152 153
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
      }
154
      VLOG(3) << "Op " << op_name[i] << " Finished";
155 156 157 158
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
159 160

    if (need_dump_field_) {
H
hutuxian 已提交
161 162 163 164
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
165 166
    }

D
dongdaxiang 已提交
167
    total_inst += cur_batch;
168
    ++batch_cnt;
D
dongdaxiang 已提交
169
    PrintFetchVars();
170 171 172 173 174 175 176
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
177
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
178
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
179 180
      }
    }
D
dongdaxiang 已提交
181
    thread_scope_->DropKids();
182 183
    timeline.Start();
  }
184

H
hutuxian 已提交
185
  if (need_dump_field_ || need_dump_param_) {
186 187 188
    writer_.Flush();
  }

189 190
#ifdef PADDLE_WITH_DISTRIBUTE
  if (thread_barrier_) {
T
tangwei12 已提交
191
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
192 193
  }
#endif
194 195 196 197 198 199
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);

  // how to accumulate fetched values here
200
  device_reader_->Start();
201
  int cur_batch;
202
  while ((cur_batch = device_reader_->Next()) > 0) {
203
    for (auto &op : ops_) {
204 205 206 207 208 209 210 211 212 213
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
214 215
    }

D
dongdaxiang 已提交
216
    PrintFetchVars();
D
dongdaxiang 已提交
217
    thread_scope_->DropKids();
218
  }
219 220
#ifdef PADDLE_WITH_DISTRIBUTE
  if (thread_barrier_) {
T
tangwei12 已提交
221
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
222 223
  }
#endif
224 225
}

D
dongdaxiang 已提交
226 227 228 229
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
D
dongdaxiang 已提交
230
  if (thread_id_ == 0) {
D
dongdaxiang 已提交
231 232
    if (batch_num_ % batch_per_print == 0) {
      int fetch_var_num = fetch_config_.fetch_var_names_size();
D
dongdaxiang 已提交
233
      for (int i = 0; i < fetch_var_num; ++i) {
D
dongdaxiang 已提交
234
        platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
D
dongdaxiang 已提交
235
                           fetch_config_.fetch_var_str_format(i));
D
dongdaxiang 已提交
236 237 238 239 240
      }
    }
  }
}

241 242
}  // end namespace framework
}  // end namespace paddle