distributed_strategy.py 15.6 KB
Newer Older
1
123malin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__all__ = [
    "TrainerRuntimeConfig", "DistributedStrategy", "SyncStrategy",
    "AsyncStrategy", "HalfAsyncStrategy", "GeoStrategy", "StrategyFactory"
]

import os
import paddle.fluid as fluid
22
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig, ServerRuntimeConfig, DistributedMode
1
123malin 已提交
23 24 25 26


class TrainerRuntimeConfig(object):
    def __init__(self):
27 28 29
        self.mode = None
        num_threads = os.getenv("CPU_NUM", "1")

30
        self.runtime_configs = {}
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
        self.runtime_configs['communicator_max_merge_var_num'] = os.getenv(
            "FLAGS_communicator_max_merge_var_num", num_threads)
        self.runtime_configs['communicator_send_queue_size'] = os.getenv(
            "FLAGS_communicator_send_queue_size", num_threads)
        self.runtime_configs[
            'communicator_independent_recv_thread'] = os.getenv(
                "FLAGS_communicator_independent_recv_thread", "1")
        self.runtime_configs[
            'communicator_min_send_grad_num_before_recv'] = os.getenv(
                "FLAGS_communicator_min_send_grad_num_before_recv", num_threads)
        self.runtime_configs['communicator_thread_pool_size'] = os.getenv(
            "FLAGS_communicator_thread_pool_size", "5")
        self.runtime_configs['communicator_send_wait_times'] = os.getenv(
            "FLAGS_communicator_send_wait_times", "5")
        self.runtime_configs['communicator_is_sgd_optimizer'] = os.getenv(
            "FLAGS_communicator_is_sgd_optimizer", "1")

1
123malin 已提交
48
        # not used 
49 50 51 52
        self.runtime_configs['rpc_deadline'] = os.getenv("FLAGS_rpc_deadline",
                                                         "180000")
        self.runtime_configs['rpc_retry_times'] = os.getenv(
            "FLAGS_rpc_retry_times", "3")
1
123malin 已提交
53 54

    def get_communicator_flags(self):
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
        need_keys = []
        num_threads = os.getenv("CPU_NUM", "1")
        mode_str = ""
        if self.mode is None or self.mode == DistributedMode.ASYNC:
            need_keys = self.runtime_configs.keys()
            mode_str = "async"
        elif self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
            mode_str = "sync or half_async"
            need_keys = [
                'communicator_max_merge_var_num',
                'communicator_send_wait_times', 'communicator_thread_pool_size',
                'communicator_send_queue_size'
            ]
        elif self.mode == DistributedMode.GEO:
            mode_str = "GEO"
            need_keys = [
                'communicator_thread_pool_size', 'communicator_send_wait_times'
            ]
        else:
            raise ValueError("Unsupported Mode")

        if self.mode == DistributedMode.SYNC or self.mode == DistributedMode.HALF_ASYNC:
            max_merge_var_num = self.runtime_configs[
                'communicator_max_merge_var_num']
            send_queue_size = self.runtime_configs[
                'communicator_send_queue_size']
            if max_merge_var_num != num_threads:
                print('WARNING: In {} mode, communicator_max_merge_var_num '
                      'must be equal to CPU_NUM. But received, '
                      'communicator_max_merge_var_num = {}, CPU_NUM = '
                      '{}. communicator_max_merge_var_num will be fored to {}.'
                      .format(mode_str, max_merge_var_num, num_threads,
                              num_threads))
                self.runtime_configs[
                    'communicator_max_merge_var_num'] = num_threads
            if send_queue_size != num_threads:
                print('WARNING: In {} mode, communicator_send_queue_size '
                      'must be equal to CPU_NUM. But received, '
                      'communicator_send_queue_size = {}, CPU_NUM = '
                      '{}. communicator_send_queue_size will be fored to {}.'
                      .format(mode_str, send_queue_size, num_threads,
                              num_threads))
                self.runtime_configs[
                    'communicator_send_queue_size'] = num_threads

        return dict((key, str(self.runtime_configs[key])) for key in need_keys)

    def display(self, configs):
103 104 105 106 107 108 109 110 111 112 113 114
        raw0, raw1, length = 45, 5, 50
        h_format = "{:^45s}{:<5s}\n"
        l_format = "{:<45s}{:<5s}\n"

        border = "".join(["="] * length)
        line = "".join(["-"] * length)

        draws = ""
        draws += border + "\n"
        draws += h_format.format("TrainerRuntimeConfig Overview", "Value")
        draws += line + "\n"

115
        for k, v in configs.items():
116 117 118 119 120
            draws += l_format.format(k, v)

        draws += border

        _str = "\n{}\n".format(draws)
1
123malin 已提交
121 122
        return _str

123 124 125
    def __repr__(self):
        return self.display(self.get_communicator_flags())

1
123malin 已提交
126

127 128 129 130 131 132 133 134
class PSLibRuntimeConfig(object):
    def __init__(self):
        self.runtime_configs = {}

    def get_runtime_configs(self):
        return self.runtime_configs


1
123malin 已提交
135 136 137 138
class DistributedStrategy(object):
    def __init__(self):
        self._program_config = DistributeTranspilerConfig()
        self._trainer_runtime_config = TrainerRuntimeConfig()
139
        self._pslib_runtime_config = PSLibRuntimeConfig()
1
123malin 已提交
140
        self._server_runtime_config = ServerRuntimeConfig()
141 142
        num_threads = int(os.getenv("CPU_NUM", "1"))

1
123malin 已提交
143 144
        self._execute_strategy = fluid.ExecutionStrategy()
        self._build_strategy = fluid.BuildStrategy()
145

1
123malin 已提交
146 147 148
        self._execute_strategy.num_threads = num_threads
        if num_threads > 1:
            self._build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
        self.debug_opt = None

    def set_debug_opt(self, opt_info):
        self.debug_opt = opt_info

    def get_debug_opt(self):
        opt_info = dict()
        if self.debug_opt is not None and isinstance(self.debug_opt, dict):
            opt_info["dump_slot"] = bool(self.debug_opt.get("dump_slot", 0))
            opt_info["dump_converter"] = str(
                self.debug_opt.get("dump_converter", ""))
            opt_info["dump_fields"] = self.debug_opt.get("dump_fields", [])
            opt_info["dump_file_num"] = self.debug_opt.get("dump_file_num", 16)
            opt_info["dump_fields_path"] = self.debug_opt.get(
                "dump_fields_path", "")
            opt_info["dump_param"] = self.debug_opt.get("dump_param", [])
        return opt_info
1
123malin 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184

    def get_program_config(self):
        return self._program_config

    def set_program_config(self, config):
        if isinstance(config, DistributeTranspilerConfig):
            self._program_config = config
        elif isinstance(config, dict):
            for key in config:
                if hasattr(self._program_config, key):
                    setattr(self._program_config, key, config[key])
                else:
                    raise ValueError(
                        "DistributeTranspilerConfig doesn't have key: {}".
                        format(key))
        else:
            raise TypeError(
                "program_config only accept input type: dict or DistributeTranspilerConfig"
            )
185 186 187 188 189 190
        self.check_program_config()

    def check_program_config(self):
        raise NotImplementedError(
            "check_program_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
        )
1
123malin 已提交
191 192 193 194 195 196 197 198

    def get_trainer_runtime_config(self):
        return self._trainer_runtime_config

    def set_trainer_runtime_config(self, config):
        if isinstance(config, TrainerRuntimeConfig):
            self._trainer_runtime_config = config
        elif isinstance(config, dict):
199 200 201
            for key, Value in config.items():
                if key in self._trainer_runtime_config.runtime_configs:
                    self._trainer_runtime_config.runtime_configs[key] = Value
1
123malin 已提交
202 203 204 205 206 207 208
                else:
                    raise ValueError(
                        "TrainerRuntimeConfig doesn't have key: {}".format(key))
        else:
            raise TypeError(
                "trainer_runtime_config only accept input type: dict or TrainerRuntimeConfig"
            )
209 210 211 212 213 214
        self.check_trainer_runtime_config()

    def check_trainer_runtime_config(self):
        raise NotImplementedError(
            "check_trainer_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
        )
1
123malin 已提交
215

216 217 218 219 220 221
    def get_pslib_runtime_config(self):
        return self._pslib_runtime_config

    def set_pslib_runtime_config(self, config):
        self._pslib_runtime_config.runtime_configs = config

1
123malin 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
    def get_server_runtime_config(self):
        return self._server_runtime_config

    def set_server_runtime_config(self, config):
        if isinstance(config, ServerRuntimeConfig):
            self._server_runtime_config = config
        elif isinstance(config, dict):
            for key in config:
                if hasattr(self._server_runtime_config, key):
                    setattr(self._server_runtime_config, key, config[key])
                else:
                    raise ValueError(
                        "ServerRuntimeConfig doesn't have key: {}".format(key))
        else:
            raise TypeError(
                "server_runtime_config only accept input type: dict or ServerRuntimeConfig"
            )
239 240 241 242 243 244
        self.check_server_runtime_config()

    def check_server_runtime_config(self):
        raise NotImplementedError(
            "check_server_runtime_config must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
        )
1
123malin 已提交
245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262

    def get_execute_strategy(self):
        return self._execute_strategy

    def set_execute_strategy(self, config):
        if isinstance(config, fluid.ExecutionStrategy):
            self._execute_strategy = config
        elif isinstance(config, dict):
            for key in config:
                if hasattr(self._execute_strategy, key):
                    setattr(self._execute_strategy, key, config[key])
                else:
                    raise ValueError(
                        "ExecutionStrategy doesn't have key: {}".format(key))
        else:
            raise TypeError(
                "execute_strategy only accept input type: dict or ExecutionStrategy"
            )
263 264 265 266 267 268
        self.check_execute_strategy()

    def check_execute_strategy(self):
        raise NotImplementedError(
            "check_execute_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
        )
1
123malin 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285

    def get_build_strategy(self):
        return self._build_strategy

    def set_build_strategy(self, config):
        if isinstance(config, fluid.BuildStrategy):
            self._build_strategy = config
        elif isinstance(config, dict):
            for key in config:
                if hasattr(self._build_strategy, key):
                    setattr(self._build_strategy, key, config[key])
                else:
                    raise ValueError(
                        "BuildStrategy doesn't have key: {}".format(key))
        else:
            raise TypeError(
                "build_strategy only accept input type: dict or BuildStrategy")
286 287 288 289 290 291
        self.check_build_strategy()

    def check_build_strategy(self):
        raise NotImplementedError(
            "check_build_strategy must be implemented by derived class. You should use StrategyFactory to create DistributedStrategy."
        )
1
123malin 已提交
292 293 294 295 296


class SyncStrategy(DistributedStrategy):
    def __init__(self):
        super(SyncStrategy, self).__init__()
297 298 299 300 301 302 303 304 305 306
        self.check_program_config()
        self.check_trainer_runtime_config()
        self.check_server_runtime_config()
        self.check_build_strategy()
        self.check_execute_strategy()

    def check_trainer_runtime_config(self):
        self._trainer_runtime_config.mode = DistributedMode.SYNC

    def check_program_config(self):
T
tangwei12 已提交
307 308 309 310
        self._program_config.sync_mode = False
        self._program_config.runtime_split_send_recv = True
        self._program_config.half_async = True
        self._program_config.completely_not_async = True
1
123malin 已提交
311

312 313
    def check_server_runtime_config(self):
        pass
314

315 316 317 318 319
    def check_execute_strategy(self):
        self._execute_strategy.use_thread_barrier = True

    def check_build_strategy(self):
        self._build_strategy.async_mode = True
320

1
123malin 已提交
321 322 323 324

class AsyncStrategy(DistributedStrategy):
    def __init__(self):
        super(AsyncStrategy, self).__init__()
325 326 327 328 329 330 331 332 333 334
        self.check_program_config()
        self.check_trainer_runtime_config()
        self.check_server_runtime_config()
        self.check_build_strategy()
        self.check_execute_strategy()

    def check_trainer_runtime_config(self):
        self._trainer_runtime_config.mode = DistributedMode.ASYNC

    def check_program_config(self):
1
123malin 已提交
335 336 337
        self._program_config.sync_mode = False
        self._program_config.runtime_split_send_recv = True

338 339
    def check_server_runtime_config(self):
        pass
340

341 342 343 344 345
    def check_execute_strategy(self):
        pass

    def check_build_strategy(self):
        self._build_strategy.async_mode = True
346

1
123malin 已提交
347 348 349 350

class HalfAsyncStrategy(DistributedStrategy):
    def __init__(self):
        super(HalfAsyncStrategy, self).__init__()
351 352 353 354 355 356 357 358 359 360
        self.check_program_config()
        self.check_trainer_runtime_config()
        self.check_server_runtime_config()
        self.check_build_strategy()
        self.check_execute_strategy()

    def check_trainer_runtime_config(self):
        self._trainer_runtime_config.mode = DistributedMode.HALF_ASYNC

    def check_program_config(self):
1
123malin 已提交
361
        self._program_config.sync_mode = False
362 363
        self._program_config.runtime_split_send_recv = True
        self._program_config.half_async = True
1
123malin 已提交
364

365 366 367 368 369
    def check_server_runtime_config(self):
        pass

    def check_execute_strategy(self):
        self._execute_strategy.use_thread_barrier = True
370

371 372
    def check_build_strategy(self):
        self._build_strategy.async_mode = True
373

1
123malin 已提交
374 375 376 377

class GeoStrategy(DistributedStrategy):
    def __init__(self, update_frequency=100):
        super(GeoStrategy, self).__init__()
378 379 380 381 382 383 384 385
        self._program_config.geo_sgd_need_push_nums = update_frequency
        self.check_program_config()
        self.check_trainer_runtime_config()
        self.check_server_runtime_config()
        self.check_build_strategy()
        self.check_execute_strategy()

    def check_program_config(self):
1
123malin 已提交
386 387 388
        self._program_config.sync_mode = False
        self._program_config.runtime_split_send_recv = True
        self._program_config.geo_sgd_mode = True
389

390 391 392 393 394 395 396 397 398 399 400
    def check_trainer_runtime_config(self):
        self._trainer_runtime_config.mode = DistributedMode.GEO

    def check_server_runtime_config(self):
        pass

    def check_execute_strategy(self):
        pass

    def check_build_strategy(self):
        self._build_strategy.async_mode = True
1
123malin 已提交
401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421


class StrategyFactory(object):
    def __init_(self):
        pass

    @staticmethod
    def create_sync_strategy():
        return SyncStrategy()

    @staticmethod
    def create_half_async_strategy():
        return HalfAsyncStrategy()

    @staticmethod
    def create_async_strategy():
        return AsyncStrategy()

    @staticmethod
    def create_geo_strategy(update_frequency=100):
        return GeoStrategy(update_frequency)