ngraph_engine.h 3.7 KB
Newer Older
B
baojun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

B
baojun 已提交
15 16
#pragma once

17 18
#include <memory>
#include <set>
B
baojun 已提交
19 20
#include <string>
#include <unordered_map>
21
#include <unordered_set>
B
baojun 已提交
22 23 24 25
#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
26
#include "paddle/fluid/framework/var_desc.h"
B
baojun 已提交
27 28 29 30 31 32

#include "ngraph/ngraph.hpp"

namespace paddle {
namespace operators {

33 34
// cache engine repetitives
struct EngineCache {
35
  std::shared_ptr<ngraph::runtime::Executable> ngraph_handle;
36 37 38 39 40 41 42
  std::set<std::string> persistables;
  std::vector<std::string> var_in;
  std::vector<std::string> var_out;
  std::vector<size_t> var_in_updates;
  bool is_test = true;
};

B
baojun 已提交
43 44 45 46 47
// perform graph build through bridge and execute computation
class NgraphEngine {
 public:
  explicit NgraphEngine(const framework::Scope& scope,
                        const platform::Place& place,
48
                        const framework::ExecutionContext& ctx);
B
baojun 已提交
49 50 51

  void Run(const framework::Scope& scope, const platform::Place& place) const;

52
  static bool is_training;
53 54 55 56 57 58
  static const framework::BlockDesc* p_bdesc;
  static std::vector<std::string> feed_vars, fetch_vars;

  static void FuseNgraphOps(
      const framework::BlockDesc& prog,
      std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
B
baojun 已提交
59 60

 private:
61 62 63 64 65 66
  static std::unordered_map<std::string, EngineCache> engine_cache;
  static std::unordered_map<
      std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
      t_in_cache_;
  static framework::Variable* pre_var_ptr;

B
baojun 已提交
67 68 69 70
  const framework::Scope& scope_;
  const platform::Place& place_;
  std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
  std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
71
  std::set<std::string> persistables_;
B
baojun 已提交
72
  std::unordered_set<std::string> post_op_inputs_;
73
  bool is_test_{true};
B
baojun 已提交
74 75 76 77 78 79 80 81
  std::string func_cache_key_;

  // ngraph backend eg. CPU
  static std::shared_ptr<ngraph::runtime::Backend> backend_;
  // var_name of inputs
  std::vector<std::string> var_in_;
  // var_name of outputs from  fetch in order
  std::vector<std::string> var_out_;
82 83
  // non-persitable var_in
  std::vector<size_t> var_in_updates_;
B
baojun 已提交
84 85 86 87 88 89 90 91
  // map input vars to nodes
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_in_node_map_;
  // map each var name with a ngraph node
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_node_map_;
92
  // prepare info for ngraph engine need
93
  void Prepare(const framework::ExecutionContext& ctx);
94 95 96
  // get ngraph engine input and output list
  void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
                 const std::vector<int>& interval);
B
baojun 已提交
97
  // get ngraph input and define ngraph input parameters
98
  void GetNgInputShape();
B
baojun 已提交
99 100 101
  // Call ngraph bridge to map ops
  void BuildNgNodes();
  // build ngraph function call
102 103 104 105
  std::shared_ptr<ngraph::Function> BuildNgFunction(
      const framework::ExecutionContext& ctx);
  // clear ngraph engine cache and t_in cache
  void ClearNgCache();
B
baojun 已提交
106
  // Check cache for ngraph function or otherwise build the function
107
  void GetNgFunction(const framework::ExecutionContext& ctx);
B
baojun 已提交
108 109 110 111
};

}  // namespace operators
}  // namespace paddle