strategy.py 6.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import os
import copy
import argparse
from . import constants


class BaseConfig(object):

    def __init__(self, category, config_dict=None):
        self._category = category
        self._config_dict = None
        if config_dict is not None:
            if isinstance(config_dict, dict):
                self._config_dict = config_dict
            else:
                raise ValueError(
                    "Expected a dictionary. But received: {}".format(
                        config_dict))
        # Initialize attributes by the default config
        config = constants.get_category_default_config(self._category)
        for field, default_value in config.items():
            setattr(self, field, default_value)

        # Overide attributes by the config_dict
        if self._config_dict:
            self.from_dict(self._config_dict)

    def from_dict(self, config_dict):
        config = constants.get_category_default_config(self._category)
        for field in config.keys():
            value = config_dict.get(field, constants.NOT_FOUND)
            # Use the default value if we cannot found the value
            if value != constants.NOT_FOUND:
                setattr(self, field, value)

    def to_dict(self):
        result_dict = {}
        config = constants.get_category_default_config(self._category)
        for field in config.keys():
            value = getattr(self, field)
            result_dict[field] = value
        for field, value in self.__dict__.items():
            if isinstance(value, BaseConfig):
                result_dict[field] = value.to_dict()
        return result_dict

    def __repr__(self):
        return yaml.dump(self.to_dict(),
                         default_flow_style=False,
                         sort_keys=True,
                         indent=4)

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
            setattr(result, k, copy.deepcopy(v, memo))
        return result


class RecomputeConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.RECOMPUTE
        super(RecomputeConfig, self).__init__(category, config_dict)


class AMPConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.AMP
        super(AMPConfig, self).__init__(category, config_dict)


class ShardingConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.SHARDING
        super(ShardingConfig, self).__init__(category, config_dict)


class GradientMergeConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.GRADIENT_MERGE
        super(GradientMergeConfig, self).__init__(category, config_dict)


class QATConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.QAT
        super(QATConfig, self).__init__(category, config_dict)


class TuningConfig(BaseConfig):

    def __init__(self, config_dict=None):
        category = constants.TUNING
        super(TuningConfig, self).__init__(category, config_dict)


class Strategy(BaseConfig):
    """
    The `Strategy` object is used to configure the paralleization and optimization beheviors.

    Args:
        config (dict|string, optional): If this is None, the default configurations will used.
        If this is a dictionary, the recognized key-value of it will be used to override the default
        configurations while other default configurations are left unchanged. If this is a string,
        it is interpreted as the path to a YAML configuration and will be loaded to override the
        corresponding default configurations.

    Examples:
        .. code-block:: python

            import paddle
            import paddle.distributed.auto_parallel as auto

            strategy = auto.Strategy()
            sharding = strategy.sharding
            self.assertEqual(sharding.enabled, False)
            self.assertEqual(sharding.stage, 1)
            self.assertEqual(sharding.degree, 8)
            sharding.enabled = True
            sharding.stage = 2
            sharding.degree = 2
            self.assertEqual(sharding.enabled, True)
            self.assertEqual(sharding.stage, 2)
            self.assertEqual(sharding.degree, 2)

    """

    def __init__(self, config=None):
        if config is not None:
            if isinstance(config, dict):
                self._config_dict = copy.deepcopy(config)
            # elif os.path.exists(config):
            #     with open(config, "rb") as yaml_file:
            #         self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader)
            else:
                raise ValueError(
                    "Expected a dictionary. But received: {}".format(config))
        else:
            self._config_dict = {}

        category = constants.BASE
        super(Strategy, self).__init__(category, self._config_dict)

        config_dict = self._config_dict.get(constants.RECOMPUTE, None)
        self.recompute = RecomputeConfig(config_dict)

        config_dict = self._config_dict.get(constants.AMP, None)
        self.amp = AMPConfig(config_dict)

        config_dict = self._config_dict.get(constants.SHARDING, None)
        self.sharding = ShardingConfig(config_dict)

        config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
        self.gradient_merge = GradientMergeConfig(config_dict)

        config_dict = self._config_dict.get(constants.QAT, None)
        self.qat = QATConfig(config_dict)

        config_dict = self._config_dict.get(constants.TUNING, None)
        self.tuning = TuningConfig(config_dict)