rule_based_tuner.py 22.8 KB
Newer Older
C
caozhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import math
from abc import abstractmethod
17

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
from ..graph import Graph

_PATTERNS = {}


def register_pattern(cls):
    """Register pattern for rule-based tuner."""
    name = cls.name

    def register(name):
        global _PATTERNS
        _PATTERNS[name] = cls()

    register(name)

    return cls


36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
class BasePattern(Graph):
    name = "base"

    def __init__(self):
        super().__init__()
        self.build()

    @abstractmethod
    def build(self):
        pass


@register_pattern
class QKVPattern(BasePattern):
    name = "qkv"

    def __init__(self):
        super().__init__()

    def build(self):
        query = self.add_node(0, **{"type": "var"})

        q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
        k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
        v_weight = self.add_node(3, **{"dim": 2, "type": "param"})

        q_matmul = self.add_node(4, **{"type": "matmul_v2"})
        k_matmul = self.add_node(5, **{"type": "matmul_v2"})
        v_matmul = self.add_node(6, **{"type": "matmul_v2"})

        q_x = self.add_edge(0, 4, **{"input_name": "X"})
        k_x = self.add_edge(0, 5, **{"input_name": "X"})
        v_x = self.add_edge(0, 6, **{"input_name": "X"})
        q_y = self.add_edge(1, 4, **{"input_name": "Y"})
        k_y = self.add_edge(2, 5, **{"input_name": "Y"})
        v_y = self.add_edge(3, 6, **{"input_name": "Y"})

        q = self.add_node(7, **{"type": "var"})
        k = self.add_node(8, **{"type": "var"})
        v = self.add_node(9, **{"type": "var"})

        q_out = self.add_edge(4, 7, **{"output_name": "Out"})
        k_out = self.add_edge(5, 8, **{"output_name": "Out"})
        v_out = self.add_edge(6, 9, **{"output_name": "Out"})

        # Pattern
        self.attrs["shard_spec"] = [
            [(1, 2, 3), [[-1, 0], [-1, 1]]],
C
chenxujun 已提交
84
        ]  # 2-tuple list such as [(tensor_id, shard_spec)]
85 86


87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
def convert_to_graph(ops, block):
    """Convert ops to graph."""
    graph = Graph()
    graph.attrs["var_to_id"] = {}  # {var_name: node_id}
    graph.attrs["id_to_var"] = {}  # {node_id: var_name}
    graph.attrs["op_to_id"] = {}  # {op_id: node_id}
    graph.attrs["id_to_op"] = {}  # {node_id: op_id}

    node_id = -1
    for op in ops:
        attrs = op.all_attrs()
        attrs["type"] = op.type
        node_id += 1

        # create op node
        op_node = graph.add_node(node_id, **attrs)
        graph.attrs["op_to_id"][op.desc.id()] = op_node.id
        graph.attrs["id_to_op"][op_node.id] = op.desc.id()
105
        graph._attr_to_nodes[op_node.id] = {}
106
        for input_name in op.input_names:
107
            graph._attr_to_nodes[op_node.id][input_name] = []
108 109 110 111 112 113 114 115
            for var_name in op.input(input_name):
                if var_name not in graph.attrs["var_to_id"]:
                    # create var node
                    node_id += 1
                    var_node = graph.add_node(node_id)
                    var = block._var_recursive(var_name)
                    if var.is_parameter:
                        var_node.attrs["type"] = "param"
116
                        var_node.attrs["dim"] = len(var.shape)
117 118 119 120 121 122 123 124 125 126 127
                    else:
                        var_node.attrs["type"] = "var"
                    graph.attrs["var_to_id"][var_name] = var_node.id
                    graph.attrs["id_to_var"][var_node.id] = var_name
                else:
                    var_node_id = graph.attrs["var_to_id"][var_name]
                    var_node = graph._nodes[var_node_id]

                # create edge that input -> op
                input_edge = graph.add_edge(var_node.id, op_node.id)
                input_edge.attrs["input_name"] = input_name
128
                graph._attr_to_nodes[op_node.id][input_name].append(var_node)
129 130

            for output_name in op.output_names:
131
                graph._attr_to_nodes[op_node.id][output_name] = []
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
                for var_name in op.output(output_name):
                    if var_name not in graph.attrs["var_to_id"]:
                        # create var node
                        node_id += 1
                        var_node = graph.add_node(node_id)
                        var = block._var_recursive(var_name)
                        if var.is_parameter:
                            var_node.attrs["type"] = "param"
                        else:
                            var_node.attrs["type"] = "var"
                        graph.attrs["var_to_id"][var_name] = var_node.id
                        graph.attrs["id_to_var"][var_node.id] = var_name
                    else:
                        var_node_id = graph.attrs["var_to_id"][var_name]
                        var_node = graph._nodes[var_node_id]

                    # create edge that op -> output
                    output_edge = graph.add_edge(op_node.id, var_node.id)
                    output_edge.attrs["output_name"] = output_name

152 153 154 155
                    graph._attr_to_nodes[op_node.id][output_name].append(
                        var_node
                    )

156 157 158
    return graph


159 160 161 162 163
def match(pattern, graph):
    def _is_op_node(node):
        """Judge whether node is op node"""
        if node.attrs["type"] not in ["var", "param", "data"]:
            return True
164

165
        return False
166

167 168 169 170
    def _compare_op_node(src, tgt):
        """Compare whether two op nodes are equal"""
        if src.attrs["type"] != tgt.attrs["type"]:
            return False
171

172
        return True
173

174 175 176 177 178 179 180
    def _compare_var_node(src, tgt):
        """Compare whether two var nodes are equal"""
        for key in src.attrs:
            if key not in tgt.attrs:
                return False
            if src.attrs[key] != tgt.attrs[key]:
                return False
181

182
        return True
183

184 185 186 187 188
    def _match_core(src_node, tgt_node):
        nonlocal not_matched
        # do not support one input name or output name corresponding to multiple vars
        if not_matched:
            return
189

190 191 192 193
        if _is_op_node(src_node):
            # compare op node whether equal
            if not _compare_op_node(src_node, tgt_node):
                return
194

195
            result[src_node.id] = tgt_node.id
196

197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
            # input var nodes
            src_input_nodes = src_reverse_adjs[src_node.id]
            for node in src_input_nodes:
                # has visited
                if node.id in result:
                    continue
                edge = src_edges[node.id][src_node.id]
                input_name = edge.attrs["input_name"]

                # NOTE: do not support one input name or output name corresponding to multiple vars
                compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
                    input_name, None
                )
                if not compare_nodes:
                    not_matched = True
                    return
                _match_core(node, compare_nodes[0])

            # output var nodes
            src_output_node_ids = src_edges[src_node.id].keys()
            for node_id in src_output_node_ids:
                # has visited
                if node_id in result:
                    continue
                node = src_nodes[node_id]
                edge = src_edges[src_node.id][node_id]
                output_name = edge.attrs["output_name"]

                # NOTE: do not support one input name or output name corresponding to multiple vars
                compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
                    output_name, None
                )
                if not compare_nodes:
                    not_matched = True
                    return
                _match_core(node, compare_nodes[0])
233

234 235 236 237 238
        else:
            # compare var node whether equal
            if not _compare_var_node(src_node, tgt_node):
                not_matched = True
                return
239

240
            result[src_node.id] = tgt_node.id
241

242 243 244 245 246
            # as input for op nodes
            src_as_input_node_ids = src_edges[src_node.id].keys()
            for node_id in src_as_input_node_ids:
                if node_id in result:
                    continue
247

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
                src_edge = src_edges[src_node.id][node_id]
                input_name = src_edge.attrs["input_name"]
                compare_node_ids = tgt_edges[tgt_node.id].keys()

                compare_node = None
                for compare_node_id in compare_node_ids:
                    edge = tgt_edges[tgt_node.id][compare_node_id]
                    if (
                        edge.attrs["input_name"] == input_name
                        and compare_node_id not in result.values()
                    ):
                        compare_node = tgt_nodes[compare_node_id]
                        break

                if not compare_node:
                    not_matched = True
                    return
                _match_core(src_nodes[node_id], compare_node)

            # as output for nodes
            src_as_output_nodes = src_reverse_adjs[src_node.id]
            for node in src_as_output_nodes:
                if node.id in result:
                    continue

                src_edge = src_edges[node.id][src_node.id]
                output_name = src_edge.attrs["output_name"]
275

276
                compare_node_ids = tgt_reverse_adjs[tgt_node.id]
C
caozhou 已提交
277

278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
                compare_node = None
                for node_id in compare_node_ids:
                    edge = tgt_edges[node_id][tgt_node.id]
                    if edge.attrs["output_name"] == output_name:
                        compare_node = tgt_nodes[node_id]
                        break
                if not compare_node:
                    not_matched = True
                    return
                _match_core(src_nodes[node_id], compare_node)

    results = []
    result = {}
    has_matched = set()
    src_nodes = pattern.nodes
    src_edges = pattern._adjs
    src_reverse_adjs = pattern._reverse_adjs

    tgt_nodes = graph.nodes
    tgt_edges = graph._adjs
    tgt_reverse_adjs = graph._reverse_adjs
    tgt_attr_to_nodes = graph._attr_to_nodes
    not_matched = False

    # starts with a op node
    src_start_node = None
    for node_id in src_nodes:
        node = src_nodes[node_id]
        if node.attrs["type"] not in ["var", "param", "data"]:
            src_start_node = node
            break
    assert src_start_node is not None

    for node_id in tgt_nodes:
        node = tgt_nodes[node_id]
        if node.attrs["type"] == src_start_node.attrs["type"]:
            _match_core(src_start_node, node)
            if not not_matched:
                need_to_append = True
                for value in result.values():
                    if value in has_matched:
                        result = {}
                        need_to_append = False
                        break
                if need_to_append:
                    results.append(result)
                    for value in result.values():
                        has_matched.add(value)
                    result = {}
            else:
                not_matched = False
                result = {}

    return results


class OperatorClusteringUtil:
C
caozhou 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
    common_starts = ["layer_norm", "matmul_v2", "matmul"]

    @staticmethod
    def get_ranks(seq):
        """Get rank array of the given seq by doubled algorithm."""
        ordered_seq = sorted(list(set(seq)))
        item_to_rank = {item: idx for idx, item in enumerate(ordered_seq)}
        inter_ranks = [item_to_rank[item] for item in seq]

        length = len(inter_ranks)
        power = 0
        interval = 2**power
        while interval < length:
            for idx, item in enumerate(inter_ranks):
                if idx + interval >= length:
                    inter_ranks[idx] = [item, -1]
                else:
                    inter_ranks[idx] = [item, inter_ranks[idx + interval]]

            tmp = []
            for item in inter_ranks:
                if item not in tmp:
                    tmp.append(item)
            tmp.sort(key=lambda x: (x[0], x[1]))
            item_to_rank = {}
            for idx, val in enumerate(tmp):
                key = ",".join(str(item) for item in val)
                item_to_rank[key] = idx

            inter_ranks = [
                item_to_rank[",".join(str(val) for val in item)]
                for item in inter_ranks
            ]
            power += 1
            interval = 2**power

        return inter_ranks

    @staticmethod
    def get_suffixes(ranks):
        """Get suffix array by the given rank array."""
        suffixes = [0 for idx in range(len(ranks))]
        for idx, item in enumerate(ranks):
            suffixes[item] = idx
        return suffixes

    @staticmethod
    def get_heights(suffixes, seq):
        """Get height array by the suffix array and seq"""
        heights = [-1 for i in range(len(suffixes))]
        for i in range(1, len(seq)):
            x = seq[suffixes[i - 1] :]
            y = seq[suffixes[i] :]
            max_len = len(x) if len(x) > len(y) else len(y)
            same_count = 0
            for j in range(max_len):
                if j >= len(x) or j >= len(y):
                    break
                else:
                    if x[j] == y[j]:
                        same_count += 1
                    else:
                        break
            heights[i] = same_count

        return heights

    @staticmethod
    def get_longest_repeated_sub_seq(suffixes, heights, seq):
        """Get longest repeated sub sequence by suffix array algorithm."""
        length = len(seq)
        if length <= 1:
            return None
        k = length // 2
        height_groups = []
        longest_sub_seq = None
        longest_sub_seqs = []

        while k >= 2:
            height_group = []
            for i in range(1, len(heights)):
                if heights[i] >= k:
                    if i == 1:
                        height_group.append(0)
                    height_group.append(i)
                else:
                    if i == 1:
                        height_groups.append([0])
                        height_group = [i]
                    else:
                        height_groups.append(height_group)
                        height_group = [i]

            if height_group:
                height_groups.append(height_group)

            for height_group in height_groups:
                suffix_group = []
                index_group = []
                for idx in height_group:
                    suffix_group.append(idx)
                    index_group.append(suffixes[idx])

                max_index = max(index_group)
                min_index = min(index_group)
                if max_index - min_index >= k:
                    longest_sub_seq = seq[min_index : min_index + k]
442 443 444 445
                    if (
                        longest_sub_seq[0]
                        in OperatorClusteringUtil.common_starts
                    ):
C
caozhou 已提交
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
                        return longest_sub_seq
            if longest_sub_seq is not None:
                return longest_sub_seq

            k -= 1
            height_groups = []

        return longest_sub_seq

    @staticmethod
    def get_decomposed_sub_seq(seq):
        """Get decomposed sub seq s by seq S such as s * R = S."""
        if not seq:
            return seq

        decomposed_sub_seq = seq
        seq_len = len(seq)
        if seq_len == 1:
            return decomposed_sub_seq
        else:
            for interval in range(2, seq_len + 1):
                if seq_len % interval == 0:
                    repeated_times = seq_len // interval
                    decomposed_sub_seq = seq[0:interval]
                    decomposed = True
                    for j in range(1, repeated_times + 1):
                        sub_seq = seq[interval * (j - 1) : interval * j]
                        if sub_seq != decomposed_sub_seq:
                            decomposed = False
                            break
                    if decomposed:
                        return decomposed_sub_seq

        return decomposed_sub_seq

    @staticmethod
    def replace_by_decomposed_seq(sub_seq, seq):
        """Replace seq by sub seq."""
        if not sub_seq:
            return seq

        result = []
        sub_seq_len = len(sub_seq)
        i = 0
        while i < len(seq):
            if seq[i : i + sub_seq_len] == sub_seq:
                result.append(seq[i : i + sub_seq_len])
                i += sub_seq_len
            else:
                result.append(seq[i])
                i += 1

        return result

    @staticmethod
    def stop_replace(seq):
        for item in seq:
            if not isinstance(item, list):
                return False
        return True


508 509 510 511 512 513 514 515 516 517 518
class ClusterPartitionUtil:
    @staticmethod
    def factorization(num):
        factors = []
        for i in range(1, int(math.floor(math.sqrt(num))) + 1):
            if num % i == 0:
                factors.append([i, int(num / i)])
        return factors

    @staticmethod
    def complete_meshes(partitions: list, num: int):
519 520 521 522 523
        if num == 2:
            return [[1, 2], [2, 1]]
        if num == 3:
            return [[1, 2], [2, 1], [1]]
        # special cases
524 525 526 527 528 529 530 531 532 533 534 535 536 537
        if len(partitions) == 1:
            partitions = ClusterPartitionUtil.factorization(num - 1)
            partitions.append([1])
        return partitions

    @staticmethod
    def partition_cluster(
        n: int,
        m: int,
        filter=[
            complete_meshes.__func__,
        ],
    ) -> list:
        """
C
chenxujun 已提交
538
        Partition cluster into possible device meshes.
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606

        Args:
            n (int): The number of nodes.
            m (int): The number of single devices on each node.
            filter (list): Functions for filtering useful meshes

        Returns:
            device_meshed (list) : The possible device meshes.
        """
        partition_result = ClusterPartitionUtil.factorization(n)
        for func in filter:
            partition_result = func(partition_result, n)
        device_meshes = []
        if n == 1:
            partition_result = ClusterPartitionUtil.factorization(m)
            for partition in partition_result:
                device_mesh = []
                for i in range(partition[0]):
                    device_mesh.append([1, partition[1]])
                device_meshes.append(device_mesh)
        else:
            incerement = 1 if partition_result[-1] == [1] else 0
            for partition in partition_result:
                if len(partition) < 2:
                    continue
                device_mesh = []
                for i in range(partition[0]):
                    device_mesh.append([partition[1], m])
                device_mesh[-1][0] += incerement
                device_meshes.append(device_mesh)

        return device_meshes


