pass_utils.h 2.7 KB
Newer Older
L
lujiale 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
/**
 * Copyright 2019-2020 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef GE_GRAPH_PASSES_PASS_UTILS_H_
#define GE_GRAPH_PASSES_PASS_UTILS_H_

#include <vector>

#include "framework/common/debug/ge_log.h"
#include "common/ge_inner_error_codes.h"
#include "graph/compute_graph.h"

namespace ge {
class PassUtils {
 public:
  PassUtils() = delete;

  static NodePtr GetInDataNode(const ConstNodePtr &node, int index);

  static bool IsConstant(const ConstNodePtr &node);

  static Status SetOutNodeWeight(const OutDataAnchorPtr &out_data_anchor, const NodePtr &src_node);

  static Status RemoveBranch(const NodePtr &node, std::vector<NodePtr> &delete_nodes, std::vector<NodePtr> &end_nodes);

  static Status RemoveInactiveBranchToMerge(const OutDataAnchorPtr &inactive_output_anchor,
      std::vector<NodePtr> &delete_nodes, std::vector<NodePtr> &end_nodes);

  ///
  /// check is need iter flow ctrl.
  /// @param compute_graph graph
  /// @return true:need iter flow ctrl.
  ///         false:no need
  ///
  static bool IsNeedTrainIteFlowCtrl(const ComputeGraphPtr &compute_graph);
  /// Construct a TensorDesc and put the data in it, it's shape is a list.
  /// If the data length is 1, it's shape is []
  static Status ConstructTensorDescWithData(const GeTensorDesc &out_desc, std::vector<int64_t> &data,
                                            std::vector<GeTensorPtr> &v_output, const bool scalar_output = false);

  template <typename T>
  static Status ConstructTensorDescWithData(const GeTensorDesc &out_desc, T *buf, uint32_t len,
                                            std::vector<GeTensorPtr> &v_output, const bool scalar_output = false);
  ///
  /// find in data anchor index with a valid peer out node existed
  /// @param node_ptr
  /// @return index
  ///
  static int GetUniqueInDataAnchorIndex(const NodePtr &node_ptr);
  ///
  /// unlink node's in data anchors[index]'s father node with node itself
  /// then link father node's all in control nodes to node
  /// if any and not connected yet
  /// @param node
  /// @param index: in data anchor index
  /// @return
  ///
  static Status UnlinkNodeWithControlCopy(NodePtr &node, int index);
};
}  // namespace ge

#endif  // GE_GRAPH_PASSES_PASS_UTILS_H_