ngraph_engine.h 4.2 KB
Newer Older
B
baojun 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

B
baojun 已提交
15 16
#pragma once

17 18
#include <memory>
#include <set>
B
baojun 已提交
19 20
#include <string>
#include <unordered_map>
21
#include <unordered_set>
B
baojun 已提交
22 23 24 25
#include <vector>

#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
26
#include "paddle/fluid/framework/var_desc.h"
B
baojun 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40

#include "ngraph/ngraph.hpp"

namespace paddle {
namespace operators {

enum class OpState {                /* nGraph support state on ops          */
                     FULL_TRAIN,    /* Support full ops for train           */
                     PARTIAL_TRAIN, /* Support partial ops for train        */
                     FULL_TEST,     /* Support full list of ops for test    */
                     PARTIAL_TEST,  /* Support partial list of ops for test */
                     UNKNOWN        /* Output all for debug purpose         */
};

41 42 43 44 45 46 47 48 49 50
// cache engine repetitives
struct EngineCache {
  std::shared_ptr<ngraph::Function> ngraph_function;
  std::set<std::string> persistables;
  std::vector<std::string> var_in;
  std::vector<std::string> var_out;
  std::vector<size_t> var_in_updates;
  bool is_test = true;
};

B
baojun 已提交
51 52 53 54 55
// perform graph build through bridge and execute computation
class NgraphEngine {
 public:
  explicit NgraphEngine(const framework::Scope& scope,
                        const platform::Place& place,
56
                        const framework::ExecutionContext& ctx);
B
baojun 已提交
57 58 59

  void Run(const framework::Scope& scope, const platform::Place& place) const;

60
  static bool is_training;
61 62 63 64 65 66
  static const framework::BlockDesc* p_bdesc;
  static std::vector<std::string> feed_vars, fetch_vars;

  static void FuseNgraphOps(
      const framework::BlockDesc& prog,
      std::vector<std::unique_ptr<framework::OperatorBase>>* ops);
B
baojun 已提交
67 68

 private:
69 70 71 72 73 74
  static std::unordered_map<std::string, EngineCache> engine_cache;
  static std::unordered_map<
      std::string, std::vector<std::shared_ptr<ngraph::runtime::Tensor>>>
      t_in_cache_;
  static framework::Variable* pre_var_ptr;

B
baojun 已提交
75 76 77 78
  const framework::Scope& scope_;
  const platform::Place& place_;
  std::vector<std::shared_ptr<framework::OperatorBase>> fused_ops_;
  std::unordered_map<std::string, ngraph::element::Type> var_type_map_;
79
  std::set<std::string> persistables_;
B
baojun 已提交
80
  std::unordered_set<std::string> post_op_inputs_;
81 82
  OpState op_state_ = OpState::UNKNOWN;
  bool is_test_{true};
B
baojun 已提交
83 84 85 86 87 88 89 90 91 92
  std::string func_cache_key_;

  // ngraph backend eg. CPU
  static std::shared_ptr<ngraph::runtime::Backend> backend_;
  // ngraph function to call and execute
  std::shared_ptr<ngraph::Function> ngraph_function_;
  // var_name of inputs
  std::vector<std::string> var_in_;
  // var_name of outputs from  fetch in order
  std::vector<std::string> var_out_;
93 94
  // non-persitable var_in
  std::vector<size_t> var_in_updates_;
B
baojun 已提交
95 96 97 98 99 100 101 102
  // map input vars to nodes
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_in_node_map_;
  // map each var name with a ngraph node
  std::shared_ptr<
      std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
      var_node_map_;
103
  // prepare info for ngraph engine need
104
  void Prepare(const framework::ExecutionContext& ctx);
105 106 107
  // get ngraph engine input and output list
  void BuildNgIO(const std::vector<framework::OpDesc*>& op_descs,
                 const std::vector<int>& interval);
B
baojun 已提交
108
  // get ngraph input and define ngraph input parameters
109
  void GetNgInputShape();
B
baojun 已提交
110 111 112
  // Call ngraph bridge to map ops
  void BuildNgNodes();
  // build ngraph function call
113
  void BuildNgFunction(const framework::ExecutionContext& ctx);
B
baojun 已提交
114
  // Check cache for ngraph function or otherwise build the function
115
  void GetNgFunction(const framework::ExecutionContext& ctx);
B
baojun 已提交
116 117 118 119
};

}  // namespace operators
}  // namespace paddle