async_executor.h 3.0 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

  http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
#define PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_

#include <memory>
#include <mutex>    // NOLINT
#include <set>
#include <map>
#include <string>
#include <thread>   // NOLINT
#include <vector>
25
#include <typeinfo>
W
wangguibao 已提交
26 27
#include "paddle/fluid/framework/data_feed.pb.h"
#include "paddle/fluid/framework/executor_thread_worker.h"
W
wangguibao 已提交
28 29 30 31 32 33 34
#include "paddle/fluid/framework/datafeed_creator.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"

namespace paddle {
namespace framework {
W
wangguibao 已提交
35
class AsyncExecutor {
W
wangguibao 已提交
36
 public:
W
wangguibao 已提交
37
  explicit AsyncExecutor(Scope& scope, const platform::Place& place);   // NOLINT
W
wangguibao 已提交
38
  virtual ~AsyncExecutor() {}
W
wangguibao 已提交
39
  static std::unique_ptr<ProgramDesc> LoadDescFromFile(
W
wangguibao 已提交
40
                                          const std::string& filename);
W
wangguibao 已提交
41
  Scope* GetRootScope() { return &root_scope_; }
W
wangguibao 已提交
42

W
wangguibao 已提交
43 44
  void SetModelPath(const std::string& model_path) {
    model_path_ = model_path;
W
wangguibao 已提交
45 46
  }

W
wangguibao 已提交
47 48 49 50 51 52
  void SetInitProgFile(const std::string& init_prog_file) {
    init_prog_file_ = init_prog_file;
  }

  void SetInitModelFile(const std::string& init_model_file) {
    init_model_file_ = init_model_file;
W
wangguibao 已提交
53 54
  }

W
wangguibao 已提交
55
  void SetModelPrefix(const std::string& model_prefix);
56
  void RunStartupProgram(const ProgramDesc& program, Scope* scope);
W
wangguibao 已提交
57
  std::vector<float> RunFromFile(const ProgramDesc& main_program,
58
                                  const std::string& data_feed_desc_str,
W
wangguibao 已提交
59 60 61
                                  const std::vector<std::string>& filelist,
                                  const int thread_num,
                                  const std::vector<std::string>& fetch_names);
W
wangguibao 已提交
62

W
wangguibao 已提交
63
  void CheckFiles(const std::vector<std::string>& files);
W
wangguibao 已提交
64 65
  void LoadInitModel();

66
 private:
W
wangguibao 已提交
67 68 69 70 71 72 73
  void CreateThreads(ExecutorThreadWorker* worker,
                     const ProgramDesc& main_program,
                     const std::shared_ptr<DataFeed>& reader,
                     const std::vector<std::string>& fetch_var_names,
                     Scope& root_scope,   // NOLINT
                     const int thread_index);

74

W
wangguibao 已提交
75
 public:
W
wangguibao 已提交
76
  std::string model_prefix_;
W
wangguibao 已提交
77 78 79
  std::string model_path_;
  std::string init_prog_file_;
  std::string init_model_file_;
W
wangguibao 已提交
80
  Scope& root_scope_;
W
wangguibao 已提交
81
  platform::Place place_;
W
wangguibao 已提交
82 83 84 85 86 87
};

}  // namespace framework
}  // namespace paddle
#endif  // PADDLE_FLUID_FRAMEWORK_ASYNC_EXECUTOR_H_
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */