buffer_assign.cc 5.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/buffer_assign.h"

#include "paddle/cinn/common/union_find.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/lang/lower_impl.h"
#include "paddle/cinn/optim/ir_replace.h"

namespace cinn {
namespace optim {

namespace {

struct BufferUFNode : public common::UnionFindNode {
  BufferUFNode(const std::string& x) : tensor_name(x) {}

  const char* type_info() const override { return __type_info__; }

  std::string tensor_name;
  static const char* __type_info__;
};

const char* BufferUFNode::__type_info__ = "BufferUFNode";

struct IRReplaceTensorMutator : ir::IRMutator<> {
  const std::map<std::string, ir::Tensor>& tensor_map;
41 42
  IRReplaceTensorMutator(const std::map<std::string, ir::Tensor>& tensor_map)
      : tensor_map(tensor_map) {}
43 44 45 46 47 48 49 50 51 52 53 54
  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

  void Visit(const ir::_Tensor_* op, Expr* expr) override {
    auto it = tensor_map.find(op->name);
    if (it != tensor_map.end()) {
      *expr = Expr(it->second);
    }
  }
};

}  // namespace

55 56 57 58 59 60 61 62
std::map<std::string, ir::Tensor> InitialAssignBuffer(
    Expr* expr,
    poly::StageMap stages,
    const std::map<std::string, ir::Tensor>& all_tensor_map,
    const common::Graph* comp_graph,
    const std::set<std::string>& temp_tensor_names) {
  // The tensor map helps to reserve only one tensor instance for a
  // tensor(called the same name).
63 64 65 66 67 68 69 70 71 72
  std::map<std::string, ir::Tensor> buffer_updated_tensor;

  for (auto& item : all_tensor_map) {
    if (stages[item.second]->inlined()) continue;
    buffer_updated_tensor[item.second->name] = item.second;
  }

  // union-find to cluster the tensors with the same buffer.
  common::UnionFind union_find;

73 74
  // unify all the tensor occurance with a global one, e.g. there are multiple
  // tensor B exists in the expression, replace them with a shared one.
75 76 77 78 79 80 81 82 83 84
  ir::CollectIRNodes(*expr, [&](const Expr* x) -> bool {
    auto* t = x->as_tensor();
    if (t && !stages[t]->inlined()) {
      Reference(x) = Expr(all_tensor_map.at(t->name));
    }
    return false;
  });

  std::map<std::string, BufferUFNode*> uf_map;
  for (auto& item : all_tensor_map) {
85
    auto* n = union_find.AddNode(new BufferUFNode(item.second->name));
86 87 88 89 90 91 92 93 94 95
    uf_map[item.second->name] = n->safe_as<BufferUFNode>();
  }

  for (auto& item : buffer_updated_tensor) {
    auto* cur_n = uf_map[item.first];
    for (auto& other : stages[item.second]->meta.tensors_to_share_buffer_with) {
      // we might intialize the buffer in args.
      auto* other_n = uf_map[other];
      if (!other_n) continue;

96 97
      VLOG(3) << "share buffer between " << item.first << " "
              << other_n->tensor_name;
98 99 100 101
      cur_n->Union(other_n);
    }
  }

102 103 104
  // determine which tensor to have the initial buffer, and will share across
  // the cluster, we take a topological order of the computational graph, and
  // find out which tensor comes first in a cluster.
105 106

  auto _topo_order_topo_edges_ = comp_graph->topological_order();
107 108
  auto& topo_order = std::get<0>(_topo_order_topo_edges_);
  auto& topo_edges = std::get<1>(_topo_order_topo_edges_);
109 110 111 112 113 114 115
  for (common::GraphNode* n : topo_order) {
    auto nn = n->safe_as<lang::detail::CompuGraphNode>();
    CHECK(nn);
    {
      auto it = uf_map.find(nn->tensor->name);
      CHECK(it != uf_map.end());
      auto& cluster_info = std::get<0>(it->second->GetRoot())->cluster_info;
116 117
      if (cluster_info.empty()) {  // buffer owner(a tensor) of this cluster not
                                   // set yet.
118 119 120 121 122 123 124
        cluster_info = nn->tensor->name;
      }
    }
  }

  // Get a center of the cluster, it will consider the following rules
  // 1. Prefer a tensor arg than a temp tensor.
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
  auto cluster_get_center_tensor =
      [&](const std::vector<common::UnionFindNode*>& cluster) {
        ir::Tensor some_tensor;
        // try to find a node that is a tensor_arg, allocate buffer for it, and
        // make others share buffer with it.
        for (auto* n : cluster) {
          auto* node = n->safe_as<BufferUFNode>();
          bool is_temp = temp_tensor_names.count(node->tensor_name);
          if (!is_temp) return all_tensor_map.at(node->tensor_name);
          if (all_tensor_map.at(node->tensor_name)->buffer.defined()) {
            return all_tensor_map.at(node->tensor_name);
          }
          some_tensor = all_tensor_map.at(node->tensor_name);
        }
        return some_tensor;
      };
141 142 143 144 145 146 147 148 149 150 151 152

  for (auto& cluster : union_find.GetClusters()) {
    auto root_tensor = cluster_get_center_tensor(cluster);
    if (!root_tensor->buffer.defined() && !root_tensor->type().is_void()) {
      root_tensor->WithBuffer();
    }

    for (auto* n : cluster) {
      auto& tensor = all_tensor_map.at(n->safe_as<BufferUFNode>()->tensor_name);
      if (tensor != root_tensor) {
        auto keep_shape = root_tensor->buffer->shape;
        Reference(&tensor)->Bind(root_tensor->buffer);
153
        root_tensor->buffer->shape = keep_shape;
154 155 156 157 158 159 160 161 162 163 164
        Reference(&tensor)->buffer->shape = keep_shape;
        VLOG(3) << "keep_shape is : " << utils::GetStreamCnt(keep_shape[0]);
      }
    }
  }

  return buffer_updated_tensor;
}

}  // namespace optim
}  // namespace cinn