optimization_tuner.py 21.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
# import yaml
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
import os
import sys
import copy
import shlex
import pathlib
import time
import shutil
import pickle
import json
import logging
import subprocess

import paddle
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass, PassContext

33
from paddle.distributed.auto_parallel.dist_context import DistributedContext
34 35 36
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.reshard import Resharder
from paddle.distributed.auto_parallel.partitioner import Partitioner
37 38 39 40
from paddle.distributed.auto_parallel.process_group import (
    clear_all_process_groups,
    get_all_process_groups,
)
41
from paddle.distributed.auto_parallel.utils import debug_program
42
from paddle.distributed.auto_parallel.utils import set_grad_var_shape
43

44
from ..utils import get_logger
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
from .config import TuningConfig
from .algorithms import new_algorithm
from .trial import TrialStatus


def _get_new_params_grads(target_program, ref_program, ref_params_grads):
    ref_block = ref_program.global_block()
    target_block = target_program.global_block()
    target_params_grads = []

    for p, g in ref_params_grads:
        # NOTE grad var might not be generated
        assert ref_block.has_var(p.name)
        assert target_block.has_var(p.name)
        new_p = target_block.var(p.name)
        if g:
            new_g = target_block.var(g.name)
        else:
            new_g = None

        target_params_grads.append((new_p, new_g))

    return target_params_grads


def _get_new_loss(target_program, ref_program, loss):
    ref_block = ref_program.global_block()
    target_block = target_program.global_block()
    assert ref_block.has_var(loss.name)

    return target_block.var(loss.name)


def parse_process_groups():
    group_map = {}
    all_process_groups = get_all_process_groups()
    for process_group in all_process_groups:
        group_map[process_group.id] = process_group.ranks
    return group_map


def get_metric(results):
    assert isinstance(
88 89
        results, dict
    ), "results should be type of dictionary, but got {}.".format(type(results))
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    if 'Throughtput' in results and isinstance(results['Throughtput'], float):
        return float(results['Throughtput'])
    else:
        return -1.0


def parse_results(results):
    if results['Throughtput'] > 0:
        return "Throughtput: {} step / s.".format(results['Throughtput'])
    et = results.get("ErrorType", None)
    if et == "ResourceExhaustedError":
        return "Fail with OOM"
    else:
        return "Fail with UNKWON ERROR"


# TODO only dependent on dist context
# all env need to be start a new pass are member of dist context
def _copy_context(ref_dist_context):

    clear_all_process_groups()

    new_dist_context = DistributedContext()
113 114 115 116 117 118
    new_dist_context._serial_main_program = (
        ref_dist_context.serial_main_program.clone(for_test=False)
    )
    new_dist_context._serial_startup_program = (
        ref_dist_context.serial_startup_program.clone(for_test=False)
    )
119 120 121 122 123 124

    # mapping variable into new dist context
    if getattr(ref_dist_context, '_params_grads', None):
        new_dist_context._params_grads = _get_new_params_grads(
            new_dist_context.serial_main_program,
            ref_dist_context.serial_main_program,
125 126
            ref_dist_context._params_grads,
        )
127 128
    new_dist_context._serial_loss = _get_new_loss(
        new_dist_context.serial_main_program,
129 130 131
        ref_dist_context.serial_main_program,
        ref_dist_context.serial_loss,
    )
132 133 134 135 136 137 138

    for key, var_list in ref_dist_context._serial_feed_vars.items():
        new_var_list = []
        for var in var_list:
            block_idx = var.block.idx
            var_name = var.name
            var = new_dist_context._serial_main_program.blocks[
139 140
                block_idx
            ]._var_recursive(var_name)
141 142 143 144 145
            new_var_list.append(var)
        new_dist_context._serial_feed_vars[key] = new_var_list

    for key, var_list in ref_dist_context._serial_fetch_vars.items():
        new_var_list = []
