async_executor.h 5.7 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#define PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_

#include <memory>
#include <mutex>    // NOLINT
#include <set>
#include <map>
#include <string>
#include <thread>   // NOLINT
#include <vector>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);

class ExecutorThreadWorker {
 public:
  ExecutorThreadWorker() {}
  virtual ~ExecutorThreadWorker() {}
W
wangguibao 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
  void CreateThreadScope(const framework::ProgramDesc& program);
  void SetDataFeed(const DataFeed& datafeed);
  void SetThreadId(int tid);
  void CreateThreadOperators(const framework::ProgramDesc& program);
  void SetRootScope(Scope* g_scope);
  void SetDevice();
  virtual void AddFidSet();
  void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
  void AddTrainFile(const std::string& filename);
  void SetMainProgram(const ProgramDesc& main_program_desc);
  void SetPlace(const paddle::platform::Place& place);
  void SetMaxTrainingEpoch(const int max_epoch);
  void BindingDataFeedMemory();
  void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
  void SetInspectVarName(const std::string& inspect_var_name);
  void SetModelParamNames(const std::vector<std::string>& param_names);
  void SetSparseCommData(const std::map<std::string, int>& param_names);
  void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
  void Train();
  virtual const char* PickOneFile();
  void UpdateEpochNum();

  virtual void SetDenseCommTensor(
W
wangguibao 已提交
62
      const std::vector<std::string>& param_names) {}
W
wangguibao 已提交
63
  virtual void Initialize() {}
W
wangguibao 已提交
64 65

 public:
W
wangguibao 已提交
66 67 68 69 70 71 72
  static std::mutex s_locker_for_pick_file_;
  static unsigned int s_current_file_idx_;
  static size_t s_current_finished_file_cnt_;
  static unsigned int s_current_epoch_;
  static int s_current_save_epoch_;
  static std::vector<std::string> s_thread_filelist_;   // filelist
  static bool s_is_first_worker_;
W
wangguibao 已提交
73 74 75

 protected:
  // thread index
W
wangguibao 已提交
76
  int thread_id_;
W
wangguibao 已提交
77 78

  // max epoch for each thread
W
wangguibao 已提交
79
  unsigned int max_epoch_;
W
wangguibao 已提交
80 81

  // instances learned currently
W
wangguibao 已提交
82 83 84
  int comm_batch_;
  std::string model_prefix_;
  std::vector<std::string> op_names_;
W
wangguibao 已提交
85 86

  // local ops for forward and backward
W
wangguibao 已提交
87
  std::vector<OperatorBase *> ops_;
W
wangguibao 已提交
88 89

  // main program for training
W
wangguibao 已提交
90
  std::unique_ptr<framework::ProgramDesc> main_program_;
W
wangguibao 已提交
91 92

  // binary data reader
W
wangguibao 已提交
93
  std::shared_ptr<DataFeed> local_reader_;
W
wangguibao 已提交
94

W
wangguibao 已提交
95 96 97
  std::string inspect_var_name_;
  std::vector<std::string> model_param_names_;
  std::map<std::string, int> sparse_comm_data_;
W
wangguibao 已提交
98 99

  // execution place
W
wangguibao 已提交
100
  platform::Place place_;
W
wangguibao 已提交
101 102

  // root scope for model parameters
W
wangguibao 已提交
103
  Scope* root_scope_;
W
wangguibao 已提交
104 105

  // a thread scope, father scope is global score which is shared
W
wangguibao 已提交
106
  Scope* thread_scope_;
W
wangguibao 已提交
107 108
};

W
wangguibao 已提交
109
class AsyncExecutor {
W
wangguibao 已提交
110
 public:
W
wangguibao 已提交
111 112
  explicit AsyncExecutor(const platform::Place& place);
  virtual ~AsyncExecutor() {}
W
wangguibao 已提交
113
  static std::unique_ptr<ProgramDesc> LoadDescFromFile(
W
wangguibao 已提交
114
                                          const std::string& filename);
W
wangguibao 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127
  void InitRootScope(Scope* scope);
  void SetInspectVarName(const std::string& inspect_var_name);
  void SetParamNames(const std::vector<std::string>& param_names);
  void SetMaxTrainingEpoch(const int max_epoch);
  Scope* GetRootScope() { return root_scope_; }
  void SetThreadNum(const int thread_num);
  void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
  void SetFileList(const char* filelist);
  void SetFileList(const std::vector<std::string> filelist);
  void SetDataFeedName(const char* feedname);

  void SetDataFeedParam(const datafeed::DataFeedParameter& feed_param) {
    data_feed_param_ = feed_param;
W
wangguibao 已提交
128 129
  }

W
wangguibao 已提交
130 131
  void SetCommBatch(int comm_batch) {
    comm_batch_ = comm_batch;
W
wangguibao 已提交
132 133
  }

W
wangguibao 已提交
134 135 136
  void SetModelPrefix(const std::string& model_prefix);
  void SetDenseCommTensor(const std::vector<std::string>& dense_comm_tensor);
  void SetSparseCommTensor(
W
wangguibao 已提交
137
      const std::vector<std::string>& sparse_comm_tensor);
W
wangguibao 已提交
138 139 140
  void SetSparseCommData(const std::map<std::string, int>& sparse_comm_data);
  virtual void PrepareThreads(const framework::ProgramDesc& host_program);
  void RunStartupProgram(const framework::ProgramDesc& program,
W
wangguibao 已提交
141
      framework::Scope* scope);
W
wangguibao 已提交
142
  void RunAsyncExecutor(const ProgramDesc& host_program);
W
wangguibao 已提交
143 144

 public:
W
wangguibao 已提交
145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
  unsigned int thread_num_;
  datafeed::DataFeedParameter data_feed_param_;
  int max_epoch_;
  int batch_size_;
  int comm_batch_;
  std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
  std::vector<std::thread> threads_;
  std::vector<std::string> filelist_;
  std::string inspect_var_name_;
  std::vector<std::string> model_param_names_;
  std::vector<std::string> dense_comm_tensor_;
  std::vector<std::string> sparse_comm_tensor_;
  std::map<std::string, int> sparse_comm_data_;
  std::string model_prefix_;
  std::string feed_name_;
  Scope* root_scope_;
  platform::Place place_;
W
wangguibao 已提交
162 163 164 165 166 167
};

}  // namespace framework
}  // namespace paddle
#endif  // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */