async_executor.cc 17.8 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/async_executor.h"
#include <stdio.h>
#include <string.h>
#include <fcntl.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <map>
#include <algorithm>
#include "google/protobuf/message.h"
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"

#include "gflags/gflags.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"

namespace paddle {
namespace framework {
W
wangguibao 已提交
40 41 42 43 44 45 46
std::mutex ExecutorThreadWorker::s_locker_for_pick_file_;
unsigned int ExecutorThreadWorker::s_current_file_idx_ = 0;
size_t ExecutorThreadWorker::s_current_finished_file_cnt_ = 0;
unsigned int ExecutorThreadWorker::s_current_epoch_ = 0;
int ExecutorThreadWorker::s_current_save_epoch_ = 0;
bool ExecutorThreadWorker::s_is_first_worker_ = false;
std::vector<std::string> ExecutorThreadWorker::s_thread_filelist_;
W
wangguibao 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144

void CreateTensor(Variable* var, proto::VarType::Type var_type) {
  if (var_type == proto::VarType::LOD_TENSOR) {
    var->GetMutable<LoDTensor>();
  } else if (var_type == proto::VarType::SELECTED_ROWS) {
    var->GetMutable<SelectedRows>();
  } else if (var_type == proto::VarType::FEED_MINIBATCH) {
    var->GetMutable<FeedFetchList>();
  } else if (var_type == proto::VarType::FETCH_LIST) {
    var->GetMutable<FeedFetchList>();
  } else if (var_type == proto::VarType::STEP_SCOPES) {
    var->GetMutable<std::vector<Scope>>();
  } else if (var_type == proto::VarType::LOD_RANK_TABLE) {
    var->GetMutable<LoDRankTable>();
  } else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {
    var->GetMutable<LoDTensorArray>();
  } else if (var_type == proto::VarType::PLACE_LIST) {
    var->GetMutable<platform::PlaceList>();
  } else if (var_type == proto::VarType::READER) {
    var->GetMutable<ReaderHolder>();
  } else if (var_type == proto::VarType::RAW) {
    // GetMutable will be called in operator
  } else {
    PADDLE_THROW(
        "Variable type %d is not in "
        "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
        "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, RAW]",
        var_type);
  }
}

static void read_binary_file(const std::string& filename,
                             std::string* content) {
  std::string &contents = *content;
  std::ifstream fin(filename, std::ios::in | std::ios::binary);
  if (!fin.good()) {
    LOG(ERROR) << "Cannot open file " << filename.c_str();
  }
  fin.seekg(0, std::ios::end);
  contents.clear();
  contents.resize(fin.tellg());
  fin.seekg(0, std::ios::beg);
  fin.read(&contents[0], contents.size());
  fin.close();
}

static void save_model(
    const std::unique_ptr<ProgramDesc> & main_program,
    Scope* scope,
    const std::vector<std::string> & param_names,
    const std::string & model_name,
    bool save_combine) {
  auto place = platform::CPUPlace();
  const BlockDesc& global_block = main_program->Block(0);
  std::vector<std::string> paralist;

  for (auto* var : global_block.AllVars()) {
    bool is_model_param = false;
    for (auto param_name : param_names) {
      if (var->Name() == param_name) {
        is_model_param = true;
        break;
      }
    }

    if (!is_model_param)  continue;

    if (!save_combine) {
      LOG(ERROR) << "model var name: " << var->Name().c_str();

      paddle::framework::AttributeMap attrs;
      attrs.insert({"file_path", model_name + "/" + var->Name()});
      auto save_op = paddle::framework::OpRegistry::CreateOp(
                                                      "save",
                                                      {{"X", {var->Name()}}},
                                                      {},
                                                      attrs);

      save_op->Run(*scope, place);
    } else {
      paralist.push_back(var->Name());
    }
  }

  if (save_combine) {
    std::sort(paralist.begin(), paralist.end());
    paddle::framework::AttributeMap attrs;
    attrs.insert({"file_path", model_name});
    auto save_op = paddle::framework::OpRegistry::CreateOp(
                                                      "save_combine",
                                                      {{"X", paralist}},
                                                      {},
                                                      attrs);
    save_op->Run(*scope, place);
  }
}   // end save_model


W
wangguibao 已提交
145 146
void ExecutorThreadWorker::AddTrainFile(const std::string& file) {
  s_thread_filelist_.push_back(file);
W
wangguibao 已提交
147 148
}

W
wangguibao 已提交
149
void ExecutorThreadWorker::CreateThreadOperators(const ProgramDesc& program) {
W
wangguibao 已提交
150
  auto& block = program.Block(0);
W
wangguibao 已提交
151
  op_names_.clear();
W
wangguibao 已提交
152 153
  for (auto& op_desc : block.AllOps()) {
    std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
W
wangguibao 已提交
154
    op_names_.push_back(op_desc->Type());
W
wangguibao 已提交
155
    OperatorBase* local_op_ptr = local_op.release();
W
wangguibao 已提交
156
    ops_.push_back(local_op_ptr);
W
wangguibao 已提交
157 158 159 160
    continue;
  }
}

W
wangguibao 已提交
161
void ExecutorThreadWorker::CreateThreadScope(const ProgramDesc& program) {
W
wangguibao 已提交
162
  auto& block = program.Block(0);
W
wangguibao 已提交
163
  thread_scope_ = &root_scope_->NewScope();
W
wangguibao 已提交
164 165
  for (auto& var : block.AllVars()) {
    if (var->Persistable()) {
W
wangguibao 已提交
166
      auto* ptr = root_scope_->Var(var->Name());
W
wangguibao 已提交
167 168 169 170
      CreateTensor(ptr, var->GetType());
      // LOGERR("create Persistable var[%s] finished",
      //      var->Name().c_str());
    } else {
W
wangguibao 已提交
171
      auto* ptr = thread_scope_->Var(var->Name());
W
wangguibao 已提交
172 173 174 175 176 177 178
      CreateTensor(ptr, var->GetType());
      // LOGERR("create unpersistable var[%s] finished",
      //      var->Name().c_str());
    }
  }
}

W
wangguibao 已提交
179 180
void ExecutorThreadWorker::SetDataFeed(const std::shared_ptr<DataFeed>& datafeed) {
  local_reader_ = datafeed;
W
wangguibao 已提交
181 182
}

W
wangguibao 已提交
183 184
void ExecutorThreadWorker::BindingDataFeedMemory() {
  const std::vector<std::string>& input_feed = local_reader_->GetUseSlotAlias();
W
wangguibao 已提交
185
  for (auto name : input_feed) {
W
wangguibao 已提交
186
    local_reader_->AddFeedVar(thread_scope_->Var(name), name);
W
wangguibao 已提交
187 188 189
  }
}

W
wangguibao 已提交
190
void ExecutorThreadWorker::SetInspectVarName(
W
wangguibao 已提交
191
    const std::string& inspect_var_name) {
W
wangguibao 已提交
192
  inspect_var_name_ = inspect_var_name;
W
wangguibao 已提交
193 194
}

W
wangguibao 已提交
195
void ExecutorThreadWorker::SetModelParamNames(
W
wangguibao 已提交
196
    const std::vector<std::string>& param_names) {
W
wangguibao 已提交
197
  model_param_names_ = param_names;
W
wangguibao 已提交
198 199
}

W
wangguibao 已提交
200
void ExecutorThreadWorker::SetSparseCommData(
W
wangguibao 已提交
201
    const std::map<std::string, int>& param_names) {
W
wangguibao 已提交
202
  sparse_comm_data_ = param_names;
W
wangguibao 已提交
203 204
}

W
wangguibao 已提交
205
void ExecutorThreadWorker::SetDevice() {
W
wangguibao 已提交
206 207 208 209 210 211 212 213 214 215 216
  static unsigned priority[] = {
    0, 1, 2, 3, 4, 5,
    6, 7, 8, 9, 10, 11,
    12, 13, 14, 15, 16, 17,
    18, 19, 20, 21, 22, 23,
    24, 25, 26, 27, 28, 29,
    30, 31, 32, 33, 34, 35,
    36, 37, 38, 39, 40, 41,
    42, 43, 44, 45, 46, 47
  };

W
wangguibao 已提交
217
  unsigned int i = this->thread_id_;
W
wangguibao 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237

  if (i < sizeof(priority) / sizeof(unsigned)) {
    unsigned proc = priority[i];

    cpu_set_t mask;
    CPU_ZERO(&mask);
    CPU_SET(proc, &mask);

    if (-1 == sched_setaffinity(0, sizeof(mask), &mask)) {
      LOG(ERROR) << "WARNING: Failed to set thread affinity for thread " << i;
    } else {
      CPU_ZERO(&mask);
      if ((0 == sched_getaffinity(0, sizeof(mask), &mask))
          && CPU_ISSET(proc, &mask)) {
        LOG(ERROR) << "TRACE: Thread " << i << " is running on processor " << proc << "...";
      }
    }
  }
}

W
wangguibao 已提交
238 239
void ExecutorThreadWorker::UpdateEpochNum() {
  s_current_finished_file_cnt_++;
W
wangguibao 已提交
240

W
wangguibao 已提交
241 242 243
  if (s_current_finished_file_cnt_ >= s_thread_filelist_.size()) {
    s_current_finished_file_cnt_ = 0;
    s_current_epoch_++;
W
wangguibao 已提交
244 245 246
  }
}

W
wangguibao 已提交
247
const char* ExecutorThreadWorker::PickOneFile() {
W
wangguibao 已提交
248
  std::string file_to_be_preocessed;
W
wangguibao 已提交
249 250 251 252 253 254 255 256 257
  std::lock_guard<std::mutex> lock(s_locker_for_pick_file_);

  if (s_current_file_idx_ >= s_thread_filelist_.size()) {
    std::random_shuffle(s_thread_filelist_.begin(),
    s_thread_filelist_.end());
    s_current_file_idx_ = 0;
    // s_current_epoch_++; //example: when one file, one thread, it's bug
    LOG(ERROR) << "thread " << thread_id_
               << ": finish traing for epoch " << s_current_epoch_ + 1;
W
wangguibao 已提交
258
  }
W
wangguibao 已提交
259
  file_to_be_preocessed = s_thread_filelist_[s_current_file_idx_];
W
wangguibao 已提交
260

W
wangguibao 已提交
261
  s_current_file_idx_++;
W
wangguibao 已提交
262 263 264
  return file_to_be_preocessed.c_str();
}

W
wangguibao 已提交
265
void ExecutorThreadWorker::Train() {
W
wangguibao 已提交
266
  LOG(ERROR) << "begin to train";
W
wangguibao 已提交
267
  SetDevice();
W
wangguibao 已提交
268 269 270 271
#ifdef LOCAL_PROF
  std::vector<double> op_total_time;
  std::vector<std::string> op_name;
  // int total_batch = 0;
W
wangguibao 已提交
272
  for (auto& op : ops_) {
W
wangguibao 已提交
273 274
    op_name.push_back(op->Type());
  }
W
wangguibao 已提交
275
  op_total_time.resize(ops_.size());
W
wangguibao 已提交
276 277 278 279 280
  for (int i = 0; i < op_total_time.size(); ++i) {
    op_total_time[i] = 0.0;
  }
#endif
  std::string inspect_key = "inspect";
W
wangguibao 已提交
281 282 283
  if (!inspect_var_name_.empty()) {
    inspect_key = inspect_var_name_.substr(0,
                                          inspect_var_name_.find_first_of('_'));
W
wangguibao 已提交
284 285
  }

W
wangguibao 已提交
286
  for (unsigned i = 0; i < max_epoch_; ++i) {
W
wangguibao 已提交
287 288 289 290 291 292 293 294
    LOG(ERROR) << "epoch: " << i;
#ifdef LOCAL_PROF
    Timer timeline;
    double total_time = 0.0;
    double read_time = 0.0;
#endif
    float total_inspect = 0;
    int batch_num = 1;
W
wangguibao 已提交
295 296 297
    while (i == s_current_epoch_) {
      const char* filename = PickOneFile();
      local_reader_->SetFile(filename);
W
wangguibao 已提交
298 299 300 301
      while (true) {
#ifdef LOCAL_PROF
        timeline.start();
#endif
W
wangguibao 已提交
302
        bool flag = local_reader_->ReadBatch();
W
wangguibao 已提交
303 304 305 306 307 308 309 310 311 312 313 314
        if (!flag) {
          break;
        }
#ifdef LOCAL_PROF
        timeline.pause();
        read_time += timeline.elapsed_sec();
        total_time += timeline.elapsed_sec();
#endif
        if (!flag) {
          break;
        }

W
wangguibao 已提交
315
        for (unsigned int i = 0; i < ops_.size(); ++i) {
W
wangguibao 已提交
316 317 318
#ifdef LOCAL_PROF
          timeline.start();
#endif
W
wangguibao 已提交
319
          ops_[i]->Run(*thread_scope_, place_);
W
wangguibao 已提交
320 321 322 323 324 325 326 327
#ifdef LOCAL_PROF
          timeline.pause();
          op_total_time[i] += timeline.elapsed_sec();
          total_time += timeline.elapsed_sec();
#endif
        }
        batch_num++;
        float avg_inspect = 0.0;
W
wangguibao 已提交
328 329
        if (!inspect_var_name_.empty()) {
          avg_inspect = thread_scope_->FindVar(inspect_var_name_)
W
wangguibao 已提交
330 331 332 333
                                     ->GetMutable<LoDTensor>()
                                     ->data<float>()[0];
        }
        total_inspect += avg_inspect;
W
wangguibao 已提交
334
        thread_scope_->DropKids();
W
wangguibao 已提交
335
      }
W
wangguibao 已提交
336
      UpdateEpochNum();
W
wangguibao 已提交
337
      LOG(ERROR) << "memory used after epoch " << i + 1
W
wangguibao 已提交
338
                 << " called: " << memory::memory_usage(place_);
W
wangguibao 已提交
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356

#ifdef LOCAL_PROF
      for (int i = 0; i < op_total_time.size(); ++i) {
        std::cerr << "op_name:[" << i << "][" << op_name[i] << "]"
                  << " op_mean_time:[" << op_total_time[i] << "s]"
                  << std::endl;
      }
      std::cerr << "read time: " << read_time << "s" << std::endl;
#endif
    }
#ifdef LOCAL_PROF
    LOG(ERROR) << "mean " << inspect_key.c_str()
               << " of epoch " << i + 1 << ": " << total_inspect / batch_num
               << ", total_time: " << total_time;
#else
    LOG(ERROR) << "mean " << inspect_key.c_str()
               << " of epoch " << i + 1 << ": " << total_inspect / batch_num;
#endif
W
wangguibao 已提交
357
    if (thread_id_ == 0) {
W
wangguibao 已提交
358 359 360 361
      char modelfile[1024];
      snprintf(&modelfile[0],
              sizeof(modelfile),
              "%s_epoch%d.model",
W
wangguibao 已提交
362
              model_prefix_.c_str(),
W
wangguibao 已提交
363 364 365 366 367 368 369
              i);
      std::string model_filename = std::string(modelfile);
      // this save_inference_model can only save imdbtask, should make this
      // general
      //
      // currently comment it
      LOG(ERROR) << "Going to save model " << modelfile;
W
wangguibao 已提交
370 371 372
      save_model(main_program_,
          thread_scope_,
          model_param_names_,
W
wangguibao 已提交
373 374 375 376 377 378
          model_filename,
          true);
    }
  }
}

W
wangguibao 已提交
379 380
void ExecutorThreadWorker::SetThreadId(int tid) {
  thread_id_ = tid;
W
wangguibao 已提交
381 382
}

W
wangguibao 已提交
383 384
void ExecutorThreadWorker::SetPlace(const platform::Place& place) {
  place_ = place;
W
wangguibao 已提交
385 386
}

W
wangguibao 已提交
387
void ExecutorThreadWorker::SetMainProgram(
W
wangguibao 已提交
388
    const ProgramDesc& main_program_desc) {
W
wangguibao 已提交
389
  main_program_.reset(new ProgramDesc(main_program_desc));
W
wangguibao 已提交
390 391
}

W
wangguibao 已提交
392 393
void ExecutorThreadWorker::SetRootScope(Scope* g_scope) {
  root_scope_ = g_scope;
W
wangguibao 已提交
394 395
}

W
wangguibao 已提交
396 397
void ExecutorThreadWorker::SetMaxTrainingEpoch(int max_epoch) {
  max_epoch_ = max_epoch;
W
wangguibao 已提交
398 399
}

W
wangguibao 已提交
400
MultiExecutor::MultiExecutor(const platform::Place& place) : place_(place) {}
W
wangguibao 已提交
401

W
wangguibao 已提交
402 403
void MultiExecutor::InitRootScope(Scope* scope) {
  root_scope_ = scope;
W
wangguibao 已提交
404 405
}

W
wangguibao 已提交
406 407
void MultiExecutor::SetMaxTrainingEpoch(int max_epoch) {
  max_epoch_ = max_epoch;
W
wangguibao 已提交
408 409
}

W
wangguibao 已提交
410 411
void MultiExecutor::SetDataFeedName(const char* feedname) {
  feed_name_ = std::string(feedname);
W
wangguibao 已提交
412 413
}

W
wangguibao 已提交
414 415
void MultiExecutor::SetModelPrefix(const std::string& model_prefix) {
  model_prefix_ = model_prefix;
W
wangguibao 已提交
416 417
}

W
wangguibao 已提交
418
void MultiExecutor::RunStartupProgram(const ProgramDesc& program,
W
wangguibao 已提交
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
                                        Scope* scope) {
  auto& block = program.Block(0);
  for (auto& var : block.AllVars()) {
    if (var->Persistable()) {
      auto* ptr = scope->Var(var->Name());
      CreateTensor(ptr, var->GetType());
      // LOGERR("Persistable Var Name:%s", var->Name().c_str());
    }
  }

  std::map<std::string, int> param_dict;
  std::vector<OperatorBase *> ops;
  for (auto& op_desc : block.AllOps()) {
    std::vector<std::string> param_name_vec = op_desc->OutputArgumentNames();
    bool need_to_run = false;
    for (auto& name : param_name_vec) {
      if (param_dict.find(name) == param_dict.end()) {
        param_dict[name] = 1;
        need_to_run = true;
      }
    }
    if (need_to_run) {
      std::unique_ptr<OperatorBase> local_op = OpRegistry::CreateOp(*op_desc);
      OperatorBase* local_op_ptr = local_op.release();
      ops.push_back(local_op_ptr);
    }
  }
  // LOGERR("There are %d parameters in startup program, %d op needs to run",
  //        param_dict.size(), ops.size());

  for (auto& op : ops) {
W
wangguibao 已提交
450
    op->Run(*scope, place_);
W
wangguibao 已提交
451 452 453 454 455 456 457 458
  }
  // LOGERR("total time for startup program: %fs", timeline.elapsed_sec());
  for (auto& op : ops) {
    delete op;
  }
  // LOGERR("run startup program done.");
}

W
wangguibao 已提交
459
std::unique_ptr<ProgramDesc> MultiExecutor::LoadDescFromFile(
W
wangguibao 已提交
460 461 462 463 464 465 466
    const std::string& f) {
  std::string program_desc_str;
  read_binary_file(f, &program_desc_str);
  std::unique_ptr<ProgramDesc> program(new ProgramDesc(program_desc_str));
  return program;
}

W
wangguibao 已提交
467
void MultiExecutor::SetDenseCommTensor(
W
wangguibao 已提交
468
    const std::vector<std::string>& dense_comm_tensor) {
W
wangguibao 已提交
469
  dense_comm_tensor_.resize(dense_comm_tensor.size());
W
wangguibao 已提交
470
  for (unsigned int i = 0; i < dense_comm_tensor.size(); ++i) {
W
wangguibao 已提交
471
    dense_comm_tensor_[i] = dense_comm_tensor[i];
W
wangguibao 已提交
472 473 474
  }
}

W
wangguibao 已提交
475
void MultiExecutor::SetSparseCommTensor(
W
wangguibao 已提交
476
    const std::vector<std::string>& sparse_comm_tensor) {
W
wangguibao 已提交
477
  sparse_comm_tensor_.resize(sparse_comm_tensor.size());
W
wangguibao 已提交
478
  for (unsigned int i = 0; i < sparse_comm_tensor.size(); ++i) {
W
wangguibao 已提交
479
    sparse_comm_tensor_[i] = sparse_comm_tensor[i];
W
wangguibao 已提交
480 481 482
  }
}

W
wangguibao 已提交
483
void MultiExecutor::SetSparseCommData(
W
wangguibao 已提交
484
    const std::map<std::string, int>& sparse_comm_data) {
W
wangguibao 已提交
485 486
  sparse_comm_data_ = sparse_comm_data;
  LOG(INFO) << "Sparse comm data: " << sparse_comm_data_.size();
W
wangguibao 已提交
487 488
}

W
wangguibao 已提交
489 490
void MultiExecutor::SetFileList(const char* filelist) {
  filelist_.clear();
W
wangguibao 已提交
491 492 493 494
  std::ifstream fin(filelist);
  std::string filename;
  while (fin >> filename) {
    LOG(ERROR) << "add " << filename.c_str() << " to filelist";
W
wangguibao 已提交
495
    filelist_.push_back(filename);
W
wangguibao 已提交
496 497 498 499
  }
  fin.close();
}

W
wangguibao 已提交
500 501 502
void MultiExecutor::SetFileList(std::vector<std::string> tfiles) {
  filelist_.clear();
  filelist_.insert(filelist_.end(), tfiles.begin(), tfiles.end());
W
wangguibao 已提交
503 504 505
  return;
}

W
wangguibao 已提交
506 507
void MultiExecutor::SetInspectVarName(const std::string& inspect_var_name) {
  inspect_var_name_ = inspect_var_name;
W
wangguibao 已提交
508 509
}

W
wangguibao 已提交
510 511
void MultiExecutor::SetParamNames(const std::vector<std::string>& param_names) {
  model_param_names_ = param_names;
W
wangguibao 已提交
512 513
}

W
wangguibao 已提交
514 515
void MultiExecutor::SetThreadNum(const int thread_num) {
  thread_num_ = thread_num;
W
wangguibao 已提交
516 517
}

W
wangguibao 已提交
518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
void MultiExecutor::PrepareThreads(const ProgramDesc& host_program) {
  workers_.resize(thread_num_);
  for (unsigned i = 0; i < thread_num_; ++i) {
    workers_[i].reset(new ExecutorThreadWorker);
    workers_[i]->SetThreadId(i);
    workers_[i]->CreateThreadOperators(host_program);
    workers_[i]->SetRootScope(root_scope_);
    workers_[i]->SetPlace(place_);
    workers_[i]->SetMaxTrainingEpoch(max_epoch_);
    workers_[i]->CreateThreadScope(host_program);
    workers_[i]->SetInspectVarName(inspect_var_name_);
    workers_[i]->SetModelParamNames(model_param_names_);
    workers_[i]->SetSparseCommData(sparse_comm_data_);
    workers_[i]->SetMainProgram(host_program);
    workers_[i]->SetModelPrefix(model_prefix_);
W
wangguibao 已提交
533 534
  }

W
wangguibao 已提交
535
  for (unsigned i = 0; i < filelist_.size(); ++i) {
W
wangguibao 已提交
536 537
    // suppose at least one trainer thread here, and
    // filelist is static so that we only add filelist once
W
wangguibao 已提交
538
    workers_[0]->AddTrainFile(filelist_[i]);
W
wangguibao 已提交
539 540
  }
  // mpi_wrapper::ModelParam model_param(true);
W
wangguibao 已提交
541
  // workers_[0]->register_parallel_training_param(model_param);
W
wangguibao 已提交
542

W
wangguibao 已提交
543
  for (unsigned i = 0; i < thread_num_; ++i) {
W
wangguibao 已提交
544
    // new a datafeed here
W
wangguibao 已提交
545 546 547 548 549 550
    std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str());
    local_feed->Init(data_feed_param_);
    local_feed->SetBatchSize(batch_size_);
    workers_[i]->SetDataFeed(local_feed);
    workers_[i]->BindingDataFeedMemory();
    workers_[i]->SetThreadId(i);
W
wangguibao 已提交
551 552 553
  }
}

W
wangguibao 已提交
554
void MultiExecutor::RunMultiExecutor(const ProgramDesc& host_program) {
W
wangguibao 已提交
555
  // thread binding here?
W
wangguibao 已提交
556 557 558 559
  PrepareThreads(host_program);
  for (unsigned i = 0; i < thread_num_; ++i) {
    threads_.push_back(std::thread(&ExecutorThreadWorker::Train,
                      workers_[i].get()));
W
wangguibao 已提交
560 561
  }

W
wangguibao 已提交
562
  for (auto& th : threads_) {
W
wangguibao 已提交
563 564 565 566 567 568 569 570
    th.join();
  }
}

}   // end namespace framework
}   // end namespace paddle

/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */