op_lowering_util.h 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <queue>

#include "paddle/cinn/hlir/framework/op_lowering.h"

namespace cinn {
namespace hlir {
namespace framework {

std::vector<NodeData*> GetInputNodeData(const Node* node);

27 28 29 30 31 32 33 34
ir::Tensor GetTensor(
    const NodeData* node_data,
    const absl::flat_hash_map<std::string, Type>& type_dict,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict);

std::vector<ir::Tensor> CollectInputTensor(
    const Node* node,
    const absl::flat_hash_map<std::string, Type>& type_dict,
35 36 37
    const absl::flat_hash_map<std::string, shape_t>& shape_dict,
    std::vector<ir::Tensor>* func_args,
    std::unordered_map<std::string, ir::Tensor>* tensor_map);
38 39 40 41

std::unordered_map<Node*, Node*> BuildVirtualConsumer(
    const GroupPtr& group,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict);
42 43 44 45 46 47 48 49 50

NodeData* GetNodeData(const Node* node);

std::vector<NodeData*> GetAllNodeData(const Node* node);

std::vector<Node*> GetConsumers(const Node* node);

bool IsConstOp(const framework::Node* node);

51 52
std::vector<Node*> GetConsumersInSet(const Node* node,
                                     const std::unordered_set<Node*>& node_set);
53

54 55 56
std::vector<Node*> TopologicalOrder(
    const GroupPtr& group,
    const std::unordered_map<Node*, Node*>& virtual_consumers);
57

58 59 60 61
std::vector<Node*> BFSTopologicalOrderWithPriority(
    const GroupPtr& group,
    const std::unordered_map<Node*, Node*>& virtual_consumers,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict);
62 63 64

Node* FindGlobalReducer(const std::vector<Node*>& nodes_in_order);

65 66
Node* FindNearestReducer(const Node* node,
                         const std::unordered_set<Node*>& nodes_set);
67 68 69 70 71 72 73 74 75

bool CanbeInline(Node* node,
                 const std::vector<Node*> consumers,
                 const Node* reducer,
                 const std::unordered_set<Node*> masters,
                 const GroupPtr& group,
                 const std::unordered_set<Node*>& nodes_set,
                 const absl::flat_hash_map<std::string, shape_t>& shape_dict);

76 77 78 79 80 81 82 83 84 85 86 87 88 89
Node* GetMasterToComputeAt(
    Node* node,
    const std::vector<Node*>& nodes_in_order,
    const std::unordered_set<Node*>& nodes_inline,
    const std::unordered_set<Node*>& nodes_set,
    const std::unordered_map<Node*, Node*>& virtual_consumers,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict);

std::unordered_set<Node*> GetMasters(
    Node* node,
    const std::unordered_set<Node*>& nodes_inline,
    const std::unordered_set<Node*>& nodes_set);

void LoopAssignReduce(
90
    ir::IRSchedule& ir_sch,  // NOLINT
91 92 93 94 95 96 97
    const Node* node,
    const Node* reducer,
    const Target& target,
    const std::unordered_map<std::string, ir::Tensor>& tensor_map,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict);

void LoopComputeAt(
98
    ir::IRSchedule& ir_sch,  // NOLINT
99 100 101 102 103 104 105
    Node* node,
    const Node* master,
    const GroupPtr& group,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict,
    const std::unordered_map<std::string, ir::Tensor>& tensor_map);

void SyncThreadWithShared(
106
    ir::IRSchedule& ir_sch,  // NOLINT
107 108 109 110 111
    const GroupPtr& group,
    const std::unordered_set<Node*>& nodes_inline,
    const std::unordered_set<Node*>& nodes_set,
    const absl::flat_hash_map<std::string, shape_t>& shape_dict,
    const std::unordered_map<std::string, ir::Tensor>& tensor_map);
112 113 114 115

}  // namespace framework
}  // namespace hlir
}  // namespace cinn