check_fusion_accuracy_pass.cc 20.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <absl/container/flat_hash_map.h>

#include <deque>
#include <sstream>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/visualize_helper.h"
#include "paddle/cinn/hlir/pass/fusion_helper_base.h"
#include "paddle/cinn/runtime/custom_function.h"

namespace cinn::hlir::pass {

using framework::Graph;
using framework::Node;
using framework::NodeData;
using framework::NodePtr;
using framework::Operator;
using framework::OpPatternKind;

using cinn::hlir::framework::GenerateAccCheckNodeId;
using common::GraphEdge;
using common::GraphNode;

44
using GroupPtr = std::shared_ptr<Graph::Group>;
45 46 47 48 49 50 51 52
using GroupList = std::vector<GroupPtr>;

using ShapeDict = absl::flat_hash_map<std::string, framework::shape_t>;
using DtypeDict = absl::flat_hash_map<std::string, common::Type>;

namespace utils {
class AssertMsg {
 public:
53
  explicit AssertMsg(int group_id) : group_id_(group_id) {}
54

55 56 57
  void SetMsg(const std::string& title, const std::string& msg) {
    msg_info_[title] = msg;
  }
58 59

  const std::string& GetMsg(const std::string& title) const {
60 61
    CHECK(msg_info_.count(title))
        << "Msg of group " << group_id_ << " not has title: " << title;
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
    return msg_info_.at(title);
  }

  void CleasMsg(const std::string& title) { msg_info_.erase(title); }

  std::string str() const {
    std::stringstream ss;
    for (const auto& msg_pair : msg_info_) {
      ss << "  -- " << msg_pair.first << ": " << msg_pair.second << "\n";
    }
    return ss.str();
  }

 private:
  std::unordered_map<std::string, std::string> msg_info_;
  int group_id_;
};
}  // namespace utils

class CheckFusionAccuracyPass {
 public:
83
  explicit CheckFusionAccuracyPass(Graph* graph)
84
      : graph_(graph),
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        shape_dict_(graph_->GetMutableAttrs<ShapeDict>("infershape")),
        dtype_dict_(graph_->GetMutableAttrs<DtypeDict>("inferdtype")) {}

  GroupList Apply();

 protected:
  // a helper find to get node data's debug information
  std::string DebugNodeData(NodeData* node);

  // create node's output node, whose name is output_id
  NodeData* CreateOutputNode(NodePtr node, const std::string& output_id = "");

  // get group's op kind
  OpPatternKind GetOpKind(const framework::Node* node);

  // create a group, the group only has one node
  GroupPtr CreateSingleNodeGroup(NodePtr node_ptr);

  // topological order nodes list
  std::vector<Node*> TopologicalOrder(const std::vector<Node*>& nodes);

  // copy and create new output from old_node, and link to new_node
  void CreateCheckNodeOutputs(Node* old_node, NodePtr new_node);

  // relink new_node's input node
  void RelinkNodeInputs(Node* old_node, NodePtr new_node);

  // create check fusion accuracy pass node
  NodePtr CreateCheckNode(Node* node);

  std::pair<NodePtr, NodeData*> CreateIsCloseNode(const std::string& node_id);

  std::pair<NodePtr, NodeData*> CreateAllNode(const std::string& node_id);

119 120
  std::pair<NodePtr, NodeData*> CreateAssertNode(const std::string& node_id,
                                                 utils::AssertMsg* assert_msg);
121 122

  // the AssertAllClose operator are composed of isclose+all+assert
123 124 125 126
  std::vector<NodePtr> CreateAssertAllClose(
      const std::string& node_id,
      utils::AssertMsg* assert_msg,
      const std::vector<NodeData*>& inputs);
127

128 129 130 131 132
  // link origin group's output and pass group's output to the AssertAllClose
  // nodes
  GroupList LinkToAssertAllClose(
      const std::unordered_set<NodeData*>& group_outputs,
      utils::AssertMsg* msg);
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149

  // skip check some op and var, now only support check float dtype
  bool IsSkipVar(const NodeData* var);

 private:
  Graph* graph_;
  std::unordered_map<NodeData*, NodeData*> old2new_nodedata_map_;
  std::unordered_map<NodeData*, NodeData*> constnode_old2new_map_;

  ShapeDict& shape_dict_;
  DtypeDict& dtype_dict_;

  static std::atomic_int key_count_;
};

std::atomic_int CheckFusionAccuracyPass::key_count_{0};

150 151 152
bool CheckFusionAccuracyPass::IsSkipVar(const NodeData* var) {
  return !dtype_dict_.at(var->id()).is_float();
}
153 154 155

std::string CheckFusionAccuracyPass::DebugNodeData(NodeData* node) {
  std::stringstream ss;
156 157
  ss << node->id() << "{shape=["
     << cinn::utils::Join(shape_dict_.at(node->id()), ", ")
158 159 160 161
     << "], dtype=" << dtype_dict_.at(node->id()) << "}";
  return ss.str();
}

162 163
NodeData* CheckFusionAccuracyPass::CreateOutputNode(
    NodePtr node, const std::string& output_id) {
164 165 166 167 168 169 170
  // create node's output data node
  auto node_id = output_id;
  if (node_id.empty()) {
    node_id = cinn::common::UniqName("var_" + node->id());
  }

  CHECK(graph_->RetrieveNode(node_id) == nullptr)
171 172
      << "The node " << node->op()->name << "'s output " << node_id
      << " had been registered in graph! Please check.";
173 174 175 176 177 178 179 180

  auto* output_data = new NodeData(node, 0, 0, node_id);
  node->LinkTo(output_data);
  graph_->RegisterNode(node_id, output_data);

  return output_data;
}

181 182
void CheckFusionAccuracyPass::CreateCheckNodeOutputs(Node* old_node,
                                                     NodePtr new_node) {
183 184 185
  const auto& outlinks = old_node->outlinks_in_order();
  for (const auto& out_edge : outlinks) {
    auto out_node = out_edge->sink()->safe_as<NodeData>();
186 187
    CHECK(out_node) << "Node " << old_node->id()
                    << "'s output node is nullptr! Please check.";
188 189 190 191

    const auto& out_node_id = out_node->id();
    // If the check node's output variable node not created
    if (!FusionHelperBase::IsConstOp(old_node)) {
192 193 194 195
      // note the const op will recompute in group, so that the op may appear in
      // many group CHECK_EQ(old2new_nodedata_map_.count(out_node), 0)
      //    << "Var " << out_node_id << " repeated! The graph is not a SSA
      //    graph! Please check.";
196 197 198 199
    }

    const auto& check_out_node_id = GenerateAccCheckNodeId(out_node_id);

200
    auto check_out_node = CreateOutputNode(new_node, check_out_node_id);
201 202 203 204 205 206 207 208
    check_out_node->output_index = out_node->output_index;

    auto check_out_shape = shape_dict_.at(out_node_id);
    shape_dict_.emplace(check_out_node_id, std::move(check_out_shape));

    auto check_out_dtype = dtype_dict_.at(out_node_id);
    dtype_dict_.emplace(check_out_node_id, std::move(check_out_dtype));

209 210 211
    VLOG(4) << "Create the check fusion accuracy node of node "
            << old_node->id() << "'s output node " << DebugNodeData(out_node)
            << " success, which is " << DebugNodeData(check_out_node);
212 213 214 215 216

    old2new_nodedata_map_[out_node] = check_out_node;
  }
}

217 218
void CheckFusionAccuracyPass::RelinkNodeInputs(Node* old_node,
                                               NodePtr new_node) {
219 220 221
  const auto& inlinks = old_node->inlinks_in_order();
  for (const auto& in_edge : inlinks) {
    auto in_node = in_edge->source()->safe_as<NodeData>();
222 223
    CHECK(in_node) << "Node " << old_node->id()
                   << "'s input node is nullptr! Please check.";
224 225 226 227 228 229 230 231 232 233

    if (old2new_nodedata_map_.count(in_node)) {
      old2new_nodedata_map_[in_node]->LinkTo(new_node.get());
    } else {
      in_node->LinkTo(new_node.get());
    }
  }
}

NodePtr CheckFusionAccuracyPass::CreateCheckNode(Node* node) {
234 235
  CHECK(node->op()) << "Node " << node->id()
                    << " is not operator! Please check.";
236 237 238 239

  const auto& check_node_id = GenerateAccCheckNodeId(node->id());

  CHECK(graph_->RetrieveNode(check_node_id) == nullptr)
240 241
      << "The node " << node->id() << "'s check fusion accuracy node"
      << check_node_id << " had been registered in graph! Please check.";
242

243 244
  auto check_node = Node::Create(
      node->op(), GenerateAccCheckNodeId(node->attrs.node_name), check_node_id);
245 246 247 248 249 250 251
  check_node->attrs.attr_store = node->attrs.attr_store;

  graph_->RegisterNode(check_node_id, check_node.get());

  CreateCheckNodeOutputs(node, check_node);
  RelinkNodeInputs(node, check_node);

252 253
  VLOG(4) << "Create node " << framework::DebugString(node)
          << "'s check fusion accuracy node success, which is "
254 255 256 257 258 259
          << framework::DebugString(check_node.get());

  return check_node;
}

OpPatternKind CheckFusionAccuracyPass::GetOpKind(const framework::Node* node) {
260 261 262 263
  auto op_pattern_dict_ =
      &framework::Operator::GetAttrs<OpPatternKind>("OpPattern");
  CHECK(op_pattern_dict_->Find(node->op()))
      << "Don't find the pattern of op : " << node->id();
264 265 266
  auto kind = op_pattern_dict_[0][node->op()];

  if (kind == framework::kBroadcast) {
267 268
    // As binary op was defined as broadcast, actually it should be
    // element-wise.
269 270 271 272 273 274 275 276 277
    if (node->op()->name != "broadcast_to") {
      return framework::kElementWise;
    }
  }

  return kind;
}

GroupPtr CheckFusionAccuracyPass::CreateSingleNodeGroup(NodePtr node_ptr) {
278
  auto node = node_ptr.get();
279 280 281 282 283 284 285 286
  auto group = std::make_shared<Graph::Group>();
  // init group
  group->nodes.push_back(node);
  group->nodes_set.insert(node);
  group->output_nodes.insert(node);
  // input node
  for (auto& edge : node->inlinks()) {
    auto input_graph_node = edge->source();
287
    auto input_node_data = input_graph_node->safe_as<NodeData>();
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303
    CHECK(input_node_data);
    // input data has no source node
    if (input_node_data->source_node.get()) {
      group->input_nodes[input_node_data->source_node.get()] = 1;
    }
  }

  // group type
  group->op_pattern_kind = GetOpKind(node);
  // use current node as master node for schedule
  group->master_nodes.insert(node);
  group->group_id = node->id();

  return group;
}

304 305
std::pair<NodePtr, NodeData*> CheckFusionAccuracyPass::CreateIsCloseNode(
    const std::string& node_id) {
306 307
  const auto& is_close_node_id = "isclose_" + node_id;

308 309 310
  auto is_close_node = Node::Create(Operator::Get("isclose"),
                                    GenerateAccCheckNodeId("isclose"),
                                    is_close_node_id);
311
  is_close_node->attrs.attr_store["rtol"] =
312 313
      cinn::runtime::utils::AssertTrueMsgTool::GetInstance()
          ->GetFlagValue<float>("rtol");
314
  is_close_node->attrs.attr_store["atol"] =
315 316
      cinn::runtime::utils::AssertTrueMsgTool::GetInstance()
          ->GetFlagValue<float>("atol");
317
  is_close_node->attrs.attr_store["equal_nan"] =
318 319
      cinn::runtime::utils::AssertTrueMsgTool::GetInstance()
          ->GetFlagValue<bool>("equal_nan");
320 321 322 323 324 325 326 327 328 329

  graph_->RegisterNode(is_close_node_id, is_close_node.get());

  // create node's output data node
  auto output_data = CreateOutputNode(is_close_node);

  auto check_out_shape = shape_dict_.at(node_id);
  shape_dict_.emplace(output_data->id(), std::move(check_out_shape));
  dtype_dict_.emplace(output_data->id(), common::Bool());

330 331
  VLOG(4) << "Create node " << node_id
          << "'s isclose node success, whose id is " << is_close_node_id
332 333 334 335 336
          << ", whose output is " << DebugNodeData(output_data);

  return {is_close_node, output_data};
}

337 338
std::pair<NodePtr, NodeData*> CheckFusionAccuracyPass::CreateAllNode(
    const std::string& node_id) {
339 340
  const auto& all_node_id = "all_" + node_id;

341 342 343
  auto all_node = Node::Create(Operator::Get("reduce_all"),
                               GenerateAccCheckNodeId("reduce_all"),
                               all_node_id);
344 345 346 347 348 349

  int shape_size = shape_dict_[node_id].size();
  std::vector<int> axes(shape_size);
  for (int i = 0; i < shape_size; ++i) {
    axes[i] = i;
  }
350
  all_node->attrs.attr_store["dim"] = axes;
351 352 353 354 355 356 357 358 359 360
  all_node->attrs.attr_store["keep_dim"] = false;

  graph_->RegisterNode(all_node_id, all_node.get());

  // create node's output data node
  auto output_data = CreateOutputNode(all_node);

  shape_dict_.emplace(output_data->id(), framework::shape_t{1});
  dtype_dict_.emplace(output_data->id(), common::Bool());

361 362
  VLOG(4) << "Create node " << node_id << "'s all node success, whose id is "
          << all_node_id << ", whose output is " << DebugNodeData(output_data);
363 364 365 366

  return {all_node, output_data};
}

367 368
std::pair<NodePtr, NodeData*> CheckFusionAccuracyPass::CreateAssertNode(
    const std::string& node_id, utils::AssertMsg* assert_msg) {
369 370
  const auto& assert_node_id = "assert_" + node_id;

371 372 373 374 375 376
  auto assert_node = Node::Create(Operator::Get("assert_true"),
                                  GenerateAccCheckNodeId("assert_true"),
                                  assert_node_id);
  // TODO(thisjiang): change type from 'int' to 'std::string' when custom call
  // support 'std::string' type
  int msg_key = key_count_.fetch_add(1);
377
  assert_node->attrs.attr_store["msg"] = msg_key;
378 379
  cinn::runtime::utils::AssertTrueMsgTool::GetInstance()->SetMsg(
      msg_key, assert_msg->str());
380
  assert_node->attrs.attr_store["only_warning"] =
381 382
      cinn::runtime::utils::AssertTrueMsgTool::GetInstance()
          ->GetFlagValue<bool>("only_warning");
383 384 385 386 387 388 389 390 391

  graph_->RegisterNode(assert_node_id, assert_node.get());

  // create node's output data node
  auto output_data = CreateOutputNode(assert_node);

  shape_dict_.emplace(output_data->id(), framework::shape_t{1});
  dtype_dict_.emplace(output_data->id(), common::Bool());

392 393 394
  VLOG(4) << "Create node " << node_id << "'s assert node success, whose id is "
          << assert_node_id << ", whose output is "
          << DebugNodeData(output_data);
395 396 397 398

  return {assert_node, output_data};
}

399 400 401 402
std::vector<NodePtr> CheckFusionAccuracyPass::CreateAssertAllClose(
    const std::string& node_id,
    utils::AssertMsg* assert_msg,
    const std::vector<NodeData*>& inputs) {
403 404 405 406 407 408 409 410 411 412 413 414 415 416
  std::vector<NodePtr> group_nodes;
  // create isclose + all + assert nodes
  // create isclose node and link inputs to the node
  const auto& is_close_node = CreateIsCloseNode(node_id);
  for (auto in_data : inputs) {
    in_data->LinkTo(is_close_node.first.get());
  }
  group_nodes.emplace_back(is_close_node.first);

  // create assert node
  const auto& assert_node = CreateAssertNode(node_id, assert_msg);

  // check and create all node
  auto in_shape = shape_dict_[node_id];
417 418
  int prod_size = std::accumulate(
      in_shape.begin(), in_shape.end(), 1, std::multiplies<int>());
419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
  if (prod_size > 1) {
    // need reduce
    const auto& all_node = CreateAllNode(node_id);

    is_close_node.second->LinkTo(all_node.first.get());
    all_node.second->LinkTo(assert_node.first.get());

    group_nodes.emplace_back(all_node.first);
  } else {
    // do not need reduce
    is_close_node.second->LinkTo(assert_node.first.get());
  }
  group_nodes.emplace_back(assert_node.first);

  return group_nodes;
}

436 437
GroupList CheckFusionAccuracyPass::LinkToAssertAllClose(
    const std::unordered_set<NodeData*>& group_outputs, utils::AssertMsg* msg) {
438 439 440 441
  GroupList assert_groups;
  for (auto* group_out : group_outputs) {
    const auto& out_node_id = group_out->id();
    if (IsSkipVar(group_out)) {
442 443 444 445
      LOG(WARNING) << "The CheckFusionAccuracyPass only support check float "
                      "point dtype data now, skip check node \""
                   << out_node_id
                   << "\", who's dtype=" << dtype_dict_.at(out_node_id);
446 447
      continue;
    }
448 449 450 451
    CHECK(old2new_nodedata_map_.count(group_out))
        << "The check fusion accuracy's node corresponding to " << out_node_id
        << " had not been created! Please check.";
    auto pass_out = old2new_nodedata_map_.at(group_out);
452 453 454
    const auto& acc_check_out_id = pass_out->id();

    msg->SetMsg("Var Name", out_node_id);
455 456 457 458 459 460 461
    msg->SetMsg(
        "Suggestion",
        cinn::utils::StringFormat("You can check the value by set "
                                  "FLAGS_cinn_self_check_accuracy and compare "
                                  "the result between \"%s\" and \"%s\"",
                                  out_node_id.c_str(),
                                  acc_check_out_id.c_str()));
462

463 464
    const auto& nodes =
        CreateAssertAllClose(acc_check_out_id, msg, {group_out, pass_out});
465 466 467 468 469 470 471 472

    for (const auto& node : nodes) {
      assert_groups.emplace_back(CreateSingleNodeGroup(node));
    }
  }
  return assert_groups;
}

473 474
std::vector<Node*> CheckFusionAccuracyPass::TopologicalOrder(
    const std::vector<Node*>& nodes) {
475
  struct NodeCompare {
476 477 478
    bool operator()(Node* lhs, Node* rhs) const {
      return lhs->id() < rhs->id();
    }
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
  };

  std::set<Node*, NodeCompare> node_set(nodes.begin(), nodes.end());

  // count all node's output to find the group's start node
  std::unordered_set<NodeData*> all_outputs;
  for (auto node : node_set) {
    for (auto& out_edge : node->outlinks_in_order()) {
      all_outputs.insert(out_edge->sink()->safe_as<NodeData>());
    }
  }

  // if the node's input is not any group node's output, it's start node
  std::deque<Node*> queue;
  std::unordered_map<Node*, int> indegree;
  for (auto node : node_set) {
    bool is_start = true;
    for (auto& in_edge : node->inlinks_in_order()) {
      if (all_outputs.count(in_edge->source()->safe_as<NodeData>())) {
        // if the node's input is some group node's output, it's not start node
        is_start = false;
        indegree[node]++;
      }
    }
    if (is_start) {
      queue.emplace_back(node);
    }
  }

  std::vector<Node*> ordered_nodes;
  // start to visit
  while (!queue.empty()) {
    auto top_node = queue.front();
    ordered_nodes.push_back(top_node);

    queue.pop_front();

    for (auto& out_edge : top_node->outlinks_in_order()) {
      // the output of node is a variable node, not op node
      auto out_data = out_edge->sink()->safe_as<NodeData>();

      for (auto out_data_edge : out_data->outlinks()) {
        // the variable node's output are the required output nodes
        auto out_node = out_data_edge->sink()->safe_as<Node>();
        if (indegree.count(out_node) && (--indegree[out_node]) == 0) {
524 525
          // if the output node in group and its input nodes are all visited,
          // push
526 527 528 529 530 531
          queue.push_back(out_node);
        }
      }
    }
  }

532 533
  CHECK_EQ(ordered_nodes.size(), nodes.size())
      << "There has circle in group! Please check.";
534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553

  return ordered_nodes;
}

GroupList CheckFusionAccuracyPass::Apply() {
  GroupList check_fusion_groups;

  std::unordered_set<std::string> fetch_ids;
  for (auto* node : graph_->outputs) {
    fetch_ids.emplace(node->id());
  }

  int i = 0;
  for (auto& group : graph_->fusion_groups) {
    check_fusion_groups.emplace_back(group);

    const auto& group_nodes = group->CollectNodes();

    // fusion group only has one node, do not need check, skip
    if (group_nodes.size() <= 1) {
554 555
      VLOG(4) << "The Group " << group->GetFuncName()
              << " just has one node, skip.";
556 557 558 559 560
      continue;
    }

    // split orign group and create group for each node
    const auto& ordered_nodes = TopologicalOrder(group_nodes);
561 562
    VLOG(4) << "Check the accuracy of group "
            << graph_->DebugGroupedGraph(ordered_nodes);
563 564 565

    for (auto* node : ordered_nodes) {
      if (node->is_variable()) {
566 567
        VLOG(4) << "The node " << node->id()
                << " is variable, skip check fusion accuracy.";
568 569 570 571 572 573 574 575 576 577 578 579 580
        continue;
      }

      auto check_node = CreateCheckNode(node);
      check_fusion_groups.push_back(CreateSingleNodeGroup(check_node));
    }

    // set assert debug info
    utils::AssertMsg msg(i);
    msg.SetMsg("Kernel name", group->GetFuncName());
    msg.SetMsg("Group id", std::to_string(i));
    msg.SetMsg(
        "Group structure",
581 582 583 584
        cinn::utils::StringFormat(
            "\nGroup %d {\n%s}",
            i,
            graph_->DebugGroupedGraph(ordered_nodes, fetch_ids).c_str()));
585 586

    // link the group's output data node to assert all close node
587 588 589 590
    const auto& assert_group =
        LinkToAssertAllClose(group->GetOutputNodeDatas(), &msg);
    check_fusion_groups.insert(
        check_fusion_groups.end(), assert_group.begin(), assert_group.end());
591 592 593 594 595 596

    i++;
  }
  return check_fusion_groups;
}

597 598 599
void CheckFusionAccuracyPassImpl(Graph* graph) {
  graph->fusion_groups = CheckFusionAccuracyPass(graph).Apply();
}
600 601 602 603 604 605 606 607 608 609 610

}  // namespace cinn::hlir::pass

CINN_REGISTER_HELPER(CheckFusionAccuracyPass) {
  CINN_REGISTER_PASS(CheckFusionAccuracyPass)
      .describe("Check Fusion Accuracy Pass.")
      .set_change_structure(true)
      .set_body(cinn::hlir::pass::CheckFusionAccuracyPassImpl);

  return true;
}