fusion_merge_pass.cc 38.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/pass/fusion_merge_pass_util.h"

DECLARE_bool(enhance_vertical_fusion_with_recompute);

namespace cinn {
namespace hlir {
namespace pass {

using framework::Graph;
using framework::Node;
using framework::NodeData;
using framework::OpPatternKind;
using framework::shape_t;

using common::GraphEdge;
using common::GraphNode;

using Comparator = Graph::Group::SharedGroupComparator;
33
using Hasher = Graph::Group::SharedGroupHasher;
34

35
using GroupPtr = std::shared_ptr<Graph::Group>;
36 37
using GroupList = std::vector<GroupPtr>;

38 39
using ConditionFunction = std::function<bool(
    const FusionHelperBase*, const GroupPtr&, const GroupPtr&)>;
40 41 42 43 44 45 46

// Op Fusion Pass which performs Ops fusion, Ops are fused
// "vertically", meaning producing Ops are fused into their consumers
// with the intent that the loops which compute their values will be fused in
// code generation.
class FusionMergePassHelper : public FusionHelperBase {
 public:
47
  explicit FusionMergePassHelper(const Graph* graph) : FusionHelperBase(graph) {
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
    fusion_groups_ = graph->fusion_groups;
    // init fusion relation.
    InitFusionRelation();
    // init input to consumers.
    InitInputToConsumers();
    // init fusion group index.
    InitFusionGroupsAndIndex();
  }

  GroupList operator()() {
    // run fusion merge untill no update.
    DoFusionMerge();
    for (auto& group : fusion_groups_) {
      VLOG(3) << "Fusion Group -> " << group->group_id;
      for (auto& sub_group : group->fused_sub_groups) {
        VLOG(3) << "  Fused Sub-Group -> " << sub_group->group_id;
      }
65
      for (const auto& producer : group->producer_groups()) {
66 67
        VLOG(3) << "  Producer -> " << producer->group_id;
      }
68
      for (const auto& consumer : group->consumer_groups()) {
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
        VLOG(3) << "  Consumer -> " << consumer->group_id;
      }
    }
    return fusion_groups_;
  }

 private:
  void DoFusionMerge() {
    VLOG(3) << "DoFusionMerge...!";
    while (DoHorizontalFusion()) {
    }
    while (DoVerticalFusion(/* recompute=*/false)) {
    }
    while (DoVerticalFusion(/* recompute=*/true)) {
    }
  }

  bool DoHorizontalFusion() {
    VLOG(3) << "DoHorizontalFusion...!";
    bool updated = false;
    for (int idx = 0; idx < fusion_groups_.size(); ++idx) {
      auto producer = fusion_groups_[idx];
      VLOG(3) << "Fusion Producer Group -> " << producer->group_id;
      // if producer is sub group.
      if (producer->belong_groups.size()) {
        continue;
      }
      // do horizontal fusion.
97
      updated |= HorizontalFusion(producer, producer->consumer_groups());
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
    }

    if (updated) {
      UpdateFusionGroup();
    }
    return updated;
  }

  bool DoVerticalFusion(bool recompute) {
    VLOG(3) << "DoVerticalFusion...!";
    bool updated = false;
    for (int idx = 0; idx < fusion_groups_.size(); ++idx) {
      auto producer = fusion_groups_[idx];
      VLOG(3) << "Fusion Producer Group -> " << producer->group_id;
      // if producer is sub group.
      if (producer->belong_groups.size()) {
        continue;
      }
      // do horizontal fusion.
      if (!recompute) {
118
        updated |= HorizontalFusion(producer, producer->consumer_groups());
119
      }
120 121
      updated |=
          VerticalFusion(producer, producer->consumer_groups(), recompute);
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    }
    // fuse input consumers
    updated |= FuseInputToConsumers();

    if (updated) {
      UpdateFusionGroup();
    }
    return updated;
  }

  void UpdateFusionGroup() {
    VLOG(3) << "UpdateFusionGroup...";
    GroupList fusion_groups;
    std::unordered_set<GroupPtr, Hasher, Comparator> fusion_groups_set;
    // update fusion_groups_
    for (auto& group : fusion_groups_) {
      if (!group->belong_groups.size()) {
        fusion_groups.push_back(group);
        fusion_groups_set.insert(group);
      }
    }
    // keep group in order
    fusion_groups_.clear();
    fusion_groups_index_.clear();
    while (!fusion_groups_set.empty()) {
      bool is_ring = true;
      for (int idx = 0; idx < fusion_groups.size(); ++idx) {
        auto& group = fusion_groups[idx];
        if (!group.get()) {
          continue;
        }

        bool exist = false;
155
        for (const auto& producer : group->producer_groups()) {
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
          if (fusion_groups_set.count(producer)) {
            VLOG(4) << group->group_id << " " << producer->group_id;
            exist = true;
            break;
          }
        }

        if (!exist) {
          fusion_groups_index_[group] = fusion_groups_.size();
          fusion_groups_.push_back(group);
          fusion_groups_set.erase(group);
          group.reset();
          is_ring = false;
          continue;
        }
      }
      if (is_ring) {
        LOG(FATAL) << "Exists Ring, Please Check!";
      }
    }
  }

178 179
  bool HorizontalFusion(
      GroupPtr producer,
180
      const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
181 182 183 184 185 186
    VLOG(3) << "HorizontalFusion...!";
    if (consumers.size() <= 1) {
      return false;
    }

    std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
187
    for (const auto& consumer : consumers) {
188 189 190 191 192 193 194 195 196 197 198 199 200
      // relation
      auto& relation = fusion_relation_map_[consumer->op_pattern_kind];
      // check horizontal relation exist
      if (!relation.horizontal_relation.size()) {
        continue;
      }
      candidates.insert(consumer);
    }

    std::vector<GroupList> fusionable_consumers;
    for (auto& candidate : candidates) {
      // check dependency
      if (IsDependencySimplify(producer, candidate, candidates)) {
201 202
        VLOG(4) << "IsDependencySimplify, Can't fuse " << candidate->group_id
                << ", As it depency others!";
203 204 205 206
        continue;
      }

      if (IsDependency(producer, candidate, candidates)) {
207 208
        VLOG(4) << "IsDependency, Can't fuse " << candidate->group_id
                << ", As it depency others!";
209 210 211 212 213 214 215 216 217 218
        continue;
      }

      if (!fusionable_consumers.size()) {
        fusionable_consumers.push_back({candidate});
        continue;
      }

      // check each fusionable groups
      bool fusionable = false;
219
      auto& relation = fusion_relation_map_[candidate->op_pattern_kind];
220 221 222 223 224 225
      for (auto& groups : fusionable_consumers) {
        auto& last = groups.back();
        if (!relation.horizontal_relation.count(last->op_pattern_kind)) {
          continue;
        }

226 227
        if (!relation.horizontal_relation[last->op_pattern_kind](
                this, candidate, last)) {
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
          continue;
        }

        groups.push_back(candidate);
        fusionable = true;
        break;
      }

      // if can't fuse to othors Groups, new Groups.
      if (!fusionable) {
        fusionable_consumers.push_back({candidate});
      }
    }

    bool updated = false;
    for (auto& groups : fusionable_consumers) {
      if (groups.size() > 1) {
        updated = true;
        HorizontalFuse(groups);
      }
    }

    return updated;
  }

253
  void HorizontalFuse(const GroupList& consumers) {
254 255 256 257 258 259 260 261 262 263 264 265
    VLOG(3) << "HorizontalFuse Groups...";
    // create fusion group
    auto fused_group = std::make_shared<Graph::Group>();
    // As recompute exist which may case sub-group used by more than one time.
    std::vector<GroupPtr> repeat_sub_groups;
    std::unordered_set<GroupPtr, Hasher, Comparator> sub_group_set;
    // find the first consumer.
    GroupPtr first_consumer(nullptr);
    // fuse all group into fusion group.
    for (auto& consumer : consumers) {
      VLOG(3) << "fuse consumer " << consumer->group_id << " into fused_group!";
      // update depth
266 267 268 269
      fused_group->max_depth =
          std::max(fused_group->max_depth, consumer->max_depth);
      fused_group->min_depth =
          std::min(fused_group->min_depth, consumer->min_depth);
270 271 272 273 274 275 276 277
      // update group id
      if (fused_group->group_id.size()) {
        fused_group->group_id += "_" + consumer->group_id;
      } else {
        fused_group->group_id = consumer->group_id;
      }
      // set op pattern kind
      fused_group->op_pattern_kind =
278 279
          static_cast<int>(fused_group->op_pattern_kind) >=
                  static_cast<int>(consumer->op_pattern_kind)
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327
              ? fused_group->op_pattern_kind
              : consumer->op_pattern_kind;
      // input nodes
      for (auto& node : consumer->input_nodes) {
        if (fused_group->input_nodes.count(node.first)) {
          fused_group->input_nodes[node.first] += node.second;
        } else {
          fused_group->input_nodes.insert(node);
        }
      }
      // output node
      for (auto& node : consumer->output_nodes) {
        fused_group->output_nodes.insert(node);
      }
      // internal node
      if (consumer->fused_sub_groups.size()) {
        for (auto& node : consumer->internal_nodes) {
          fused_group->internal_nodes.insert(node);
        }
      }
      // master node
      for (auto& node : consumer->master_nodes) {
        if (GetOpKind(node) == framework::kReduction) {
          fused_group->master_nodes.insert(node);
        }
      }
      // insert sub group
      if (consumer->fused_sub_groups.size()) {
        for (auto& sub_group : consumer->fused_sub_groups) {
          // check sub group is repeat.
          if (sub_group_set.count(sub_group)) {
            VLOG(3) << sub_group->group_id << " is repeated!";
            repeat_sub_groups.push_back(sub_group);
            continue;
          }
          // record sub group
          sub_group_set.insert(sub_group);

          // insert to fused sub group.
          fused_group->fused_sub_groups.push_back(sub_group);
          // update belongs group
          sub_group->belong_groups.erase(consumer);
          sub_group->belong_groups.insert(fused_group);
        }
      } else {
        fused_group->fused_sub_groups.push_back(consumer);
      }
      // producer group
328 329
      for (auto& producer : *consumer->mut_producer_groups()) {
        fused_group->mut_producer_groups()->insert(producer);
330
        // update producer's consumer
331 332
        producer->mut_consumer_groups()->erase(consumer);
        producer->mut_consumer_groups()->insert(fused_group);
333 334
      }
      // consumer group
335 336
      for (auto& gconsumer : *consumer->mut_consumer_groups()) {
        fused_group->mut_consumer_groups()->insert(gconsumer);
337
        // update consumer's producer
338 339
        gconsumer->mut_producer_groups()->erase(consumer);
        gconsumer->mut_producer_groups()->insert(fused_group);
340 341 342 343 344 345
      }
      // belongs group
      consumer->belong_groups.insert(fused_group);

      // find the first consumer.
      CHECK(fusion_groups_index_.count(consumer))
346 347
          << "Can't find consumer " << consumer->group_id
          << " index in fusion_groups_index_!";
348
      if (first_consumer.get()) {
349 350
        if (fusion_groups_index_[consumer] <
            fusion_groups_index_[first_consumer]) {
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
          first_consumer = consumer;
        }
      } else {
        first_consumer = consumer;
      }
    }

    // if node is output nodes of sub_group, check it can't be internal node.
    for (auto& sub_group : repeat_sub_groups) {
      // check each output node in sub_group.
      for (auto& node : sub_group->output_nodes) {
        // if node is not output node of fused_group.
        if (!fused_group->output_nodes.count(node)) {
          fused_group->internal_nodes.insert(node);
        }
      }
    }

369 370
    if (static_cast<int>(framework::kReduction) >
        static_cast<int>((consumers.back())->op_pattern_kind)) {
371 372 373 374 375 376
      auto consumer = consumers.back();

      for (auto& node : consumer->master_nodes) {
        fused_group->master_nodes.insert(node);
      }
    } else {
377 378
      for (auto consumer = consumers.rbegin(); consumer != consumers.rend();
           ++consumer) {
379 380 381 382 383 384 385 386
        Node* master_node = nullptr;
        for (auto& node : (*consumer)->master_nodes) {
          if (GetOpKind(node) != framework::kReduction) {
            master_node = node;
            break;
          }
        }
        if (master_node) {
387 388
          VLOG(3) << "Insert Master node : " << master_node->id()
                  << " into group : " << fused_group->group_id;
389 390 391 392 393 394
          fused_group->master_nodes.insert(master_node);
          break;
        }
      }
    }

395 396
    auto postion = fusion_groups_index_[first_consumer];
    fusion_groups_[postion] = fused_group;
397 398
    fusion_groups_index_[fused_group] = postion;

399 400
    CHECK(fused_group->output_nodes.size())
        << "No output node is found, " << fused_group->group_id;
401 402
  }

403
  bool VerticalFusion(
404 405
      const GroupPtr& producer,
      const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers,
406
      bool recompute) {
407 408 409 410 411 412 413 414 415
    VLOG(3) << "VerticalFusion, Number of Consumers : " << consumers.size();
    auto& relation = fusion_relation_map_[producer->op_pattern_kind];
    // if producer can't fuse others
    if (!relation.vertical_relation.size()) {
      return false;
    }

    std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers_unsafe;
    std::unordered_set<GroupPtr, Hasher, Comparator> fuse_consumers;
416
    for (const auto& consumer : consumers) {
417 418
      VLOG(4) << "Check consuemr " << consumer->group_id
              << " can fuse to producer " << producer->group_id;
419 420
      // if can't fuse
      if (!relation.vertical_relation.count(consumer->op_pattern_kind)) {
421 422
        VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer "
                << consumer->group_id;
423 424 425 426
        continue;
      }

      // if condition function is false
427 428 429 430
      if (!relation.vertical_relation[consumer->op_pattern_kind](
              this, producer, consumer)) {
        VLOG(4) << "Can't fuse producer " << producer->group_id << " consumer "
                << consumer->group_id;
431 432 433 434 435 436
        continue;
      }

      fuse_consumers_unsafe.insert(consumer);

      if (IsDependencySimplify(producer, consumer, consumers)) {
437 438
        VLOG(4) << "IsDependencySimplify, Consumer " << consumer->group_id
                << " can't be master fused group!";
439 440 441 442
        continue;
      }

      if (IsDependency(producer, consumer, consumers)) {
443 444
        VLOG(4) << "IsDependency, Consumer " << consumer->group_id
                << " can't be master fused group!";
445 446 447 448 449 450
        continue;
      }

      fuse_consumers.insert(consumer);
    }

451 452 453 454
    VLOG(3) << "VerticalFusion, Number of fuse Consumers : "
            << fuse_consumers.size();
    VLOG(3) << "VerticalFusion, Number of unsafe fuse Consumers : "
            << fuse_consumers.size();
455 456 457 458 459 460 461

    if (fuse_consumers.size() == 0) {
      return false;
    }
    // if can_fuse_consumers == consumers
    // if producer op kind == kElementwise
    // if use recompute
462
    if (fuse_consumers_unsafe.size() == producer->consumer_groups().size() &&
463 464 465 466
        producer->op_pattern_kind == framework::kElementWise) {
      if (!recompute) {
        return false;
      } else {
467
        RecomputeEleGraph(producer, &fuse_consumers_unsafe);
468 469 470 471 472 473
        VerticalFuse(producer, fuse_consumers_unsafe);
        return true;
      }
    }

    if (fuse_consumers.size()) {
474
      SelectConsumerToFuse(producer, &fuse_consumers);
475 476 477 478 479 480 481 482 483 484 485
    }

    // if fusionable consumers exist
    if (fuse_consumers.size()) {
      VerticalFuse(producer, fuse_consumers);
      return true;
    }

    return false;
  }

486 487 488
  void VerticalFuse(const GroupPtr& producer,
                    const std::unordered_set<GroupPtr, Hasher, Comparator>&
                        fusionable_consumers) {
489 490 491 492 493 494
    VLOG(3) << "VerticalFuse...!";
    GroupList fused_groups;
    GroupPtr master_fuesd_group(nullptr);
    for (auto& consumer : fusionable_consumers) {
      auto fused_group = std::make_shared<Graph::Group>();
      // update depth using consumer depth.
495 496 497 498
      fused_group->max_depth =
          std::max(producer->max_depth, consumer->max_depth);
      fused_group->min_depth =
          std::min(producer->min_depth, consumer->min_depth);
499 500
      // update group id
      fused_group->group_id = producer->group_id + "_" + consumer->group_id;
501 502
      VLOG(3) << "fuse producer " << producer->group_id << " into consumer "
              << consumer->group_id;
503 504
      // fuse producer into fusion group
      fused_group->op_pattern_kind =
505 506
          static_cast<int>(producer->op_pattern_kind) >=
                  static_cast<int>(consumer->op_pattern_kind)
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
              ? producer->op_pattern_kind
              : consumer->op_pattern_kind;
      // input nodes
      fused_group->input_nodes = producer->input_nodes;

      // internal nodes
      if (producer->fused_sub_groups.size()) {
        for (auto& node : producer->internal_nodes) {
          fused_group->internal_nodes.insert(node);
        }
      }
      // convert producer's output node to internal.
      for (auto node : producer->output_nodes) {
        // if node is used more than 1 time.
        if (consumer->input_nodes.count(node)) {
          if (consumer->input_nodes[node] > 1 && node->inlinks().size() > 0) {
            fused_group->internal_nodes.insert(node);
          }
        }
      }
      // master nodes
      for (auto& node : producer->master_nodes) {
        if (GetOpKind(node) == framework::kReduction) {
          fused_group->master_nodes.insert(node);
        }
      }

      // producer groups
535 536
      for (auto& group : *producer->mut_producer_groups()) {
        fused_group->mut_producer_groups()->insert(group);
537
        // update producer's producer's consumer
538 539
        group->mut_consumer_groups()->erase(producer);
        group->mut_consumer_groups()->insert(fused_group);
540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
      }

      // sub groups
      if (producer->fused_sub_groups.size()) {
        for (auto& group : producer->fused_sub_groups) {
          fused_group->fused_sub_groups.push_back(group);
          // update belong group
          group->belong_groups.erase(producer);
          group->belong_groups.insert(fused_group);
        }
      } else {
        fused_group->fused_sub_groups.push_back(producer);
      }
      producer->belong_groups.insert(fused_group);

      // input nodes
      for (auto& input_node : consumer->input_nodes) {
        // if input node not in producer output.
        if (!producer->output_nodes.count(input_node.first)) {
          if (fused_group->input_nodes.count(input_node.first)) {
            fused_group->input_nodes[input_node.first] += input_node.second;
          } else {
            fused_group->input_nodes.insert(input_node);
          }
        }
      }

      // output nodes
      for (auto& node : consumer->output_nodes) {
        fused_group->output_nodes.insert(node);
      }

      // internal nodes
      if (consumer->fused_sub_groups.size()) {
        for (auto& node : consumer->internal_nodes) {
          fused_group->internal_nodes.insert(node);
        }
      }

      // master nodes
      for (auto& node : consumer->master_nodes) {
        fused_group->master_nodes.insert(node);
      }

      // producer nodes
585
      for (auto& group : *consumer->mut_producer_groups()) {
586
        if (group.get() != producer.get()) {
587
          fused_group->mut_producer_groups()->insert(group);
588
          // update consumer's producer's consumer
589 590
          group->mut_consumer_groups()->erase(consumer);
          group->mut_consumer_groups()->insert(fused_group);
591 592 593
        }
      }
      // consumer nodes
594 595
      for (auto& group : *consumer->mut_consumer_groups()) {
        fused_group->mut_consumer_groups()->insert(group);
596
        // update consumer's consumer's producer
597 598
        group->mut_producer_groups()->erase(consumer);
        group->mut_producer_groups()->insert(fused_group);
599 600 601 602 603
      }

      // sub group
      if (consumer->fused_sub_groups.size()) {
        for (auto& sub_group : consumer->fused_sub_groups) {
604 605 606
          if (std::find(fused_group->fused_sub_groups.begin(),
                        fused_group->fused_sub_groups.end(),
                        sub_group) == fused_group->fused_sub_groups.end()) {
607 608 609 610 611 612 613 614 615 616 617 618 619
            fused_group->fused_sub_groups.push_back(sub_group);
          }
          // update belong group
          sub_group->belong_groups.erase(consumer);
          sub_group->belong_groups.insert(fused_group);
        }
      } else {
        fused_group->fused_sub_groups.push_back(consumer);
      }
      consumer->belong_groups.insert(fused_group);

      fused_groups.push_back(fused_group);
      CHECK(fusion_groups_index_.count(consumer))
620 621 622 623
          << "Can't find consumer " << consumer->group_id
          << " index in fusion_groups_index_!";
      auto postion = fusion_groups_index_[consumer];
      fusion_groups_[postion] = fused_group;
624 625 626 627 628
      fusion_groups_index_[fused_group] = postion;

      if (!master_fuesd_group.get()) {
        master_fuesd_group = fused_group;
      }
629 630
      CHECK(fused_group->output_nodes.size())
          << "No output node is found, " << fused_group->group_id;
631 632 633 634
    }

    for (auto& node : producer->output_nodes) {
      bool be_output = true;
635
      for (const auto& consumer : producer->consumer_groups()) {
636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
        // if consumer is in fusionable.
        if (fusionable_consumers.count(consumer)) {
          if (consumer->input_nodes.count(node)) {
            be_output = false;
          }
          continue;
        }
        // if consumer is not in fusionable.
        if (consumer->input_nodes.count(node)) {
          be_output = true;
          break;
        }
        // others node is as graph output.
      }

      if (output_nodes_set_.count(node)) {
        be_output = true;
      }

      if (be_output) {
656 657
        VLOG(4) << "Insert Id " << node->id() << " Into Group "
                << master_fuesd_group->group_id;
658 659 660 661
        master_fuesd_group->output_nodes.insert(node);
      }
    }
    // insert unfusionable consumer groups
662
    for (auto& consumer : *producer->mut_consumer_groups()) {
663 664 665
      if (fusionable_consumers.count(consumer)) {
        continue;
      }
666
      master_fuesd_group->mut_consumer_groups()->insert(consumer);
667
      // update consumer's producer
668 669
      consumer->mut_producer_groups()->erase(producer);
      consumer->mut_producer_groups()->insert(master_fuesd_group);
670 671 672
    }
  }

673 674
  void RecomputeEleGraph(
      const GroupPtr& producer,
675
      std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
676 677 678 679 680
    if (producer->op_pattern_kind != framework::kElementWise) {
      SelectConsumerToFuse(producer, fusionable_consumers);
    }
  }

681 682
  void SelectConsumerToFuse(
      const GroupPtr& producer,
683
      std::unordered_set<GroupPtr, Hasher, Comparator>* fusionable_consumers) {
684 685 686
    // if is const op
    if (is_const_group(this, producer)) {
      std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
687
      for (auto& consumer : *fusionable_consumers) {
688 689 690 691
        // if can be output node.
        if (is_same_shape(this, producer, consumer)) {
          candidates.insert(consumer);
        } else {
692 693
          VLOG(4) << "Fuse Producer : " << producer->group_id
                  << " into Consumer : " << consumer->group_id;
694 695
          consumer->group_id = producer->group_id + "_" + consumer->group_id;
          // just merge the node into group.
696
          auto& sub_group = consumer->fused_sub_groups.front();
697
          sub_group->group_id = producer->group_id + "_" + sub_group->group_id;
698 699
          sub_group->nodes.insert(sub_group->nodes.begin(),
                                  producer->CollectNodes()[0]);
700 701 702
          sub_group->nodes_set.insert(producer->CollectNodes()[0]);
          // remove depency.
          consumer->input_nodes.erase(producer->CollectNodes()[0]);
703 704
          consumer->mut_producer_groups()->erase(producer);
          producer->mut_consumer_groups()->erase(consumer);
705 706 707
        }
      }

708 709
      CHECK_GE(producer->consumer_groups().size(), candidates.size());
      if (producer->consumer_groups().size() == 0 && candidates.size() == 0 &&
710
          output_nodes_set_.count(producer->CollectNodes()[0]) == 0) {
711
        producer->belong_groups.insert(*fusionable_consumers->begin());
712 713
      }

714
      *fusionable_consumers = candidates;
715 716 717
      return;
    }
    // 1 to 1 fusion.
718
    if (producer->consumer_groups().size() == 1) {
719 720 721 722 723
      return;
    }

    if (FLAGS_enhance_vertical_fusion_with_recompute) {
      std::vector<GroupPtr> candidates;
724
      for (auto& consumer : *fusionable_consumers) {
725 726 727 728 729
        if (consumer->op_pattern_kind == framework::kElementWise) {
          candidates.push_back(consumer);
          continue;
        }

730 731 732 733 734 735
        auto producer_output_shape =
            this->GetNodeDataShape(*producer->output_nodes.begin());
        auto consumer_output_shape =
            this->GetNodeDataShape(*consumer->output_nodes.begin());
        auto consumer_master_input_shape =
            this->GetNodeInputShape(*(consumer->master_nodes.begin()));
736
        int producer_output_numel =
737 738 739 740
            std::accumulate(producer_output_shape.begin(),
                            producer_output_shape.end(),
                            1,
                            std::multiplies<int>());
741
        int consumer_output_numel =
742 743 744 745 746 747 748 749 750
            std::accumulate(consumer_output_shape.begin(),
                            consumer_output_shape.end(),
                            1,
                            std::multiplies<int>());
        int consumer_master_input_numel =
            std::accumulate(consumer_master_input_shape.begin(),
                            consumer_master_input_shape.end(),
                            1,
                            std::multiplies<int>());
751 752 753 754 755
        if (producer_output_numel == consumer_output_numel) {
          candidates.push_back(consumer);
          continue;
        }

756 757
        if (producer->op_pattern_kind != framework::kInjective &&
            consumer->op_pattern_kind == framework::kReduction &&
758 759 760 761
            producer_output_numel == consumer_master_input_numel) {
          candidates.push_back(consumer);
        }
      }
762 763 764 765 766
      sort(candidates.begin(),
           candidates.end(),
           [](const auto& lhs, const auto& rhs) {
             return lhs->op_pattern_kind < rhs->op_pattern_kind;
           });
767

768
      fusionable_consumers->clear();
769
      if (candidates.size()) {
770
        fusionable_consumers->insert(*candidates.begin());
771 772 773
      }
    } else {
      std::unordered_set<GroupPtr, Hasher, Comparator> candidates;
774
      for (auto& consumer : *fusionable_consumers) {
775 776 777 778 779 780 781 782
        if (consumer->op_pattern_kind == framework::kElementWise) {
          candidates.insert(consumer);
          continue;
        }

        auto shape0 = this->GetNodeDataShape(*producer->output_nodes.begin());
        auto shape1 = this->GetNodeDataShape(*consumer->output_nodes.begin());

783 784 785 786
        if (std::accumulate(
                shape0.begin(), shape0.end(), 1, std::multiplies<int>()) ==
            std::accumulate(
                shape1.begin(), shape1.end(), 1, std::multiplies<int>())) {
787 788 789 790
          candidates.insert(consumer);
        }
      }

791
      fusionable_consumers->clear();
792
      if (candidates.size()) {
793
        fusionable_consumers->insert(*candidates.begin());
794 795 796 797
      }
    }
  }

798 799 800 801
  bool IsDependency(
      const GroupPtr& producer_g,
      const GroupPtr& consumer,
      const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
802 803 804 805 806 807 808
    std::queue<GroupPtr> candidates;
    candidates.push(consumer);

    std::unordered_set<GroupPtr, Hasher, Comparator> visited_set;
    while (!candidates.empty()) {
      auto& candidate = candidates.front();
      candidates.pop();
809
      for (const auto& producer : candidate->producer_groups()) {
810 811 812 813 814 815 816 817 818 819 820 821 822 823 824
        if (producer.get() == producer_g.get()) {
          continue;
        }
        if (consumers.count(producer)) {
          return true;
        }
        if (!visited_set.count(producer)) {
          visited_set.insert(producer);
          candidates.push(producer);
        }
      }
    }
    return false;
  }

825 826 827 828
  bool IsDependencySimplify(
      const GroupPtr& producer_g,
      const GroupPtr& consumer,
      const std::unordered_set<GroupPtr, Hasher, Comparator>& consumers) {
829 830 831 832 833 834 835 836
    std::queue<GroupPtr> candidates;
    candidates.push(consumer);
    // check upper.
    int check_upper_depth = producer_g.get() ? producer_g->max_depth : INT_MAX;
    std::unordered_set<GroupPtr, Hasher, Comparator> visited_set;
    while (!candidates.empty()) {
      auto& candidate = candidates.front();
      candidates.pop();
837
      for (auto& producer : candidate->producer_groups()) {
838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
        if (producer.get() == producer_g.get()) {
          continue;
        }
        if (producer->min_depth > check_upper_depth) {
          continue;
        }
        if (consumers.count(producer)) {
          return true;
        }
        if (!visited_set.count(producer)) {
          visited_set.insert(producer);
          candidates.push(producer);
        }
      }
    }
    return false;
  }

