async_executor.h 5.2 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#define PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_

#include <memory>
#include <mutex>    // NOLINT
#include <set>
#include <map>
#include <string>
#include <thread>   // NOLINT
#include <vector>
25
#include <typeinfo>
W
wangguibao 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
void CreateTensor(Variable* var, proto::VarType::Type var_type);

class ExecutorThreadWorker {
 public:
  ExecutorThreadWorker() {}
W
wangguibao 已提交
39
  ~ExecutorThreadWorker() {}
40
  void CreateThreadScope(const ProgramDesc& program);
W
wangguibao 已提交
41
  void SetThreadId(int tid);
42
  void CreateThreadOperators(const ProgramDesc& program);
W
wangguibao 已提交
43 44
  void SetRootScope(Scope* g_scope);
  void SetDevice();
W
wangguibao 已提交
45
  void AddFidSet();
W
wangguibao 已提交
46 47 48 49 50 51
  void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
  void AddTrainFile(const std::string& filename);
  void SetMainProgram(const ProgramDesc& main_program_desc);
  void SetPlace(const paddle::platform::Place& place);
  void SetMaxTrainingEpoch(const int max_epoch);
  void BindingDataFeedMemory();
W
wangguibao 已提交
52

W
wangguibao 已提交
53
  void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
W
wangguibao 已提交
54

55
  void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);
W
wangguibao 已提交
56
  void SetModelParamNames(const std::vector<std::string>& param_names);
57
  void SetDataFeed(DataFeed& datafeed); // NOLINT
W
wangguibao 已提交
58
  void Train();
W
wangguibao 已提交
59
  const char* PickOneFile();
W
wangguibao 已提交
60
  void UpdateEpochNum();
61
  void Reset();
W
wangguibao 已提交
62

W
wangguibao 已提交
63
  void Initialize() {}
64
  std::vector<float>& GetInspectValues() {return inspect_values_;}
W
wangguibao 已提交
65 66 67

 protected:
  // thread index
W
wangguibao 已提交
68
  int thread_id_;
W
wangguibao 已提交
69 70

  // max epoch for each thread
W
wangguibao 已提交
71
  unsigned int max_epoch_;
W
wangguibao 已提交
72 73

  // instances learned currently
W
wangguibao 已提交
74 75 76
  int comm_batch_;
  std::string model_prefix_;
  std::vector<std::string> op_names_;
W
wangguibao 已提交
77 78

  // local ops for forward and backward
W
wangguibao 已提交
79
  std::vector<OperatorBase *> ops_;
W
wangguibao 已提交
80 81

  // main program for training
82
  std::unique_ptr<ProgramDesc> main_program_;
W
wangguibao 已提交
83 84

  // binary data reader
85
  std::unique_ptr<DataFeed> local_reader_;
W
wangguibao 已提交
86

87
  std::vector<std::string> inspect_var_names_;
W
wangguibao 已提交
88
  std::vector<std::string> model_param_names_;
W
wangguibao 已提交
89 90

  // execution place
W
wangguibao 已提交
91
  platform::Place place_;
W
wangguibao 已提交
92 93

  // root scope for model parameters
W
wangguibao 已提交
94
  Scope* root_scope_;
W
wangguibao 已提交
95 96

  // a thread scope, father scope is global score which is shared
W
wangguibao 已提交
97
  Scope* thread_scope_;
98 99 100

 private:
  std::vector<float> inspect_values_;
W
wangguibao 已提交
101 102
};

W
wangguibao 已提交
103
class AsyncExecutor {
W
wangguibao 已提交
104
 public:
105 106 107 108 109
  explicit AsyncExecutor(ProgramDesc& main_program,     // NOLINT
                         const std::vector<std::string>& param_names,
                         TextClassDataFeed& data_feed,  // NOLINT
                         unsigned int thread_num,
                         const platform::Place& place);
W
wangguibao 已提交
110
  virtual ~AsyncExecutor() {}
W
wangguibao 已提交
111
  static std::unique_ptr<ProgramDesc> LoadDescFromFile(
W
wangguibao 已提交
112
                                          const std::string& filename);
W
wangguibao 已提交
113 114 115 116
  void InitRootScope(Scope* scope);
  void SetMaxTrainingEpoch(const int max_epoch);
  Scope* GetRootScope() { return root_scope_; }
  void SetBatchSize(const int batch_size) { batch_size_ = batch_size; }
117

W
wangguibao 已提交
118 119 120
  void SetCommBatch(int comm_batch) {
    comm_batch_ = comm_batch;
  }
W
wangguibao 已提交
121

W
wangguibao 已提交
122 123
  void SetModelPath(const std::string& model_path) {
    model_path_ = model_path;
W
wangguibao 已提交
124 125
  }

W
wangguibao 已提交
126 127 128 129 130 131
  void SetInitProgFile(const std::string& init_prog_file) {
    init_prog_file_ = init_prog_file;
  }

  void SetInitModelFile(const std::string& init_model_file) {
    init_model_file_ = init_model_file;
W
wangguibao 已提交
132 133
  }

W
wangguibao 已提交
134
  void SetModelPrefix(const std::string& model_prefix);
135 136 137
  virtual void PrepareThreads(const ProgramDesc& host_program);
  void RunStartupProgram(const ProgramDesc& program, Scope* scope);
  std::vector<float>& Run(const std::vector<std::string>& inspect_var_names);
W
wangguibao 已提交
138

W
wangguibao 已提交
139 140
  void LoadInitModel();

141 142 143
 private:
  void SetInspectVarNames(const std::vector<std::string>& inspect_var_names);

W
wangguibao 已提交
144
 public:
145
  int thread_num_;
W
wangguibao 已提交
146 147 148 149 150
  int max_epoch_;
  int batch_size_;
  int comm_batch_;
  std::vector<std::shared_ptr<ExecutorThreadWorker> > workers_;
  std::vector<std::thread> threads_;
151
  std::vector<std::string> inspect_var_names_;
W
wangguibao 已提交
152 153
  std::vector<std::string> model_param_names_;
  std::string model_prefix_;
W
wangguibao 已提交
154 155 156
  std::string model_path_;
  std::string init_prog_file_;
  std::string init_model_file_;
W
wangguibao 已提交
157 158
  Scope* root_scope_;
  platform::Place place_;
159 160 161 162 163 164 165 166

 private:
  ProgramDesc& main_program_;
  TextClassDataFeed& data_feed_;
  std::vector<float> inspect_values_;

 private:
  static bool workers_initialized_;
W
wangguibao 已提交
167 168 169 170 171 172
};

}  // namespace framework
}  // namespace paddle
#endif  // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */