transop_without_reshape_fusion_pass.h 6.2 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
#define GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_

#include <utility>
#include <vector>

#include "inc/graph_pass.h"

namespace ge {
///
/// Transform operators depth fusion
///
class TransOpWithoutReshapeFusionPass : public GraphPass {
 public:
  TransOpWithoutReshapeFusionPass() {}
  virtual ~TransOpWithoutReshapeFusionPass() {}

  graphStatus Run(ge::ComputeGraphPtr graph) override;

 private:
  void SetRemainNode(const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor);
  bool FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor);
  void RemoveNousedNodes(const ComputeGraphPtr &graph);
  void GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc);

  void GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
                             GeTensorDesc &format_transfer_input, GeTensorDesc &format_transfer_output);

  void GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &cast_input,
                     GeTensorDesc &cast_output);

  graphStatus FormatFusion(const int index, OpDescPtr &format_transfer_op, int32_t &fusion_op_count,
                           bool &fusion_continue);

  graphStatus DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count);

  void GetOutDataPeerInControlAnchors(const size_t index,
                                      vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors);

  void GetInControlPeerOutControlAnchors(const size_t index,
                                         vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors);

  void GetOutControlPeerAnchors(const size_t index,
                                vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
                                vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors);

  graphStatus TransOpFuse(const ComputeGraphPtr &graph);

  bool OpAccuracyAbilityCheck(const OpDescPtr &op_desc);

  graphStatus GetSubGraphsBetweenNormalNode(
      const OutDataAnchorPtr &out_anchor, vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
      vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list);

  graphStatus GetSubGraphNodesInfo();

  void GetControlAnchors();

  graphStatus InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
                               const OpDescPtr &format_transfer_op, const int index, const bool insert_cast_first);

  void EraseInvalidAnchorsPair();

  graphStatus RelinkNodesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
                                            const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
                                            const int index);

  OpDescPtr GetFormatTransferOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);

  OpDescPtr GetCastOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);

  graphStatus TransOpFuseHandle(const ge::ComputeGraphPtr &graph, const int index);

  graphStatus AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node);

  bool DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const;

  bool ShapeEqualCheck(const GeShape &src, const GeShape &dst) const;

  bool InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const;

  graphStatus RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
                                const vector<NodePtr> &new_trans_nodes);

  graphStatus GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op,
                           const bool insert_cast_first, std::vector<NodePtr> &new_trans_nodes);

  void UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor,
                        const NodePtr &in_owner_node);
  void UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor,
                       const NodePtr &out_owner_node);

  graphStatus RelinkControlEdgesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
                                                   const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
                                                   const int index);

  graphStatus RelinkSubGraphControlEdges(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
                                         const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
                                         const int index);
  ///
  /// judge whether an operator is a transform op or not
  /// @param node
  /// @return True or False
  ///
  static bool IsTransOp(const NodePtr &node);

  static bool FusionFormatSupport(Format format);

  vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors_;
  vector<vector<NodePtr>> sub_graph_nodes_;
  vector<int> transop_num_count_;
  vector<bool> sub_graph_has_reshape_node_;
  vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors_;
  vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors_;
  vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors_;
  vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors_;
  vector<bool> sub_graph_has_control_edge_;
  vector<bool> sub_graph_has_out_data_peer_in_control_edge_;
};
}  // namespace ge

#endif  // GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_