  bool FuseInputToConsumers() {
    VLOG(3) << "FuseInputToConsumers...!";
    auto updated = false;
    UpdateInputToConsumers();
    GroupPtr producer(nullptr);
    for (auto& input_consumers : input_to_consumers_) {
      // if group set size == 1.
      if (input_consumers.second.size() == 1) {
        continue;
      }
      // do horizontal fusion.
      auto st = HorizontalFusion(producer, input_consumers.second);
      if (st) {
        // fused consumers, update
        UpdateInputToConsumers();
      }
      updated |= st;
    }

    return updated;
  }

  void UpdateInputToConsumers() {
    for (auto& input_consumers : input_to_consumers_) {
      auto& consumers = input_consumers.second;
      std::unordered_set<GroupPtr, Hasher, Comparator> updated_consumers;
      for (auto& consumer : consumers) {
        std::queue<GroupPtr> fused_groups;
        fused_groups.push(consumer);
        while (!fused_groups.empty()) {
          auto& cur = fused_groups.front();
          fused_groups.pop();
          // if group is sub group
          if (cur->belong_groups.empty()) {
            updated_consumers.insert(cur);
          } else {
            for (auto& belong_group : cur->belong_groups) {
              if (belong_group->group_id == cur->group_id) {
                updated_consumers.insert(belong_group);
              } else {
                fused_groups.push(belong_group);
              }
            }
          }
        }
      }
      consumers = updated_consumers;
    }
  }