146 147 148 149 150 151 152 153
        # metrics is a list of list
        if key == "metrics":
            for inner_var_list in var_list:
                new_inner_var_list = []
                for var in inner_var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = new_dist_context._serial_main_program.blocks[
154 155
                        block_idx
                    ]._var_recursive(var_name)
156 157 158 159 160 161 162
                    new_inner_var_list.append(var)
                new_var_list.append(new_inner_var_list)
        else:
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = new_dist_context._serial_main_program.blocks[
163 164
                    block_idx
                ]._var_recursive(var_name)
165
                new_var_list.append(var)
166 167 168 169
        new_dist_context._serial_fetch_vars[key] = new_var_list

    # copy information in forward and backward
    new_dist_context._serial_optimizer = copy.deepcopy(
170 171
        ref_dist_context.serial_optimizer
    )
172
    new_dist_context._dist_tensors_for_program = copy.deepcopy(
173 174
        ref_dist_context._dist_tensors_for_program
    )
175
    new_dist_context._dist_ops_for_program = copy.deepcopy(
176 177
        ref_dist_context._dist_ops_for_program
    )
178 179 180
    for pm in ref_dist_context.process_meshes:
        new_dist_context.add_process_mesh(pm)
    new_dist_context._dist_op_context = copy.deepcopy(
181 182
        ref_dist_context._dist_op_context
    )
183 184 185 186 187 188 189
    new_dist_context._block_state = copy.deepcopy(ref_dist_context.block_state)

    return new_dist_context


class OptimizationTuner:
    """
190
    OptimizationTuner is used to manage the tuning procedure of hyper-parameters (configs)
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
    of Optimization Pass in AutoParallel.
    """

    def __init__(
        self,
        user_configs,
        dist_context,
        dataset,
        inputs_spec,
        labels_spec,
        batch_size,
        rank,
    ):

        self._config = TuningConfig(user_configs, dist_context._strategy)
        # should not modify dist context from calling function
        self._baseline_dist_context = _copy_context(dist_context)
        self._baseline_completer = Completer(self._baseline_dist_context)

        self._rank = rank
        self._inputs_spec = inputs_spec
        self._labels_spec = labels_spec
        self._dataset = dataset
        self._batch_size = batch_size

        self._finished_trials = []
        self._best_metric = None
        self._best_iter = float("-inf")

        self._logger = get_logger(logging.INFO)

        self._build_programs_without_optimization()
        self._select_tuning_algorithm()

    @property
    def project_dir(self):
        dirname = self._config.project_dir
        if not os.path.exists(dirname):
            if self.rank == 0:
                pathlib.Path(dirname).mkdir(parents=True, exist_ok=True)
        return dirname

    @property
    def rank(self):
        return self._rank

    @property
    def device_id(self):
        return paddle.distributed.ParallelEnv().device_id

    # TODO Generate compelet program with all parts like forward, backward, update
    # as well as parallelism transformation.
    def _build_programs_without_optimization(self):

        serial_main_program = self._baseline_dist_context.serial_main_program
246 247 248
        serial_startup_program = (
            self._baseline_dist_context.serial_startup_program
        )
249 250 251 252 253
        serial_loss = self._baseline_dist_context.serial_loss

        with program_guard(serial_main_program, serial_startup_program):
            params_grads = append_backward(
                serial_loss,
254 255
                distop_context=self._baseline_dist_context.dist_op_context,
            )
256 257

        self._baseline_completer.complete_backward_annotation(
258 259
            serial_main_program
        )
260
        self._baseline_dist_context.block_state.parse_backward_blocks(
261 262
            serial_main_program
        )
263 264 265 266 267 268
        self._baseline_dist_context._params_grads = params_grads

        if self._config.verbose:
            baseline_dir = os.path.join(self.project_dir, "baseline")
            if not os.path.exists(baseline_dir):
                pathlib.Path(baseline_dir).mkdir(parents=True, exist_ok=True)
