cinn_graph_symbolization.h 4.9 KB
Newer Older
J
jiangcheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <map>
18
#include <string>
J
jiangcheng 已提交
19 20
#include <unordered_map>
#include <unordered_set>
21
#include <vector>
J
jiangcheng 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"

#include "cinn/frontend/net_builder.h"
#include "cinn/frontend/op_mapper_registry.h"

namespace paddle {
namespace framework {
namespace paddle2cinn {

// An executor accept subgraph which is generated by BuildCinnPass,
// run each op's CINN Op Mapper, finally return a frontend::Program object
// corresponding to the subgraph.
//
// Parameter:
// 1. graph_id:
//       the unique graph id, used for generating unique NetBuilder name.
// 2. graph:
//       the CINN subgraph whose op are all supported by CINN, and the
//       graph is independently of other graph.
// 3. input_tensors:
//       all input var nodes of CINN subgraph, they are necessary for
//       we need pass the shape and data type into CINN, otherwise the
//       NetBuilder may error for the shape not meet the precondition.
//
// Describe:
// The main function is operator(), it will run all op function by CINN
// OpMapper and finally return a program object.
// The executor operator() consisted by the following step:
// 1. create a NetBuilder, it's name is unique for each graph;
// 2. create OpMapperContext, contain scope, target, local var_map and
//    local var_model_to_program_map;
// 3. add all feed var into OpMapperContext to pass the shape and type
//    into CINN;
// 4. topological sorting graph op nodes;
// 5. transform all op from paddle opdesc format to cinn opdesc format;
// 5. run the CINN op in graph one by one. Note that the graph have been
//    topo sorted;
// 6. return the NetBuilder.Build() after all op run.
class CinnGraphSymbolization {
 public:
  CinnGraphSymbolization(
      int64_t graph_id, const ir::Graph& graph,
      const ::cinn::common::Target& target,
      const std::map<std::string, const LoDTensor*>& input_tensors)
      : graph_id_(graph_id),
        graph_(graph),
        target_(target),
        input_tensors_(input_tensors) {}

  // run all CINN op in graph by topo sorting then return its NetBuilder
  ::cinn::frontend::Program operator()();

  // return the internal variable map
  const std::unordered_map<std::string, ::cinn::frontend::Variable>& var_map()
      const {
    return var_map_;
  }

  // return the map from the variable name in paddle model to cinn program.
  const std::unordered_map<std::string, std::string>& var_model_to_program_map()
      const {
    return var_model_to_program_map_;
  }

89 90 91
  // get fetch var ids used in CINN
  std::unordered_set<std::string> GetFetchIds() const;

J
jiangcheng 已提交
92 93 94 95 96 97 98 99 100 101 102
  using OpMapperContext = ::cinn::frontend::OpMapperContext;
  using FeedInfoMap =
      std::unordered_map<std::string, OpMapperContext::FeedInfo>;
  using CinnOpDesc = ::cinn::frontend::paddle::cpp::OpDesc;

 private:
  const int64_t graph_id_;
  const ir::Graph& graph_;
  const ::cinn::common::Target& target_;
  const std::map<std::string, const LoDTensor*>& input_tensors_;

103
  // preserve cinn variable map
J
jiangcheng 已提交
104 105 106
  std::unordered_map<std::string, ::cinn::frontend::Variable> var_map_;
  std::unordered_map<std::string, std::string> var_model_to_program_map_;

107 108 109
  // fetch var names used in paddle
  std::unordered_set<std::string> fetch_var_names_;

J
jiangcheng 已提交
110 111 112
  // transform all paddle var desc in feed list into cinn_var_descs_
  FeedInfoMap GetFeedInfoMapFromInput() const;

113 114 115
  // get the topological sort of the graph_
  std::vector<ir::Node*> TopologicalSort() const;

J
jiangcheng 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128
  // transform all paddle op desc in graph into cinn op desc
  std::vector<std::unique_ptr<CinnOpDesc>> TransformAllGraphOpToCinn() const;

  // RunOp accept OpDesc and global run context then run
  // it's kernel registered in OpMapper.
  // called in RunGraph.
  void RunOp(const CinnOpDesc& op_desc, const OpMapperContext& ctx) const;

  // preserve var desc, run the op one by one.
  void RunGraph(const OpMapperContext& ctx) const;

  // create cinn scope and add parameter's feed info into scope
  std::shared_ptr<::cinn::hlir::framework::Scope> CreateCinnScope(
129
      const FeedInfoMap& feed_map);
J
jiangcheng 已提交
130 131 132 133 134 135 136 137 138 139

  // get the graph op's input persistable var name set
  std::unordered_set<std::string> GetGraphInputParameterNames() const;

  friend class CinnGraphSymbolizationForTest;
};

}  // namespace paddle2cinn
}  // namespace framework
}  // namespace paddle