  void InitInputToConsumers() {
    VLOG(3) << "InitInputToConsumers...!";
    // init input data node -> fusion group map.
    for (auto& group : fusion_groups_) {
      for (auto& node : group->nodes_set) {
        // collect producer node data.
        auto producer_node_datas = GetProducerNodeData(node);
        for (auto& node_data : producer_node_datas) {
          // node data's source node is null.
          if (!node_data->source_node.get()) {
            // insert group to set.
            input_to_consumers_[node_data].insert(group);
          }
        }
      }
    }
  }

  void InitFusionGroupsAndIndex() {
    VLOG(3) << "InitFusionGroupsAndIndex...!";
    // init the postion of groups in fusion groups.
    for (int idx = 0; idx < fusion_groups_.size(); ++idx) {
928
      auto group = fusion_groups_[idx];
929 930
      auto belong_group = std::make_shared<Graph::Group>();
      // copy from group.
931 932 933 934 935
      belong_group->max_depth = group->depth;
      belong_group->min_depth = group->depth;
      belong_group->group_id = group->group_id;
      belong_group->input_nodes = group->input_nodes;
      belong_group->output_nodes = group->output_nodes;
936
      belong_group->op_pattern_kind = group->op_pattern_kind;
937
      belong_group->master_nodes = group->master_nodes;
938 939
      (*belong_group->mut_producer_groups()) = group->producer_groups();
      (*belong_group->mut_consumer_groups()) = group->consumer_groups();
940 941 942 943 944 945 946 947 948 949 950 951 952
      belong_group->fused_sub_groups.push_back(group);
      group->belong_groups.insert(belong_group);
      // replace group to fused_group
      fusion_groups_[idx] = belong_group;
      // record idx
      fusion_groups_index_[belong_group] = idx;
    }

    // update producer and consumer.
    for (auto& group : fusion_groups_) {
      std::unordered_set<GroupPtr, Hasher, Comparator> producers;
      std::unordered_set<GroupPtr, Hasher, Comparator> consumers;

953
      for (const auto& producer : group->producer_groups()) {
954 955 956
        CHECK(producer->belong_groups.size());
        producers.insert(*producer->belong_groups.begin());
      }
957 958

      for (auto& consumer : *group->mut_consumer_groups()) {
959 960 961
        CHECK(consumer->belong_groups.size());
        consumers.insert(*consumer->belong_groups.begin());
      }
962 963 964 965
      CHECK_EQ(group->producer_groups().size(), producers.size());
      CHECK_EQ(group->consumer_groups().size(), consumers.size());
      (*group->mut_producer_groups()) = producers;
      (*group->mut_consumer_groups()) = consumers;
966 967 968 969 970 971 972 973 974
    }
  }

