hogwild_worker.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15
#include <ctime>
16
#include "paddle/fluid/framework/data_type.h"
17
#include "paddle/fluid/framework/device_worker.h"
18
#include "paddle/fluid/framework/device_worker_factory.h"
Z
zhang wenhui 已提交
19
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
20
#include "paddle/fluid/platform/cpu_helper.h"
D
dongdaxiang 已提交
21
#include "paddle/fluid/platform/lodtensor_printer.h"
22

T
tangwei12 已提交
23
#if defined PADDLE_WITH_PSCORE
T
tangwei12 已提交
24 25 26
#include "paddle/fluid/distributed/service/communicator.h"
#endif

27 28 29
namespace paddle {
namespace framework {

30
void HogwildWorker::Initialize(const TrainerDesc &desc) {
D
dongdaxiang 已提交
31
  fetch_config_ = desc.fetch_config();
32 33
  param_ = desc.hogwild_param();
  skip_ops_.resize(param_.skip_ops_size());
34
  for (int i = 0; i < param_.skip_ops_size(); ++i) {
35 36
    skip_ops_[i] = param_.skip_ops(i);
  }
37
  use_cvm_ = desc.use_cvm();
38
  thread_barrier_ = desc.thread_barrier();
39

40 41 42
  for (int i = 0; i < param_.stat_var_names_size(); ++i) {
    stat_var_name_map_[param_.stat_var_names(i)] = 1;
  }
D
dongdaxiang 已提交
43 44
}

45 46
void HogwildWorker::CreateThreadOperators(const ProgramDesc &program) {
  auto &block = program.Block(0);
47
  op_names_.clear();
48
  for (auto &op_desc : block.AllOps()) {
49 50
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
    op_names_.push_back(op_desc->Type());
51
    OperatorBase *local_op_ptr = local_op.release();
52 53 54
    ops_.push_back(local_op_ptr);
    continue;
  }
Z
zhang wenhui 已提交
55 56
  operators::PrepareSafeEagerDeletionOnConditionalOpAndConditionalGradOp(
      program, 0, ops_);
57 58
}

59 60
void HogwildWorker::CreateThreadScope(const ProgramDesc &program) {
  auto &block = program.Block(0);
61 62

  PADDLE_ENFORCE_NOT_NULL(
63 64 65
      root_scope_,
      platform::errors::NotFound(
          "Root scope should be set before creating thread scope."));
66 67

  thread_scope_ = &root_scope_->NewScope();
68 69

  for (auto &var : block.AllVars()) {
70
    all_param_.push_back(var->Name());
71
    if (var->Persistable()) {
72
      auto *ptr = root_scope_->Var(var->Name());
73
      InitializeVariable(ptr, var->GetType());
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
      if (stat_var_name_map_.find(var->Name()) != stat_var_name_map_.end() &&
          thread_id_ != 0) {
        int tensor_dim =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>()->numel();
        auto *ptr1 = thread_scope_->Var(var->Name());
        InitializeVariable(ptr1, var->GetType());
        LoDTensor *thread_tensor = ptr1->GetMutable<LoDTensor>();
        LoDTensor *root_tensor =
            root_scope_->FindVar(var->Name())->GetMutable<LoDTensor>();
#define MemsetCallback(cpp_type, proto_type)                     \
  do {                                                           \
    if (root_tensor->type() == proto_type) {                     \
      SetZero<cpp_type>(thread_tensor, root_tensor, tensor_dim); \
    }                                                            \
  } while (0)
        _ForEachDataType_(MemsetCallback);
      }
91
    } else {
92
      auto *ptr = thread_scope_->Var(var->Name());
93 94 95 96 97
      InitializeVariable(ptr, var->GetType());
    }
  }
}

98 99 100 101 102 103 104
template <typename T>
void HogwildWorker::SetZero(LoDTensor *tensor, LoDTensor *root_tensor,
                            int tensor_dim) {
  T *ptr = tensor->mutable_data<T>(root_tensor->dims(), platform::CPUPlace());
  memset(ptr, 0, sizeof(T) * tensor_dim);
}

105
void HogwildWorker::BindingDataFeedMemory() {
106
  const std::vector<std::string> &input_feed =
107
      device_reader_->GetUseSlotAlias();
108
  for (auto name : input_feed) {
109
    device_reader_->AddFeedVar(thread_scope_->FindVar(name), name);
110 111 112
  }
}

113
void HogwildWorker::CreateDeviceResource(const ProgramDesc &main_prog) {
114 115 116 117 118 119
  CreateThreadScope(main_prog);
  CreateThreadOperators(main_prog);
}

void HogwildWorker::TrainFilesWithProfiler() {
  platform::SetNumThreads(1);
120
  device_reader_->Start();
121 122
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
123
  for (auto &op : ops_) {
124 125 126 127 128 129 130 131 132 133 134 135
    op_name.push_back(op->Type());
  }
  op_total_time.resize(ops_.size());
  for (size_t i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
  platform::Timer timeline;
  double total_time = 0.0;
  double read_time = 0.0;
  int cur_batch;
  int batch_cnt = 0;
  timeline.Start();
D
dongdaxiang 已提交
136
  uint64_t total_inst = 0;
137
  while ((cur_batch = device_reader_->Next()) > 0) {
138
    VLOG(3) << "read a batch in thread " << thread_id_;
139 140 141 142
    timeline.Pause();
    read_time += timeline.ElapsedSec();
    total_time += timeline.ElapsedSec();
    for (size_t i = 0; i < ops_.size(); ++i) {
143 144 145 146 147 148 149
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (ops_[i]->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
150
      timeline.Start();
151
      VLOG(3) << "Going to run op " << op_name[i];
152 153 154
      if (!need_skip) {
        ops_[i]->Run(*thread_scope_, place_);
      }
155
      VLOG(3) << "Op " << op_name[i] << " Finished";
156 157 158 159
      timeline.Pause();
      op_total_time[i] += timeline.ElapsedSec();
      total_time += timeline.ElapsedSec();
    }
160 161

    if (need_dump_field_) {
H
hutuxian 已提交
162 163 164 165
      DumpField(*thread_scope_, dump_mode_, dump_interval_);
    }
    if (need_dump_param_ && thread_id_ == 0) {
      DumpParam(*thread_scope_, batch_cnt);
166 167
    }

D
dongdaxiang 已提交
168
    total_inst += cur_batch;
169
    ++batch_cnt;
D
dongdaxiang 已提交
170
    PrintFetchVars();
171 172 173 174 175 176 177
    if (thread_id_ == 0) {
      if (batch_cnt > 0 && batch_cnt % 100 == 0) {
        for (size_t i = 0; i < ops_.size(); ++i) {
          fprintf(stderr, "op_name:[%zu][%s], op_mean_time:[%fs]\n", i,
                  op_name[i].c_str(), op_total_time[i] / batch_cnt);
        }
        fprintf(stderr, "mean read time: %fs\n", read_time / batch_cnt);
D
dongdaxiang 已提交
178
        fprintf(stderr, "IO percent: %f\n", read_time / total_time * 100);
D
dongdaxiang 已提交
179
        fprintf(stderr, "%6.2f instances/s\n", total_inst / total_time);
180 181
      }
    }
D
dongdaxiang 已提交
182
    thread_scope_->DropKids();
183 184
    timeline.Start();
  }
185

H
hutuxian 已提交
186
  if (need_dump_field_ || need_dump_param_) {
187 188 189
    writer_.Flush();
  }

T
tangwei12 已提交
190
#if defined PADDLE_WITH_PSCORE
191
  if (thread_barrier_) {
T
tangwei12 已提交
192
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
193 194
  }
#endif
195 196 197 198 199 200
}

void HogwildWorker::TrainFiles() {
  platform::SetNumThreads(1);

  // how to accumulate fetched values here
201
  device_reader_->Start();
202
  int cur_batch;
203
  while ((cur_batch = device_reader_->Next()) > 0) {
204
    for (auto &op : ops_) {
205 206 207 208 209 210 211 212 213 214
      bool need_skip = false;
      for (auto t = 0u; t < skip_ops_.size(); ++t) {
        if (op->Type().find(skip_ops_[t]) != std::string::npos) {
          need_skip = true;
          break;
        }
      }
      if (!need_skip) {
        op->Run(*thread_scope_, place_);
      }
215 216
    }

D
dongdaxiang 已提交
217
    PrintFetchVars();
D
dongdaxiang 已提交
218
    thread_scope_->DropKids();
219
  }
T
tangwei12 已提交
220
#if defined PADDLE_WITH_PSCORE
221
  if (thread_barrier_) {
T
tangwei12 已提交
222
    paddle::distributed::Communicator::GetInstance()->BarrierTriggerDecrement();
223 224
  }
#endif
225 226
}

D
dongdaxiang 已提交
227 228 229 230
void HogwildWorker::PrintFetchVars() {
  // call count
  batch_num_++;
  int batch_per_print = fetch_config_.print_period();
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  int fetch_var_num = fetch_config_.fetch_var_names_size();

  if (fetch_var_num == 0) {
    return;
  }

  if (thread_id_ == 0 && batch_num_ % batch_per_print == 0) {
    time_t curtime;
    time(&curtime);
    char mbstr[80];
    std::strftime(mbstr, sizeof(mbstr), "%Y-%m-%d %H:%M:%S",
                  std::localtime(&curtime));

    std::stringstream ss;
    ss << "time: [" << mbstr << "], ";
    ss << "batch: [" << batch_num_ << "], ";

    for (int i = 0; i < fetch_var_num; ++i) {
      platform::PrintVar(thread_scope_, fetch_config_.fetch_var_names(i),
                         fetch_config_.fetch_var_str_format(i), &ss);
      if (i < fetch_var_num - 1) {
        ss << ", ";
D
dongdaxiang 已提交
253 254
      }
    }
255 256

    std::cout << ss.str() << std::endl;
D
dongdaxiang 已提交
257 258 259
  }
}

260 261
}  // end namespace framework
}  // end namespace paddle