269 270 271 272 273 274 275 276 277 278
            debug_program(
                self._baseline_dist_context._serial_main_program,
                baseline_dir,
                "main",
            )
            debug_program(
                self._baseline_dist_context._serial_startup_program,
                baseline_dir,
                "startup",
            )
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295

    def _select_tuning_algorithm(self):

        selected_passes_set = self._config.tuning_passes_name
        algorithm_name = "_".join(sorted(selected_passes_set))
        self._algorithm = new_algorithm(algorithm_name, self._config)

    def _apply_optimization(self, trial):
        new_strategy = trial.space
        dist_context = _copy_context(self._baseline_dist_context)
        pass_context = PassContext()
        completer = Completer(dist_context)

        main_program = dist_context.serial_main_program
        startup_program = dist_context.serial_startup_program

        # applying optimization pass
296 297
        if new_strategy.amp.enable:
            config = copy.deepcopy(new_strategy.amp.to_dict())
298 299 300 301 302
            config["dist_context"] = dist_context
            config["params_grads"] = dist_context._params_grads

            # TODO AMP Pass should not use loss var
            config["loss"] = dist_context.serial_loss
303 304
            config["input_data"] = (
                self._baseline_dist_context.serial_feed_vars["inputs"]
305
                + self._baseline_dist_context.serial_feed_vars["labels"]
306
            )
307
            if config["use_pure_fp16"]:
308
                config["base_opt"] = dist_context.serial_optimizer
309
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
310 311 312
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], pass_context
                )
313
                dist_context.serial_loss = auto_parallel_fp16_pass.get_loss()
314 315
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
316 317 318
                auto_parallel_amp_pass.apply(
                    [main_program], [startup_program], pass_context
                )
319
                dist_context.serial_loss = auto_parallel_amp_pass.get_loss()
320

321 322
        if new_strategy.recompute.enable:
            config = copy.deepcopy(new_strategy.recompute.to_dict())
323 324 325
            config["dist_context"] = dist_context
            config["no_grad_set"] = None
            config["loss"] = dist_context.serial_loss
326 327 328 329 330 331
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], pass_context
            )
332 333 334

        # Do logical partition
        partitioner = Partitioner(dist_context, self.rank)
335 336 337 338 339 340 341
        (
            dist_main_prog,
            dist_startup_prog,
            dist_params_grads,
        ) = partitioner.partition(
            main_program, startup_program, dist_context._params_grads
        )
342 343 344 345 346

        # Generate optimizer
        # FIXME should be remove from apply pass after pass support optimizers
        with program_guard(dist_main_prog, dist_startup_prog):
            optimizer_ops = dist_context.serial_optimizer.apply_gradients(
347 348
                dist_params_grads
            )
349 350 351 352
        completer.complete_update_annotation(dist_main_prog)

        # Do reshard process
        set_grad_var_shape(dist_main_prog, dist_context)
353 354 355 356 357 358 359
        resharder = Resharder(
            dist_main_prog,
            dist_startup_prog,
            self.rank,
            dist_context,
            dist_params_grads,
        )
360 361
        resharder.reshard()

362 363
        if new_strategy.sharding.enable:
            config = copy.deepcopy(new_strategy.sharding.to_dict())
364 365 366
            config["dist_context"] = dist_context
            config["params_grads"] = dist_params_grads
            config["global_rank"] = self.rank
367 368 369 370 371 372
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [dist_main_prog], [dist_startup_prog], pass_context
            )
373

374 375
        if new_strategy.gradient_merge.enable:
            config = copy.deepcopy(new_strategy.gradient_merge.to_dict())
376 377 378
            config["dist_context"] = dist_context
            config["params_grads"] = dist_params_grads
            auto_parallel_gradient_merge_pass = new_pass(
379 380 381 382 383 384 385 386 387
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [dist_main_prog], [dist_startup_prog], pass_context
            )
        trial.main_program, trial.startup_program = (
            dist_main_prog,
            dist_startup_prog,
        )
388 389 390 391 392 393 394
        return trial

    def _get_profile_context(self, trial, result_path):

        profile_ctx = {}

        profile_ctx['distributed_env'] = copy.deepcopy(
395 396
            paddle.distributed.ParallelEnv()
        )
397 398
        profile_ctx['group_map'] = parse_process_groups()
        profile_ctx[
399 400
            "loss_var_name"
        ] = self._baseline_dist_context.serial_loss.name
401
        profile_ctx[
402 403
            "main_program_decs"
        ] = trial.main_program.desc.serialize_to_string()
404
        profile_ctx[
405 406
            "startup_program_decs"
        ] = trial.startup_program.desc.serialize_to_string()
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
        self._dataset.batch_size = self._batch_size
        self._dataset.input_names = self._get_input_names()

        profile_ctx["dataset"] = self._dataset
        profile_ctx["result_filename"] = result_path

        return profile_ctx

    def _get_input_names(self):
        input_names = []
        for input_spec in self._inputs_spec[:] + self._labels_spec[:]:
            input_names.append(input_spec.name)
        return input_names

    def _launch_profile(self, ctx_path, trial_dir):

        if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
            coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
        else:
            coverage_args = []

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446
        profile_args = " ".join(
            [
                "--rank",
                str(self.rank),
                "--device_id",
                str(self.device_id),
                "--ctx_filename",
                ctx_path,
                "--profile_start_step",
                str(self._config.profile_start_step),
                "--profile_end_step",
                str(self._config.profile_end_step),
            ]
        )
        cmd_args = (
            "-m paddle.distributed.auto_parallel.tuner.profiler"
            + " "
            + profile_args
        )
447 448 449 450 451 452 453 454 455 456 457 458
        cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args)

        parent_env = copy.copy(os.environ.copy())
        # env flags need for profile
        new_env = {
            "FLAGS_USE_STANDALONE_EXECUTOR": "False",
        }
        new_env.update(parent_env)

        # TODO if any rank hang or fail, kill all processes
        self._logger.debug("Executing cmd:\n{} .".format(" ".join(cmd)))
        # new_process = subprocess.Popen(cmd, env=new_env)
459 460 461 462 463
        with open(
            os.path.join(trial_dir, "stdout.log" + str(self.rank)), "wb"
        ) as out, open(
            os.path.join(trial_dir, "stderr.log" + str(self.rank)), "wb"
        ) as err:
464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
            result = subprocess.Popen(cmd, stdout=out, stderr=err, env=new_env)
            result.wait()
            out.flush()
            err.flush()
            os.fsync(out)
            os.fsync(err)

    def _profile_trial(self, trial):
        # Making working directory
        trial_dir = self._get_trial_dir(trial)
        if not os.path.exists(trial_dir):
            if self.rank == 0:
                pathlib.Path(trial_dir).mkdir(parents=True, exist_ok=True)
            else:
                while not os.path.exists(trial_dir):
                    pass
        ctx_filename = "profile_ctx." + str(self.rank)
        ctx_path = os.path.join(trial_dir, ctx_filename)
        result_path = os.path.join(trial_dir, "result.json")

        # Prepare Profile Context
        profile_ctx = self._get_profile_context(trial, result_path)
        with open(ctx_path, 'wb') as f:
            pickle.dump(profile_ctx, f, protocol=4)

        if self._config.verbose:
            debug_program(trial.main_program, trial_dir, "main_program")
            debug_program(trial.startup_program, trial_dir, "startup_program")

        # Run
        self._launch_profile(ctx_path, trial_dir)

        # Load results
        try:
            with open(result_path, 'r') as fp:
                results = json.load(fp)
            return results
        except FileNotFoundError:
            Error_results = {"Throughtput": -1, "ErrorType": 'FatalError'}
            return Error_results

    def _evaluate_trial(self, trial):

        self._logger.info("Trial {} evaluation start.".format(trial.name))
        self._apply_optimization(trial)

        if self._config.mode == "PROFILE":
            results = self._profile_trial(trial)

        elif self._config.mode == "COSTMODEL":
            raise NotImplementedError(
515 516
                "COSTMODEL mode for optimization tuning is not supported yet!"
            )
517
        else:
518 519 520
            raise NotImplementedError(
                "invalid evaluation mode: {}".format(self._config.mode)
            )
521

522 523 524 525 526
        self._logger.info(
            "Trial {} evaluation finish with {}.".format(
                trial.name, parse_results(results)
            )
        )
527 528 529 530 531 532
        return results

    def _update(self, i, trial, results):
        self._finished_trials.append(trial)

        cur_mertic = get_metric(results)
533
        if self._best_metric is None or cur_mertic > self._best_metric:
534 535 536 537 538 539 540 541 542 543 544
            self._best_metric = cur_mertic
            self._best_iter = i

    def _get_trial_dir(self, trial):
        return os.path.join(self.project_dir, trial.name)

    def get_best_config(self):
        """
        Return the best optimization configuration found in the tuning.

        Returns:
545
            A object of fleet.DistributedStrategy with best configuration.
546 547 548 549 550 551 552 553 554 555 556 557 558 559
        """
        assert self._best_iter >= 0, "The best configuration is not found yet !"
        best_trial = self._finished_trials[self._best_iter]
        return self._algorithm.get_config_from_trial(best_trial)

    def summary(self):
        """
        Display tuning result summary.
        """
        # TODO summary with the trial_name with metric_of_trial
        best_trial = self._finished_trials[self._best_iter]
        summary_ = """
Tuning Result Summary
Run total {} trials with {} min.
560
The best trial is: [{}], whose configuration is following:
561 562 563 564 565 566
        """.format(
            len(self._finished_trials),
            (time.time() - self._tuning_start_time) / 60,
            best_trial.name,
        )
        summary_ += "\n" + best_trial.summary() + "\n"
567 568 569 570 571
        self._logger.info(summary_)
        with open(os.path.join(self.project_dir, "summary.txt"), "w+") as fw:
            for line in summary_.split("\n"):
                fw.write(line + "\n")

572 573 574 575
        # full_strategy = self.get_best_config()
        # path = os.path.join(self.project_dir, "tuned_dist_strategy.yaml")
        # with open(path, 'w') as outfile:
        #     yaml.dump(full_strategy, outfile, default_flow_style=False)
576 577 578 579 580 581 582 583 584 585 586 587 588

    def clear(self):
        """
        Clear the temporary file generated in tuning procedure.
        """
        # TODO clear up zombie process created by tuning
        if not self._config.verbose:
            for trial in self._finished_trials:
                trial_dir = self._get_trial_dir(trial)
                shutil.rmtree(trial_dir, ignore_errors=True)

    def tune(self):
        """
589 590
        Performs the search for best hyperparameter configuations
        for the selected optimization pass(es).
591 592 593 594 595 596 597
        """

        # step1: collect model info which might be used for
        # pruning the search space of the algorithm
        self._tuning_start_time = time.time()
        self._algorithm.collect_model_info(
            self._baseline_dist_context.serial_main_program,
598 599
            self._baseline_dist_context.serial_startup_program,
        )
600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620

        # main search loop
        i = 0
        while i < self._config.max_num_trial:
            # step2: create a new trial
            trial = self._algorithm.next_trial()

            if trial.status == TrialStatus.STOPPED:
                break

            # step3: evaluate the trial
            results = self._evaluate_trial(trial)

            # step4: update the algorithm with last result,
            # which could be used by algorithm to pruning the
            # remaining search space.
            self._algorithm.update(results)
            self._update(i, trial, results)

            # early stop
            i += 1
621 622 623 624
            if (
                self._config.early_stop
                and self._config.early_stop <= i - self._best_iter
            ):
625
                self._logger.info(
626 627 628 629
                    "Early stop the Tuning since there is no better trial found within [{}] trials".format(
                        self._config.early_stop
                    )
                )
630 631 632 633 634 635
                break

        # step5: summary the best config and return
        self.summary()

        self.clear()