  void InitFusionRelation() {
    VLOG(3) << "InitFusionRelation...!";
    // kElementWise
    {
      auto& relation = fusion_relation_map_[OpPatternKind::kElementWise];
      // horizontal
975 976 977 978 979 980 981 982
      relation.horizontal_relation = {
          {framework::kElementWise, is_same_size},
          // element-wise and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, is_same_size},
          // element-wise and injective op must be horizontal relation.
          {OpPatternKind::kInjective, is_same_size},
          // element-wise and reduce op must be horizontal relation.
          {OpPatternKind::kReduction, honrizontal_elementwise_fuse_reduce}};
983
      // vertical
984 985 986 987 988 989 990 991
      relation.vertical_relation = {
          {OpPatternKind::kElementWise, is_same_size},
          // element-wise and broadcast can be vertical/horizontal relation.
          {OpPatternKind::kBroadcast, elementwise_fuse_broadcast},
          // element-wise and injective op must be horizontal relation.
          {OpPatternKind::kInjective, horizontal_with_injective},
          // element-wise and reduce can be vertical/horizontal relation.
          {OpPatternKind::kReduction, elementwise_fuse_reduce}};
992 993 994 995 996
    }
    // kBroadcast
    {
      auto& relation = fusion_relation_map_[OpPatternKind::kBroadcast];
      // horizontal
997 998 999 1000 1001 1002 1003 1004 1005
      relation.horizontal_relation = {
          // broadcast and element-wise op must be horizontal relation.
          {framework::kElementWise, is_same_size},
          // broadcast and broadcast op must be horizontal relation.
          {framework::kBroadcast, is_same_size},
          // broadcast and injective op must be horizontal relation.
          {OpPatternKind::kInjective, is_same_size},
          // broadcast and reduce op must be horizontal relation.
          {OpPatternKind::kReduction, is_same_size}};
1006
      // vertical
1007 1008 1009 1010 1011 1012 1013 1014 1015
      relation.vertical_relation = {
          // broadcast and element-wise op must be vertical relation.
          {OpPatternKind::kElementWise, is_same_size},
          // broadcast and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, is_same_size},
          // broadcast and injective op must be horizontal relation.
          {OpPatternKind::kInjective, horizontal_with_injective},
          // broadcast and reduce must be vertical relation.
          {OpPatternKind::kReduction, broadcast_fuse_reduce}};
1016 1017 1018 1019 1020
    }
    // kInjective
    {
      auto& relation = fusion_relation_map_[OpPatternKind::kInjective];
      // horizontal
1021 1022 1023 1024 1025 1026 1027 1028 1029
      relation.horizontal_relation = {
          // injective and element-wise op must be horizontal relation.
          {OpPatternKind::kElementWise, is_same_size},
          // injective and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, is_same_size},
          // injective and injective op must be horizontal relation.
          {OpPatternKind::kInjective, is_same_size},
          // injective and reduce must be horizontal relation.
          {OpPatternKind::kReduction, is_same_size}};
1030
      // vertical
1031 1032 1033 1034 1035 1036 1037 1038 1039
      relation.vertical_relation = {
          // injective and element-wise op must be horizontal relation.
          {OpPatternKind::kElementWise, is_same_size},
          // injective and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, is_same_size},
          // injective and injective op must be horizontal relation.
          {OpPatternKind::kInjective, horizontal_with_injective},
          // injective and reduce can be horizontal/vertical relation.
          {OpPatternKind::kReduction, injective_horizontal_with_reduce}};
1040 1041 1042 1043 1044
    }
    // kReduction
    {
      auto& relation = fusion_relation_map_[OpPatternKind::kReduction];
      // horizontal
1045 1046 1047 1048 1049 1050 1051 1052 1053
      relation.horizontal_relation = {
          // reduce and element-wise op must be horizontal relation.
          {OpPatternKind::kElementWise, honrizontal_elementwise_fuse_reduce},
          // reduce and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, is_same_size},
          // reduce and injective op must be horizontal relation.
          {OpPatternKind::kInjective, is_same_size},
          // reduce and reduce must be horizontal relation.
          {OpPatternKind::kReduction, reduce_fuse_reduce}};
1054
      // vertical
1055 1056 1057 1058 1059 1060 1061 1062 1063
      relation.vertical_relation = {
          // reduce and elementwise can be horizontal/vertical relation.
          {OpPatternKind::kElementWise, reduce_fuse_elementwise},
          // reduce and broadcast op must be horizontal relation.
          {OpPatternKind::kBroadcast, reduce_fuse_broadcast},
          // reduce and injective op must be horizontal relation.
          {OpPatternKind::kInjective, horizontal_with_injective},
          // reduce and reduce must be horizontal relation.
          {OpPatternKind::kReduction, reduce_fuse_reduce}};
1064 1065 1066 1067 1068
    }
  }

