hogwild_worker.cc 7.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/data_type.h"
16
#include "paddle/fluid/framework/device_worker.h"
17
#include "paddle/fluid/framework/device_worker_factory.h"
Z
zhang wenhui 已提交
18
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
19
#include "paddle/fluid/operators/distributed/distributed.h"
20
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
21
#include "paddle/fluid/platform/lodtensor_printer.h"
22 23 24 25

namespace paddle {
namespace framework {

26
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
27
  fetch_config_ = desc.fetch_config();
28 29
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
30
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
31 32
    skip_ops_[i] = param_.skip_ops(i);
  }
33
  use_cvm_ = desc.use_cvm();
34
  thread_barrier_ = desc.thread_barrier();
35

36 37 38
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
39 40
}

41 42
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
43
  op_names_.clear();
44
  for (auto &op_desc : block.AllOps()) {
45 46
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
47
    OperatorBase *local_op_ptr = local_op.release();
48 49 50
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
51 52
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
53 54
}

55 56
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
57 58

  PADDLE_ENFORCE_NOT_NULL(
59 60 61
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
62 63

  thread_scope_ = &root_scope_->NewScope();
64 65

  for (auto &var : block.AllVars()) {
66
    all_param_.push_back(var->Name());
67
    if (var->Persistable()) {
68
      auto *ptr = root_scope_->Var(var->Name());
69
      InitializeVariable(ptr, var->GetType());
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
87
    } else {
88
      auto *ptr = thread_scope_->Var(var->Name());
89 90 91 92 93
      InitializeVariable(ptr, var->GetType());
    }
  }
}

94 95 96 97 98 99 100
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

101
void HogwildWorker::BindingDataFeedMemory() {
102
  const std::vector<std::string> &input_feed =
103
      device_reader_->GetUseSlotAlias();
104
  for (auto name : input_feed) {
105
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
106 107 108
  }
}

109
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
110 111 112 113 114 115
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
116
  device_reader_->Start();
117 118
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
119
  for (auto &op : ops_) {
120 121 122 123 124 125 126 127 128 129 130 131
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
132
  uint64_t total_inst = 0;
133
  while ((cur_batch = device_reader_->Next()) > 0) {
134
    VLOG(3) << "read a batch in thread " << thread_id_;
135 136 137 138
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
139 140 141 142 143 144 145
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
146
      timeline.Start();
147
      VLOG(3) << "Going to run op " << op_name[i];
148 149 150
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
      }
151
      VLOG(3) << "Op " << op_name[i] << " Finished";
152 153 154 155
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
156 157

    if (need_dump_field_) {
H
hutuxian 已提交
158 159 160 161
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
162 163
    }

D
dongdaxiang 已提交
164
    total_inst += cur_batch;
165
    ++batch_cnt;
D
dongdaxiang 已提交
166
    PrintFetchVars();
167 168 169 170 171 172 173
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
174
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
175
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
176 177
      }
    }
D
dongdaxiang 已提交
178
    thread_scope_->DropKids();
179 180
    timeline.Start();
  }
181

H
hutuxian 已提交
182
  if (need_dump_field_ || need_dump_param_) {
183 184 185
    writer_.Flush();
  }

186 187 188 189 190 191
#ifdef PADDLE_WITH_DISTRIBUTE
  if (thread_barrier_) {
    operators::distributed::Communicator::GetInstance()
        ->BarrierTriggerDecrement();
  }
#endif
192 193 194 195 196 197
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);

  // how to accumulate fetched values here
198
  device_reader_->Start();
199
  int cur_batch;
200
  while ((cur_batch = device_reader_->Next()) > 0) {
201
    for (auto &op : ops_) {
202 203 204 205 206 207 208 209 210 211
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
212 213
    }

D
dongdaxiang 已提交
214
    PrintFetchVars();
D
dongdaxiang 已提交
215
    thread_scope_->DropKids();
216
  }
217 218 219 220 221 222
#ifdef PADDLE_WITH_DISTRIBUTE
  if (thread_barrier_) {
    operators::distributed::Communicator::GetInstance()
        ->BarrierTriggerDecrement();
  }
#endif
223 224
}

D
dongdaxiang 已提交
225 226 227 228
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
D
dongdaxiang 已提交
229
  if (thread_id_ == 0) {
D
dongdaxiang 已提交
230 231
    if (batch_num_ % batch_per_print == 0) {
      int fetch_var_num = fetch_config_.fetch_var_names_size();
D
dongdaxiang 已提交
232
      for (int i = 0; i < fetch_var_num; ++i) {
D
dongdaxiang 已提交
233
        platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
D
dongdaxiang 已提交
234
                           fetch_config_.fetch_var_str_format(i));
D
dongdaxiang 已提交
235 236 237 238 239
      }
    }
  }
}

240 241
}  // end namespace framework
}  // end namespace paddle