ngraph_engine.h 6.1 KB
Newer Older
B
baojun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

B
baojun 已提交
15 16
#pragma once

17
#include <list>
18 19
#include <memory>
#include <set>
B
baojun 已提交
20 21
#include <string>
#include <unordered_map>
22
#include <unordered_set>
23
#include <utility>
B
baojun 已提交
24 25 26 27
#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
28
#include "paddle/fluid/framework/var_desc.h"
B
baojun 已提交
29 30 31 32 33 34

#include "ngraph/ngraph.hpp"

namespace paddle {
namespace operators {

35 36
// cache engine repetitives
struct EngineCache {
37
  std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
38 39 40 41 42 43 44
  std::set<std::string> persistables;
  std::vector<std::string> var_in;
  std::vector<std::string> var_out;
  std::vector<size_t> var_in_updates;
  bool is_test = true;
};

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
template <class T, class Engine, int separator = 0>
class NgraphThreadCache {
 public:
  typedef decltype(Engine::getMutex()) mutex_type;
  typedef std::lock_guard<mutex_type> guard_type;
  typedef T& ref_type;
  enum class type_of_thread { unknown, forward, backward };

  template <class S>
  struct MetaInfo {
    std::thread::id owner_tid;   // owner of the cache, future use;
    type_of_thread worker_type;  // future use
    S real_content;
    MetaInfo()
        : owner_tid{std::this_thread::get_id()},
          worker_type{type_of_thread::unknown} {}
  };

  typedef std::unique_ptr<MetaInfo<T>> content_type;
  typedef std::list<content_type> storage_type;

 protected:
  static storage_type l;
  static mutex_type getMutex() { return Engine::getMutex(); }
  static void remove_from_list(const T* raw_ptr) {
    guard_type guard(getMutex());
    l.remove_if([raw_ptr](const content_type& sh) {
      return &(sh->real_content) == raw_ptr;
    });
  }

  template <class TRaw>
  struct TLSDescriptor {
    TRaw* raw_ptr;
    TLSDescriptor() : raw_ptr{nullptr} {}
    ~TLSDescriptor() {
      // if thread die
      NgraphThreadCache::remove_from_list(raw_ptr);

      /* TODO : Parallel executor swap */
      // FastMultiThreadCache::keep_alive_for_backward_thread(raw_ptr);
    }
  };

 public:
  NgraphThreadCache() = delete;
  NgraphThreadCache(const NgraphThreadCache& copy) = delete;

  static T& fetch() {
    thread_local TLSDescriptor<T> tls;
    if (!tls.raw_ptr) {
      using elem_type = typename content_type::element_type;
      content_type _p(new elem_type());
      if (!_p) PADDLE_THROW("Cannot alloc memory for thread-cache ");
      guard_type guard(getMutex());
      l.push_back(std::move(_p));
      tls.raw_ptr = &l.back()->real_content;
    }
    return *(tls.raw_ptr);
  }
  auto getSize() -> decltype(l.size()) {
    guard_type guard(getMutex());
    return l.size();
  }

  template <class F>
  void for_each_cache(F f) {
    guard_type guard(getMutex());
    std::for_each(l.begin(), l.end(), f);
  }
};

template <class T, class Engine, int separator>
typename NgraphThreadCache<T, Engine, separator>::storage_type
    NgraphThreadCache<T, Engine, separator>::l;

B
baojun 已提交
121 122 123 124 125
// perform graph build through bridge and execute computation
class NgraphEngine {
 public:
  explicit NgraphEngine(const framework::Scope& scope,
                        const platform::Place& place,
126
                        const framework::ExecutionContext& ctx);
B
baojun 已提交
127 128 129

  void Run(const framework::Scope& scope, const platform::Place& place) const;

130
  static bool is_training;
131 132 133 134 135 136
  static const framework::BlockDesc* p_bdesc;
  static std::vector<std::string> feed_vars, fetch_vars;

  static void FuseNgraphOps(
      const framework::BlockDesc& prog,
      std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
B
baojun 已提交
137

138 139 140 141 142
  static std::recursive_mutex& getMutex() {
    static std::recursive_mutex mx;
    return mx;
  }

B
baojun 已提交
143
 private:
144 145 146 147 148 149 150 151
  template <class T>
  using ThCache =
      NgraphThreadCache<std::unordered_map<std::string, T>, NgraphEngine>;

  using main_engine_cache = ThCache<EngineCache>;
  using main_t_in_cache =
      ThCache<std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>;

152 153
  static framework::Variable* pre_var_ptr;

B
baojun 已提交
154 155 156 157
  const framework::Scope& scope_;
  const platform::Place& place_;
  std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
  std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
158
  std::set<std::string> persistables_;
B
baojun 已提交
159
  std::unordered_set<std::string> post_op_inputs_;
160
  bool is_test_{true};
B
baojun 已提交
161 162 163 164 165 166 167 168
  std::string func_cache_key_;

  // ngraph backend eg. CPU
  static std::shared_ptr<ngraph::runtime::Backend> backend_;
  // var_name of inputs
  std::vector<std::string> var_in_;
  // var_name of outputs from  fetch in order
  std::vector<std::string> var_out_;
169 170
  // non-persitable var_in
  std::vector<size_t> var_in_updates_;
B
baojun 已提交
171 172 173 174 175 176 177 178
  // map input vars to nodes
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_in_node_map_;
  // map each var name with a ngraph node
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_node_map_;
179
  // prepare info for ngraph engine need
180
  void Prepare(const framework::ExecutionContext& ctx);
181 182 183
  // get ngraph engine input and output list
  void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
                 const std::vector<int>& interval);
B
baojun 已提交
184
  // get ngraph input and define ngraph input parameters
185
  void GetNgInputShape();
B
baojun 已提交
186 187 188
  // Call ngraph bridge to map ops
  void BuildNgNodes();
  // build ngraph function call
189 190 191 192
  std::shared_ptr<ngraph::Function> BuildNgFunction(
      const framework::ExecutionContext& ctx);
  // clear ngraph engine cache and t_in cache
  void ClearNgCache();
B
baojun 已提交
193
  // Check cache for ngraph function or otherwise build the function
194
  void GetNgFunction(const framework::ExecutionContext& ctx);
B
baojun 已提交
195 196 197 198
};

}  // namespace operators
}  // namespace paddle