  GroupList fusion_groups_;
  std::unordered_map<GroupPtr, int, Hasher, Comparator> fusion_groups_index_;
1069 1070 1071
  std::unordered_map<NodeData*,
                     std::unordered_set<GroupPtr, Hasher, Comparator>>
      input_to_consumers_;
1072 1073

  struct Relation {
1074 1075 1076 1077
    std::unordered_map<framework::OpPatternKind, ConditionFunction>
        vertical_relation;
    std::unordered_map<framework::OpPatternKind, ConditionFunction>
        horizontal_relation;
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098
  };
  std::unordered_map<framework::OpPatternKind, Relation> fusion_relation_map_;
};

void FusionMergePassInternal(Graph* graph) {
  if (graph->fusion_groups.size() <= 1) {
    VLOG(3) << "Don't do Fusoin Merge Pass...!";
    return;
  }

  FusionMergePassHelper fusion_merge_pass_helper(graph);
  graph->fusion_groups = fusion_merge_pass_helper();
}

}  // namespace pass
}  // namespace hlir
}  // namespace cinn

CINN_REGISTER_HELPER(FusionMergePass) {
  CINN_REGISTER_PASS(FusionMergePass)
      .describe(
1099 1100
          "Fusion Merge Pass which performs Fusion-Ops fusion, Producer "
          "Fusion-Ops are fused into Consumer Fusion-Ops "
1101 1102 1103 1104 1105 1106
          "with certain conditions.")
      .set_change_structure(false)
      .set_body(cinn::hlir::pass::FusionMergePassInternal);

  return true;
}