hogwild_worker.cc 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include "paddle/fluid/framework/data_type.h"
16
#include "paddle/fluid/framework/device_worker.h"
Z
zhang wenhui 已提交
17
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
18
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
19
#include "paddle/fluid/platform/lodtensor_printer.h"
20

T
tangwei12 已提交
21
#if defined PADDLE_WITH_PSCORE
T
tangwei12 已提交
22 23 24
#include "paddle/fluid/distributed/service/communicator.h"
#endif

25 26 27
namespace paddle {
namespace framework {

28
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
29
  fetch_config_ = desc.fetch_config();
30 31
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
32
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
33 34
    skip_ops_[i] = param_.skip_ops(i);
  }
35
  use_cvm_ = desc.use_cvm();
36
  thread_barrier_ = desc.thread_barrier();
37

38 39 40
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
41 42
}

43 44
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
45
  op_names_.clear();
46
  for (auto &op_desc : block.AllOps()) {
47 48
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
49
    OperatorBase *local_op_ptr = local_op.release();
50 51 52
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
53 54
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
55 56
}

57 58
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
59 60

  PADDLE_ENFORCE_NOT_NULL(
61 62 63
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
64 65

  thread_scope_ = &root_scope_->NewScope();
66 67

  for (auto &var : block.AllVars()) {
68
    all_param_.push_back(var->Name());
69
    if (var->Persistable()) {
70
      auto *ptr = root_scope_->Var(var->Name());
71
      InitializeVariable(ptr, var->GetType());
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
89
    } else {
90
      auto *ptr = thread_scope_->Var(var->Name());
91 92 93 94 95
      InitializeVariable(ptr, var->GetType());
    }
  }
}

96 97 98 99 100 101 102
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

103
void HogwildWorker::BindingDataFeedMemory() {
104
  const std::vector<std::string> &input_feed =
105
      device_reader_->GetUseSlotAlias();
106
  for (auto name : input_feed) {
107
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
108 109 110
  }
}

111
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
112 113 114 115 116 117
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
118
  device_reader_->Start();
119 120
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
121
  for (auto &op : ops_) {
122 123 124 125 126 127 128 129 130 131 132 133
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
134
  uint64_t total_inst = 0;
135
  while ((cur_batch = device_reader_->Next()) > 0) {
136
    VLOG(3) << "read a batch in thread " << thread_id_;
137 138 139 140
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
141 142 143 144 145 146 147
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
148
      timeline.Start();
149
      VLOG(3) << "Going to run op " << op_name[i];
150 151 152
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
      }
153
      VLOG(3) << "Op " << op_name[i] << " Finished";
154 155 156 157
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
158 159

    if (need_dump_field_) {
H
hutuxian 已提交
160 161 162 163
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
164 165
    }

D
dongdaxiang 已提交
166
    total_inst += cur_batch;
167
    ++batch_cnt;
D
dongdaxiang 已提交
168
    PrintFetchVars();
169 170 171 172 173 174 175
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
176
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
177
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
178 179
      }
    }
D
dongdaxiang 已提交
180
    thread_scope_->DropKids();
181 182
    timeline.Start();
  }
183

H
hutuxian 已提交
184
  if (need_dump_field_ || need_dump_param_) {
185 186 187
    writer_.Flush();
  }

T
tangwei12 已提交
188
#if defined PADDLE_WITH_PSCORE
189
  if (thread_barrier_) {
T
tangwei12 已提交
190
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
191 192
  }
#endif
193 194 195 196
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);
P
phlrain 已提交
197 198
    
  std::cerr << "1!!!!!" << std::endl;
199
  // how to accumulate fetched values here
200
  device_reader_->Start();
201
  int cur_batch;
P
phlrain 已提交
202
  int i = 0;
203
  while ((cur_batch = device_reader_->Next()) > 0) {
P
phlrain 已提交
204
    i++;
205
    for (auto &op : ops_) {
206 207 208 209 210 211 212 213 214 215
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
216 217
    }

D
dongdaxiang 已提交
218
    PrintFetchVars();
D
dongdaxiang 已提交
219
    thread_scope_->DropKids();
220
  }
P
phlrain 已提交
221
  std::cerr << "total bacth " << i << std::endl;
T
tangwei12 已提交
222
#if defined PADDLE_WITH_PSCORE
223
  if (thread_barrier_) {
T
tangwei12 已提交
224
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
225 226
  }
#endif
227 228
}

D
dongdaxiang 已提交
229 230 231 232
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
D
dongdaxiang 已提交
233
  if (thread_id_ == 0) {
D
dongdaxiang 已提交
234 235
    if (batch_num_ % batch_per_print == 0) {
      int fetch_var_num = fetch_config_.fetch_var_names_size();
D
dongdaxiang 已提交
236
      for (int i = 0; i < fetch_var_num; ++i) {
D
dongdaxiang 已提交
237
        platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
D
dongdaxiang 已提交
238
                           fetch_config_.fetch_var_str_format(i));
D
dongdaxiang 已提交
239 240 241 242 243
      }
    }
  }
}

244 245
}  // end namespace framework
}  // end namespace paddle