config.py 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16
import os
17

18
from ..strategy import Strategy
19 20 21 22 23

_tuning_supported_passes = ["sharding", "recompute"]


def _get_pass_config(strategy, pass_name):
24
    config = getattr(strategy, pass_name)
25 26 27
    return config


28
class TuningConfig:
29 30 31
    """
    A uniform config wrap:
    distributed strategy: the user defined configuration for optimization pass
32
    tuning config: configuration for the tuning process: mode (profile or cost model), log dir, extra tuning config for optimization like search range for specific
33 34
    """

35
    def __init__(self, strategy):
36

37 38
        if not isinstance(strategy, Strategy):
            raise TypeError("'strategy' must be object of class `Strategy`.")
39 40 41 42 43 44 45 46 47

        self._tuning_passes_name = set()
        self._dist_strategy = copy.deepcopy(strategy)
        self._mode = None
        self._profile_start_step = None
        self._profile_end_step = None
        self._project_dir = None
        self._max_num_trial = None
        self._early_stop = None
48
        self._debug = None
49

50
        self._initialize()
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

    @property
    def mode(self):
        return self._mode

    @property
    def profile_start_step(self):
        return self._profile_start_step

    @property
    def profile_end_step(self):
        return self._profile_end_step

    @property
    def project_dir(self):
        return self._project_dir

    @property
    def tuning_passes_name(self):
        return self._tuning_passes_name

    @property
    def max_num_trial(self):
        return self._max_num_trial

    @property
    def early_stop(self):
        return self._early_stop

    @property
81 82
    def debug(self):
        return self._debug
83 84 85 86 87 88

    @property
    def dist_strategy(self):
        return self._dist_strategy

    # initialize config with user define value or default value
89 90
    def _initialize(self):
        tuning_strategy = self._dist_strategy.tuning
91

92 93 94 95 96 97
        self._mode = tuning_strategy.get("mode", "PROFILE")
        self._profile_start_step = tuning_strategy.get("profile_start_step", 10)
        self._profile_end_step = tuning_strategy.get("profile_end_step", 30)
        self._max_num_trial = tuning_strategy.get("max_num_trial", 50)
        self._early_stop = tuning_strategy.get("early_stop", None)
        self._debug = tuning_strategy.get("debug", False)
98

99
        project_dir = tuning_strategy.get("project_dir", None)
100 101 102 103 104
        if not project_dir:
            project_dir = os.path.join(os.getcwd(), "OptimizationTuning")
        self._project_dir = project_dir

        for p in _tuning_supported_passes:
105 106 107 108
            if (
                getattr(self._dist_strategy, p)
                and _get_pass_config(self._dist_strategy, p).enable_tuning
            ):
109 110 111
                # TODO distinguish different args of each passes
                self._tuning_passes_name.add(p)

112 113
                p_strategy = getattr(self._dist_strategy, p)
                self.__dict__[p] = p_strategy
114

115 116 117 118 119
                # # TODO verify the user defined configs
                # tuning_config_for_pass = tuning_strategy.get(p, None)
                # if tuning_config_for_pass:
                #     for k, v in tuning_config_for_pass.items():
                #         self.__dict__[p][k] = v
120 121 122 123

    # (NOTE)tuning config ONLY wraps dist strategy for pass config which is to be tuned
    def __getattr__(self, item):
        return getattr(self._dist_strategy, item)