def convert_to_process_meshes(device_mesh: list) -> list:
    """
    Transfer device_meshes into possible process meshes.

    Args:
        device meshes (list): [n,m], one device mesh.

    Returns:
        process_meshes (list): Possible process_meshes
    """
    n, m = device_mesh[0], device_mesh[1]
    factors = (
        ClusterPartitionUtil.factorization(m)
        if n == 1
        else ClusterPartitionUtil.factorization(n)
    )
    process_meshes = []
    if n == 1:
        for factor in factors:
            if factor[0] == 1:
                process_meshes.append([factor[1]])
                continue
            process_meshes.append(factor)
    else:
        for factor in factors:
            mul1, mul2 = factor[0], factor[1]
            if mul1 == 1:
                process_meshes.append([m * mul2])
            elif mul1 != mul2:
                process_meshes.append([int(n / mul2), m * mul2])
            process_meshes.append([int(n / mul1), m * mul1])
    return process_meshes


C
caozhou 已提交
607 608 609 610 611
class RuleBasedTuner:
    def __init__(self, dist_context, mode="train"):
        self._dist_context = dist_context
        self._mode = mode

612
    def cluster_operators(self, ops):
C
caozhou 已提交
613
        """
614
        Cluster operators to layers.
C
caozhou 已提交
615 616 617 618 619 620 621 622 623

        Args:
            ops (list): A operator list.

        Returns:
            List: The list contains the list of operators which belong to the same layer.
        """
        seq = [op.type for op in ops]

624
        while not OperatorClusteringUtil.stop_replace(seq):
C
caozhou 已提交
625 626 627 628 629 630 631 632 633 634 635 636 637
            to_replace_seq = []
            to_replace_idxes = []
            has_append = False
            for idx, item in enumerate(seq):
                if not isinstance(item, list):
                    has_append = True
                    to_replace_seq.append(item)
                    to_replace_idxes.append(idx)
                elif isinstance(seq, list) and not has_append:
                    continue
                elif isinstance(seq, list) and has_append:
                    break

638 639 640 641 642 643 644 645 646
            ranks = OperatorClusteringUtil.get_ranks(to_replace_seq)
            suffixes = OperatorClusteringUtil.get_suffixes(ranks)
            heights = OperatorClusteringUtil.get_heights(
                suffixes, to_replace_seq
            )
            longest_sub_seq = (
                OperatorClusteringUtil.get_longest_repeated_sub_seq(
                    suffixes, heights, to_replace_seq
                )
C
caozhou 已提交
647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664
            )
            has_merged = False
            if longest_sub_seq is None:
                for i in range(to_replace_idxes[-1] + 1, len(seq)):
                    if isinstance(seq[i], list):
                        seq[i] = to_replace_seq + seq[i]
                        has_merged = True
                        break
                if not has_merged:
                    for i in range(to_replace_idxes[0] - 1, -1, -1):
                        if isinstance(seq[i], list):
                            seq[i].extend(to_replace_seq)
                            has_merged = True
                            break
                if not has_merged:
                    seq = [to_replace_seq]
                    break

665
            decomposed_sub_seq = OperatorClusteringUtil.get_decomposed_sub_seq(
C
caozhou 已提交
666 667
                longest_sub_seq
            )
668
            to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq(
C
caozhou 已提交
669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686
                decomposed_sub_seq, to_replace_seq
            )
            result = seq[: to_replace_idxes[0]]
            if not has_merged:
                result.extend(to_replace_seq)
            result.extend(seq[to_replace_idxes[-1] + 1 :])
            seq = result

        layers = []
        idx = 0
        for groups in seq:
            layer = []
            for op in groups:
                layer.append(ops[idx])
                idx += 1
            layers.append(layer)

        return layers