strategy.py 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from . import constants


19
class BaseConfig:
20 21 22 23 24 25 26 27 28
    def __init__(self, category, config_dict=None):
        self._category = category
        self._config_dict = None
        if config_dict is not None:
            if isinstance(config_dict, dict):
                self._config_dict = config_dict
            else:
                raise ValueError(
                    "Expected a dictionary. But received: {}".format(
29 30 31
                        config_dict
                    )
                )
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        # Initialize attributes by the default config
        config = constants.get_category_default_config(self._category)
        for field, default_value in config.items():
            setattr(self, field, default_value)

        # Overide attributes by the config_dict
        if self._config_dict:
            self.from_dict(self._config_dict)

    def from_dict(self, config_dict):
        config = constants.get_category_default_config(self._category)
        for field in config.keys():
            value = config_dict.get(field, constants.NOT_FOUND)
            # Use the default value if we cannot found the value
            if value != constants.NOT_FOUND:
                setattr(self, field, value)

    def to_dict(self):
        result_dict = {}
        config = constants.get_category_default_config(self._category)
        for field in config.keys():
            value = getattr(self, field)
            result_dict[field] = value
        for field, value in self.__dict__.items():
            if isinstance(value, BaseConfig):
                result_dict[field] = value.to_dict()
        return result_dict

    def __repr__(self):
Z
zhaoyingli 已提交
61 62 63 64 65
        result_dict = self.to_dict()
        string = "{"
        for k, v in result_dict.items():
            string += "\"%s\":\"%s\"," % (k, v)
        return string + "}"
66 67 68 69 70 71 72 73 74 75 76 77 78

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            setattr(result, k, copy.deepcopy(v, memo))
        return result


class RecomputeConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.RECOMPUTE
79
        super().__init__(category, config_dict)
80 81 82 83 84


class AMPConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.AMP
85
        super().__init__(category, config_dict)
86 87 88 89 90


class ShardingConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.SHARDING
91
        super().__init__(category, config_dict)
92 93 94 95 96


class GradientMergeConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.GRADIENT_MERGE
97
        super().__init__(category, config_dict)
98 99 100 101 102


class QATConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.QAT
103
        super().__init__(category, config_dict)
104 105 106 107 108


class TuningConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.TUNING
109
        super().__init__(category, config_dict)
110 111


112 113 114
class DatasetConfig(BaseConfig):
    def __init__(self, config_dict=None):
        category = constants.DATASET
115
        super().__init__(category, config_dict)
116 117


118 119
class Strategy(BaseConfig):
    """
Z
zhaoyingli 已提交
120
    The `Strategy` object is used to configure the paralleization and optimization beheviors.
121 122 123 124 125 126 127 128 129 130 131 132

    Args:
        config (dict|string, optional): If this is None, the default configurations will used.
        If this is a dictionary, the recognized key-value of it will be used to override the default
        configurations while other default configurations are left unchanged. If this is a string,
        it is interpreted as the path to a YAML configuration and will be loaded to override the
        corresponding default configurations.

    Examples:
        .. code-block:: python

            import paddle
133
            from paddle.distributed.fleet import auto
134 135 136 137 138

            strategy = auto.Strategy()
            sharding = strategy.sharding
            self.assertEqual(sharding.enabled, False)
            self.assertEqual(sharding.stage, 1)
Z
zhaoyingli 已提交
139
            self.assertEqual(sharding.degree, 8)
140 141
            sharding.enabled = True
            sharding.stage = 2
Z
zhaoyingli 已提交
142
            sharding.degree = 2
143 144
            self.assertEqual(sharding.enabled, True)
            self.assertEqual(sharding.stage, 2)
Z
zhaoyingli 已提交
145
            self.assertEqual(sharding.degree, 2)
146 147 148 149 150 151 152 153 154 155 156 157

    """

    def __init__(self, config=None):
        if config is not None:
            if isinstance(config, dict):
                self._config_dict = copy.deepcopy(config)
            # elif os.path.exists(config):
            #     with open(config, "rb") as yaml_file:
            #         self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
            else:
                raise ValueError(
158 159
                    "Expected a dictionary. But received: {}".format(config)
                )
160 161 162 163
        else:
            self._config_dict = {}

        category = constants.BASE
164
        super().__init__(category, self._config_dict)
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182

        config_dict = self._config_dict.get(constants.RECOMPUTE, None)
        self.recompute = RecomputeConfig(config_dict)

        config_dict = self._config_dict.get(constants.AMP, None)
        self.amp = AMPConfig(config_dict)

        config_dict = self._config_dict.get(constants.SHARDING, None)
        self.sharding = ShardingConfig(config_dict)

        config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
        self.gradient_merge = GradientMergeConfig(config_dict)

        config_dict = self._config_dict.get(constants.QAT, None)
        self.qat = QATConfig(config_dict)

        config_dict = self._config_dict.get(constants.TUNING, None)
        self.tuning = TuningConfig(config_dict)
183 184 185

        config_dict = self._config_dict.get(constants.DATASET, None)
        self.dataset = DatasetConfig(config_dict)