estimate_cost.py 25.3 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

15 16 17 18 19 20 21
from collections import OrderedDict
from functools import reduce

import paddle
from paddle.distributed.fleet.meta_optimizers.common import OpRole

from ..dist_tensor import DistributedTensor
22 23
from ..operators.common import get_distributed_operator_impl_container
from .base_cost import Cost
24

25 26

class CostEstimator:
27
    _sepical_op_type = ["fused_attention", "fused_feedforward"]
28

29 30 31
    def __init__(
        self, program, cluster, mode="modeling", rank=None, loop_count=10
    ):
32 33 34 35
        self._program = program
        self._cluster = cluster
        self._check_mode(mode)
        self._mode = mode
36 37 38 39
        self._rank = rank if rank is not None else paddle.distributed.get_rank()
        self._loop_count = loop_count
        self._global_cost = Cost()
        self._local_cost_mapping = {}
40 41
        self._detailed_cost = (
            OrderedDict()
42 43 44
        )  # {`op_id`: {"reshard": [], "dist_op": [], "local_cost": local_cost}}}
        self._bubble_time_mapping = {}
        self._ordered_ops = []
45 46
        self.max_memories = {}
        self.max_memory = None
47 48 49 50 51 52 53 54

    @property
    def loop_count(self):
        return self._loop_count

    @property
    def detailed_cost(self):
        return self._detailed_cost
55 56 57 58 59

    @property
    def program(self):
        return self._program

60 61 62 63
    @property
    def rank(self):
        return self._rank

64 65 66 67 68 69 70 71 72 73 74 75 76 77
    @property
    def dist_context(self):
        return self._dist_context

    @property
    def cluster(self):
        return self._cluster

    @property
    def mode(self):
        return self._mode

    @property
    def global_cost(self):
78 79 80 81 82 83 84 85 86 87 88 89
        max_time = 0
        memory = 0
        flops = 0
        for rank in self._local_cost_mapping:
            cost = self._local_cost_mapping[rank]
            if cost.time > max_time:
                max_time = cost.time
            memory += cost.memory
            flops += cost.flops
        self._global_cost.time = max_time
        self._global_cost.memory = memory
        self._global_cost.flops = flops
90 91
        return self._global_cost

92 93 94 95
    def local_cost(self, rank=None):
        rank = self.rank if rank is None else rank
        if rank not in self._local_cost_mapping:
            self._local_cost_mapping[rank] = Cost()
96

97
        return self._local_cost_mapping[rank]
98

99 100 101
    def local_bubble_time(self, rank=None):
        rank = self.rank if rank is None else rank
        return self._bubble_time_mapping[rank]
102 103 104 105

    def _check_mode(self, mode):
        if mode not in ["modeling", "profiling"]:
            raise ValueError(
106
                f"Just support modeling and profiling, but got {mode}"
107
            )
108 109 110 111 112 113 114 115 116

    def _is_special_var_name(self, var_name):
        special_var_name = ["lod_tensor_blocking_queue_0"]
        if var_name in special_var_name:
            return True
        return False

    def _estimate_core(self, dist_context, resharder, block):
        from ..reshard import get_var_with_recursion
117

118 119 120 121 122 123 124 125 126
        ops = block.ops
        loop_count = None
        if block.desc.id != self.program.global_block().desc.id:
            loop_count = self.loop_count
        else:
            loop_count = 1
        for i in range(loop_count):
            for op in ops:
                self._detailed_cost[op.desc.id()] = OrderedDict()
127
                # If in the while sub block, the detail of cost is the last cost
128 129 130 131 132 133
                detail = self._detailed_cost[op.desc.id()]
                detail["reshard_cost"] = OrderedDict()  #
                detail["dist_op_cost"] = []
                if int(op.attr('op_role')) == int(OpRole.Optimize):
                    continue
                if op.type in [
134 135 136
                    "create_py_reader",
                    "create_double_buffer_reader",
                    "read",
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
                ]:
                    continue

                # NOTE: It does not support nested loop and just supports while op when op has sub block now.
                if op.type == "while":
                    while_block = self.program.blocks[op.attr("sub_block").id]
                    self._estimate_core(dist_context, resharder, while_block)
                    continue

                for var_name in op.input_arg_names:
                    if self._is_special_var_name(var_name):
                        continue
                    var = get_var_with_recursion(var_name, block, self.program)
                    reshard_cost = resharder.get_cost(op, var, self.cluster)

152
                    # Calc reshard cost
153 154 155 156 157 158
                    if reshard_cost is not None:
                        detail["reshard_cost"][var_name] = reshard_cost

                        comm_costs = reshard_cost[0]
                        local_comp_cost = reshard_cost[1]
                        for comm_cost in comm_costs:
159 160
                            # Time is cumulative in global cost and local cost, but memory and flops just are cumulative in global cost.
                            # Comm sync
161 162 163 164 165 166 167 168 169 170 171 172 173 174
                            for item in comm_cost:
                                group_ranks, cost = item
                                max_time = None
                                cost_time = {}
                                for rank in group_ranks:
                                    rank_cost = self.local_cost(rank)
                                    cost_time[rank] = rank_cost.time
                                    if max_time is None:
                                        max_time = rank_cost.time
                                    else:
                                        if max_time < rank_cost.time:
                                            max_time = rank_cost.time

                                for rank in group_ranks:
175 176 177
                                    self.local_cost(rank).time = (
                                        max_time + cost.time
                                    )
178 179 180 181 182

                                    if rank not in self._bubble_time_mapping:
                                        self._bubble_time_mapping[rank] = 0

                                    self._bubble_time_mapping[rank] += (
183 184
                                        max_time - cost_time[rank]
                                    )
185 186 187 188 189

                        for rank in local_comp_cost:
                            for comp_cost in local_comp_cost[rank]:
                                self.local_cost(rank).time += comp_cost.time

190
                # Calc dist op cost
191
                dist_op = dist_context.get_dist_op_for_program(op)
192 193 194
                if not dist_op:
                    continue

195
                op_dist_attr = dist_op.dist_attr
196
                processes = op_dist_attr.process_mesh.process_ids
197 198

                container = get_distributed_operator_impl_container(
199 200
                    op_dist_attr.impl_type
                )
201 202
                dist_impl = container.impls[op_dist_attr.impl_idx]

203 204 205
                dist_op_cost = dist_impl.calc_cost(
                    op.attr('op_role'), dist_op, dist_context, self.cluster
                )
206 207 208
                detail["dist_op_cost"] = dist_op_cost

                if dist_op_cost is None:
209 210 211
                    assert (
                        dist_op.serial_op.type in CostEstimator._sepical_op_type
                    )
212 213 214
                    continue
                for item in dist_op_cost:
                    if isinstance(item, list):
215
                        # Comm sync
216 217 218 219 220 221 222 223 224 225 226 227 228
                        for comm_op_cost in item:
                            max_time = None
                            cost_time = {}
                            group_ranks = comm_op_cost.group_ranks
                            for rank in comm_op_cost.group_ranks:
                                rank_cost = self.local_cost(rank)
                                cost_time[rank] = rank_cost.time
                                if max_time is None:
                                    max_time = rank_cost.time
                                else:
                                    if max_time < rank_cost.time:
                                        max_time = rank_cost.time
                            for rank in group_ranks:
229 230
                                self.local_cost(rank).time = (
                                    max_time + comm_op_cost.time
231 232
                                    if op.attr('op_role') != OpRole.Backward
                                    else max_time + 0.9 * comm_op_cost.time
233
                                )
234 235 236
                                if rank not in self._bubble_time_mapping:
                                    self._bubble_time_mapping[rank] = 0
                                self._bubble_time_mapping[rank] += (
237 238
                                    max_time - cost_time[rank]
                                )
239
                    elif isinstance(item, dict):
240
                        # Op just one
241
                        for rank in processes:
242
                            # DP+PP+MP
243 244 245 246 247 248 249 250 251 252 253 254
                            if rank not in item:
                                continue
                            self.local_cost(rank).time += item[rank].time

    def prepare(self):
        self._global_cost = Cost()
        self._local_cost_mapping = {}
        self._detailed_cost = OrderedDict()
        self._bubble_time_mapping = {}

    def _calculate_bytes(self, sizes, dtype):
        if sizes:
255
            total_count = reduce(lambda x, y: x * y, sizes, 1)
256 257 258 259 260 261 262
        else:
            total_count = 0

        if dtype == paddle.float64 or dtype == paddle.int64:
            dtype_factor = 8
        elif dtype == paddle.float32 or dtype == paddle.int32:
            dtype_factor = 4
263 264 265 266 267
        elif (
            dtype == paddle.float16
            or dtype == paddle.bfloat16
            or dtype == paddle.int16
        ):
268 269 270 271 272 273 274 275 276 277 278 279 280
            dtype_factor = 2
        elif dtype == paddle.int8 or dtype == paddle.uint8:
            dtype_factor = 1
        else:
            dtype_factor = 8

        memory = total_count * dtype_factor
        return memory

    def _estimate_max_memory_by_dist_op(self, dist_context):
        # This estimation will be improved, now reshard and inplace are not considered.
        # Persist var is not free.
        def _convert_pm_and_dm_to_str(process_mesh, dims_mapping):
281 282
            processes = ",".join([str(x) for x in process_mesh.process_ids])
            topology = ",".join([str(x) for x in process_mesh.shape])
283 284 285 286 287
            dims_mapping = ",".join([str(x) for x in dims_mapping])
            result = processes + topology + dims_mapping
            return result

        memories = {}
288
        self.max_memories = {}
289 290 291
        var_info = (
            {}
        )  # var_name: [[process_mesh, dims_mapping], [id]], [[process_mesh, dims_mapping], [id]]}
292 293 294 295 296 297

        for block in self.program.blocks:
            for op in block.ops:
                self._ordered_ops.append([op.desc.id(), op])
        self._ordered_ops.sort(key=lambda x: x[0])

298
        parameters = set()
299
        for op_id, op in self._ordered_ops:
300
            if op.type in [
301 302 303
                "create_py_reader",
                "create_double_buffer_reader",
                "read",
304 305
            ]:
                continue
306
            dist_op = dist_context.get_dist_op_for_program(op)
307 308
            if not dist_op:
                continue
309 310 311
            process_mesh = dist_op.dist_attr.process_mesh
            for var_name in op.input_arg_names:
                input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
312 313
                    var_name
                )
314

315 316
                if var_name not in var_info:
                    var_info[var_name] = {}
317 318 319
                key = _convert_pm_and_dm_to_str(
                    process_mesh, input_dims_mapping
                )
320 321
                if key not in var_info[var_name]:
                    var_info[var_name][key] = {}
322
                # It is even partition now
323 324 325 326
                if "position" not in var_info[var_name][key]:
                    var_info[var_name][key]["position"] = []
                var_info[var_name][key]["position"].append(op_id)

327 328 329 330 331
                if "memory" not in var_info[var_name][key]:
                    var = dist_op.get_serial_input(var_name)
                    global_sizes = var.shape
                    dtype = var.dtype
                    sizes = DistributedTensor.get_local_sizes(
332 333
                        global_sizes,
                        input_dims_mapping,
334 335
                        process_mesh.shape,
                        process_mesh.process_ids,
336
                    )
337
                    var_info[var_name][key]["memory"] = self._calculate_bytes(
338 339
                        sizes, dtype
                    )
340 341 342 343 344 345 346 347 348 349
                    if var.persistable:
                        name = var_name + key
                        if name not in parameters:
                            parameters.add(name)
                            for process in process_mesh.process_ids:
                                if process not in memories:
                                    memories[process] = 0
                                memories[process] += var_info[var_name][key][
                                    "memory"
                                ]
350 351 352

            for var_name in op.output_arg_names:
                output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
353 354
                    var_name
                )
355 356
                if var_name not in var_info:
                    var_info[var_name] = {}
357 358 359
                key = _convert_pm_and_dm_to_str(
                    process_mesh, output_dims_mapping
                )
360 361
                if key not in var_info[var_name]:
                    var_info[var_name][key] = {}
362 363 364 365
                if "position" not in var_info[var_name][key]:
                    var_info[var_name][key]["position"] = []
                var_info[var_name][key]["position"].append(op_id)

366 367 368 369 370
                if "memory" not in var_info[var_name][key]:
                    var = dist_op.get_serial_output(var_name)
                    global_sizes = var.shape
                    dtype = var.dtype
                    sizes = DistributedTensor.get_local_sizes(
371 372
                        global_sizes,
                        output_dims_mapping,
373 374
                        process_mesh.shape,
                        process_mesh.process_ids,
375
                    )
376
                    var_info[var_name][key]["memory"] = self._calculate_bytes(
377 378
                        sizes, dtype
                    )
379 380 381 382 383 384 385 386 387 388
                    if var.persistable:
                        name = var_name + key
                        if name not in parameters:
                            parameters.add(name)
                            for process in process_mesh.process_ids:
                                if process not in memories:
                                    memories[process] = 0
                                memories[process] += var_info[var_name][key][
                                    "memory"
                                ]
389 390

        has_used_vars = set()
391
        not_calc_vars = set()
392
        for op_id, op in self._ordered_ops:
393
            if op.type in [
394 395 396
                "create_py_reader",
                "create_double_buffer_reader",
                "read",
397 398
            ]:
                continue
399 400 401
            can_free_memories = {}
            can_free_vars = set()
            dist_op = dist_context.get_dist_op_for_program(op)
402 403
            if not dist_op:
                continue
404 405 406
            process_mesh = dist_op.dist_attr.process_mesh
            for var_name in op.input_arg_names:
                input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
407 408 409 410 411
                    var_name
                )
                key = _convert_pm_and_dm_to_str(
                    process_mesh, input_dims_mapping
                )
412 413
                has_used_var = var_name + key
                var = dist_op.get_serial_input(var_name)
414
                # Not used
415 416 417 418 419 420
                if (
                    has_used_var not in has_used_vars
                    and has_used_var not in parameters
                ):
                    if has_used_var in not_calc_vars:
                        continue
421
                    has_used_vars.add(has_used_var)
422
                    for process in process_mesh.process_ids:
423 424 425
                        if process not in memories:
                            memories[process] = 0
                        memories[process] += var_info[var_name][key]["memory"]
426
                # Used
427 428 429 430 431 432 433 434 435 436 437 438
                if op_id == var_info[var_name][key]["position"][-1]:
                    if (
                        has_used_var not in can_free_vars
                        and not var.persistable
                    ):
                        can_free_vars.add(has_used_var)
                        for process in process_mesh.process_ids:
                            if process not in can_free_memories:
                                can_free_memories[process] = 0
                            can_free_memories[process] += var_info[var_name][
                                key
                            ]["memory"]
439 440 441

            for var_name in op.output_arg_names:
                output_dims_mapping = dist_op.dist_attr.get_output_dims_mapping(
442 443 444 445 446
                    var_name
                )
                key = _convert_pm_and_dm_to_str(
                    process_mesh, output_dims_mapping
                )
447 448
                has_used_var = var_name + key
                var = dist_op.get_serial_output(var_name)
449 450 451 452 453 454 455
                if (
                    op.type == "reshape2"
                    or op.type == "transpose2"
                    or op.type == "elementwise_add"
                ):
                    not_calc_vars.add(has_used_var)
                    continue
456
                # Not used
457 458 459 460
                if (
                    has_used_var not in has_used_vars
                    and has_used_var not in parameters
                ):
461
                    has_used_vars.add(has_used_var)
462
                    for process in process_mesh.process_ids:
463 464 465
                        if process not in memories:
                            memories[process] = 0
                        memories[process] += var_info[var_name][key]["memory"]
466
                # Used
467 468 469 470 471 472 473 474 475 476 477 478
                if op_id == var_info[var_name][key]["position"][-1]:
                    if (
                        has_used_var not in can_free_vars
                        and not var.persistable
                    ):
                        can_free_vars.add(has_used_var)
                        for process in process_mesh.process_ids:
                            if process not in can_free_memories:
                                can_free_memories[process] = 0
                            can_free_memories[process] += var_info[var_name][
                                key
                            ]["memory"]
479

480
            # Calc peak memory
481
            for process in memories:
482 483
                if process not in self.max_memories:
                    self.max_memories[process] = memories[process]
484
                else:
485 486 487
                    if memories[process] > self.max_memories[process]:
                        self.max_memories[process] = memories[process]
            # Free memory
488 489 490 491 492
            for process in can_free_memories:
                if process in memories:
                    memories[process] -= can_free_memories[process]

        # Calculate the max memory in all ranks
493 494
        max_memory = max(self.max_memories.values())
        self.max_memory = max_memory
495 496 497 498 499 500

        return max_memory

    def estimate(self, dist_context, resharder=None):
        self.prepare()
        from ..reshard import Resharder
501 502 503 504 505 506

        resharder = (
            Resharder(self.program, None, self.rank, dist_context, [])
            if resharder is None
            else resharder
        )
507 508 509 510 511

        block = self.program.global_block()
        self._estimate_core(dist_context, resharder, block)

        return self.global_cost
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537

    def _print_tag(self, max_len, length):
        tag = "+" + "-" * max_len
        for i in range(length):
            print(tag, end="")
            if i == length - 1:
                print("+")

    def _print_vals(self, vals, max_len):
        for idx, val in enumerate(vals):
            s = "|" + str(val).center(max_len)
            print(s, end="")
            if idx == len(vals) - 1:
                print("|")

    def _pretty_print_memory_cost(self):
        """Print memory of every rank prettily."""
        if not self.max_memories or not self.max_memory:
            raise ValueError("Please calculate memory cost before print.")

        # Padding automatically
        max_len = 0
        header = ["Rank", "Memory(MiB)"]
        memories = [
            int(item // 1e6) for item in list(self.max_memories.values())
        ]
538
        for memory in memories + header:
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
            if len(str(memory)) > max_len:
                max_len = len(str(memory))
        max_len += 4  # for pretty print of center

        # Print tag
        self._print_tag(max_len, len(header))

        # Print header
        self._print_vals(header, max_len)

        # Print tag
        self._print_tag(max_len, len(header))

        # Print rank and its memory
        for i in range(len(self.max_memories)):
            memory = memories[i]
            vals = [i, memory]
            self._print_vals(vals, max_len)
            self._print_tag(max_len, len(header))

    def _pretty_print_global(self):
        """Print global execution time and max memory prettily."""
        if not self.max_memories or not self.max_memory:
            raise ValueError("Please calculate cost before print.")

        # Padding automatically
        max_len = 0
566
        header = ["Execution Time(us)", "Max Memory(MiB)"]
567
        vals = [round(self.global_cost.time, 3), int(self.max_memory // 1e6)]
568
        for memory in vals + header:
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
            if len(str(memory)) > max_len:
                max_len = len(str(memory))
        max_len += 4  # for pretty print of center

        # Print tag
        self._print_tag(max_len, len(header))

        # Print header
        self._print_vals(header, max_len)

        # Print tag
        self._print_tag(max_len, len(header))

        # Print exec time and max memory
        self._print_vals(vals, max_len)

        # Print tag
        self._print_tag(max_len, len(header))

    def pretty_print_cost(self):
        """Print cost prettily."""
        print("The global execution time and max memory are as follows:")
        self._pretty_print_global()
        print("The memory of every rank is as follows:")
        self._pretty_print_memory_cost()


def get_cost_from_engine(engine, mode):
597
    import copy
598

599 600
    from ..utils import to_list

601
    # Construct cost estimator by original main program
602
    serial_main_prog = (
603 604
        engine._fwd_main_progs[mode].clone()
        if mode in engine._fwd_main_progs
605 606
        else engine._orig_main_prog.clone()
    )
607

608
    serial_startup_prog = (
609 610
        engine._fwd_dist_contexts[mode]._original_serial_main_program.clone()
        if mode in engine._fwd_dist_contexts
611 612 613 614 615 616 617 618 619
        else engine._orig_startup_prog.clone()
    )
    losses = (
        to_list(engine._loss)
        if (
            not isinstance(engine._loss, paddle.nn.Layer)
            and not callable(engine._loss)
        )
        else engine._losses
620
    )
621 622 623
    serial_optimizer = copy.deepcopy(engine._orig_optimizer)
    if mode in engine._fwd_dist_contexts:
        dist_context = copy.deepcopy(engine._fwd_dist_contexts[mode])
624 625
    else:
        from ..dist_context import DistributedContext
626 627 628 629

        dist_context = DistributedContext(
            serial_main_prog,
            serial_startup_prog,
630
            serial_optimizer,
631 632 633 634 635 636
            losses,
            {},
            {"loss": losses},
            engine._cluster,
            engine._strategy,
        )
637 638 639 640 641 642 643
    from ..completion import Completer

    completer = Completer(dist_context)
    completer.complete_forward_annotation()
    dist_context.block_state.parse_forward_blocks(
        dist_context.serial_main_program
    )
644 645 646 647 648

    if mode == "eval" or mode == "predict":
        cost_estimator = CostEstimator(serial_main_prog, engine._cluster)
    elif mode == "train":
        from ..parallelizer_v2 import Parallelizer
649

650 651 652 653 654
        # Get serial main program with backward
        parallelizer = Parallelizer(mode, completer, dist_context)
        # Generate backward
        loss_name = dist_context.serial_loss.name
        serial_loss = serial_main_prog.global_block()._var_recursive(loss_name)
655 656 657
        params_grads = parallelizer._generate_backward(
            serial_main_prog, serial_startup_prog, serial_loss
        )
658 659 660

        # Generate optimizer
        optimizer_ops = parallelizer._generate_optimizer(
661 662 663 664 665
            serial_main_prog,
            serial_startup_prog,
            serial_optimizer,
            params_grads,
        )
666 667 668 669 670 671 672 673 674 675
        cost_estimator = CostEstimator(serial_main_prog, engine._cluster)

    # Estimate global_cost and  max memory
    global_cost = cost_estimator.estimate(dist_context)
    max_memory = cost_estimator._estimate_max_memory_by_dist_op(dist_context)

    # Print the cost
    cost_estimator.pretty_print_cost()

    return global_cost, max_memory