parameter_server_optimizer.py 16.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
import paddle
15 16
from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase
17
from paddle.framework import core
18 19
import subprocess
import re
20
import os
21
import platform
22
from ..base.private_helper_function import wait_server_ready
23

24 25
__all__ = []

26

27
class ParameterServerOptimizer(MetaOptimizerBase):
28
    def __init__(self, optimizer):
29
        super().__init__(optimizer)
30 31 32 33
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []

34 35 36
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
37
        super()._set_basic_info(
38 39
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
40

41
        # self.micro_batch_size = user_defined_strategy.pipeline_configs[
42 43
        #    'micro_batch_size']
        self.num_microbatches = user_defined_strategy.pipeline_configs[
44 45
            'accumulate_steps'
        ]
46

47 48 49 50 51 52
    def _is_graph_out(self):
        return False

    def _can_apply(self):
        if self.role_maker._is_collective:
            return False
53

54 55 56
        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        return True if k_steps >= 0 else False

T
Thunderbrook 已提交
57 58 59 60 61 62 63 64 65 66 67 68 69 70
    def get_dist_env(self):
        trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
        trainer_endpoints = ''
        current_endpoint = ''
        num_trainers = 0
        if os.getenv('PADDLE_TRAINER_ENDPOINTS'):
            trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS')
            current_endpoint = trainer_endpoints.split(',')[trainer_id]
            num_trainers = len(trainer_endpoints.split(','))

        return {
            'trainer_id': trainer_id,
            'num_trainers': num_trainers,
            'current_endpoint': current_endpoint,
71
            'trainer_endpoints': trainer_endpoints,
T
Thunderbrook 已提交
72 73
        }

74
    def _get_distributed_strategy(self):
75 76 77
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
            StrategyFactory,
        )
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        strategy = None

        if not self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if self.user_defined_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def _build_trainer_programs(self, compiled_config):
97 98 99
        from paddle.fluid.incubate.fleet.parameter_server.ir import (
            trainer_pass as worker,
        )
100 101 102 103

        _main = compiled_config.origin_main_program.clone()
        _startup = compiled_config.origin_startup_program.clone()

T
Thunderbrook 已提交
104 105
        use_ps_gpu = self.user_defined_strategy.a_sync_configs["use_ps_gpu"]

106
        if not compiled_config.is_geo_mode():
107 108 109 110
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _add_lr_decay_table_pass,
            )

111
            _add_lr_decay_table_pass(
112 113 114 115
                _main,
                compiled_config,
                self.user_defined_strategy.a_sync_configs["lr_decay_steps"],
            )
116

117
            # for main program
118 119 120
            _main = worker.distributed_ops_pass(
                _main, compiled_config, use_ps_gpu
            )
T
Thunderbrook 已提交
121 122 123
            if not use_ps_gpu:
                _main = worker.delete_optimizer_pass(_main, compiled_config)
                _main = worker.append_send_ops_pass(_main, compiled_config)
124
                _startup = worker.delete_extra_optimizes_pass(
125 126
                    _startup, compiled_config
                )
T
Thunderbrook 已提交
127 128

                # for startup program
129
            _startup = worker.fake_init_ops_pass(_startup, compiled_config)
T
Thunderbrook 已提交
130 131
            if use_ps_gpu:
                _main = worker.ps_gpu_pass(_main)
132 133 134 135
                from paddle.fluid.transpiler.collective import (
                    SingleProcessMultiThread,
                )

T
Thunderbrook 已提交
136 137
                t = SingleProcessMultiThread()
                env = self.get_dist_env()
138 139 140 141 142 143 144 145
                t.transpile(
                    startup_program=_startup,
                    main_program=_main,
                    rank=env["trainer_id"],
                    endpoints=env["trainer_endpoints"],
                    current_endpoint=env['current_endpoint'],
                    wait_port=False,
                )
146

147 148
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
149 150
            # for heter program
            if self.role_maker._is_heter_parameter_server_mode:
151 152 153 154
                from paddle.fluid.incubate.fleet.parameter_server.ir import (
                    heter_trainer_pass as heter_worker,
                )

155 156
                if self.role_maker._is_heter_worker():
                    # for heter worker
157 158
                    stage_id = self.role_maker._get_stage_id()
                    device = self.role_maker._heter_device_type().lower()
159
                    _main = heter_worker.split_heter_worker_ops_pass(
160 161
                        _main, compiled_config, stage_id, device
                    )
162 163
                else:
                    # for default worker
164
                    _main = heter_worker.split_trainer_ops_pass(
165 166
                        _main, compiled_config
                    )
167 168 169
        else:
            _main = worker.append_send_ops_pass(_main, compiled_config)
            _startup = _startup
170 171
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
172

173
        launch_barrier = self.user_defined_strategy.a_sync_configs[
174 175
            "launch_barrier"
        ]
176 177 178 179 180 181
        launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
        if launch_barrier and launch_barrier_flag:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
T
tangwei12 已提交
182 183 184
            # if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker(
            # ):
            #     wait_server_ready(self.role_maker._get_heter_worker_endpoints())
185

186 187 188
        return _main, _startup

    def _build_pserver_programs(self, compiled_config):
189 190
        _main = paddle.static.Program()
        _startup = paddle.static.Program()
191

192 193 194
        from paddle.fluid.incubate.fleet.parameter_server.ir import (
            pserver_pass as server,
        )
T
tangwei12 已提交
195

196
        if not compiled_config.is_geo_mode():
T
tangwei12 已提交
197

198 199 200 201
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _get_optimize_ops,
            )

T
tangwei12 已提交
202 203 204 205 206 207 208 209
            is_sgd_adam = False

            main_program = compiled_config.get_origin_main_program()
            ops = _get_optimize_ops(main_program)

            if len(ops) == 0:
                return _main, _startup

210 211 212 213
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import (
                _add_lr_decay_table_pass,
            )

214
            lr_decay_steps = self.user_defined_strategy.a_sync_configs[
215 216 217 218 219
                "lr_decay_steps"
            ]
            _add_lr_decay_table_pass(
                main_program, compiled_config, lr_decay_steps
            )
220

T
tangwei12 已提交
221 222 223 224 225 226 227 228
            for op in ops:
                if op.type in ["sgd", "adam"]:
                    is_sgd_adam = True
                    break

            if is_sgd_adam:
                return _main, _startup

229 230 231
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
232 233 234
            _main = server.large_scale_sparse_pass(
                _main, _main, compiled_config, False
            )
235
            _startup = server.build_pserver_startup_program_pass(
236 237 238 239 240
                _startup, _main, compiled_config
            )
            _startup = server.large_scale_sparse_pass(
                _startup, _main, compiled_config, True
            )
241 242

            if not compiled_config.is_sync_mode():
243
                _main = server.delete_unused_in_main_pass(
244 245
                    _main, compiled_config
                )
246

247
            _startup = server.delete_unused_in_startup_pass(
248 249
                _startup, _main, compiled_config
            )
250 251 252 253 254
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _startup = server.build_pserver_startup_program_pass(
255 256
                _startup, _main, compiled_config
            )
257
            _startup = server.delete_unused_in_startup_pass(
258 259
                _startup, _main, compiled_config
            )
260 261 262

        return _main, _startup

263
    def _can_apply_geo(self, dist_strategy, program):
264 265 266
        def get_sys_free_mem():
            plat = platform.system()
            if platform.system() == "Darwin":
267 268 269
                vm = subprocess.Popen(
                    ['vm_stat'], stdout=subprocess.PIPE
                ).communicate()[0]
270 271
                # Process vm_stat
                vmLines = vm.split('\n')
272
                sep = re.compile(r':[\s]+')
273 274 275 276
                vmStats = {}
                for row in range(1, len(vmLines) - 2):
                    rowText = vmLines[row].strip()
                    rowElements = sep.split(rowText)
277 278 279
                    vmStats[(rowElements[0])] = (
                        int(rowElements[1].strip(r'\.')) * 4096
                    )
280 281 282 283 284 285 286 287 288 289 290
                return vmStats["Pages free"]
            elif platform.system() == "Linux":
                mems = {}
                with open('/proc/meminfo', 'rb') as f:
                    for line in f:
                        fields = line.split()
                        mems[fields[0]] = int(fields[1]) * 1024
                free = mems[b'MemFree:']
                return free
            else:
                raise ValueError(
291 292 293
                    "%s platform is unsupported is parameter server optimizer"
                    % (platform.system())
                )
294 295

        if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
296
            return False
297 298 299

        free = get_sys_free_mem()

300 301 302
        from paddle.fluid.incubate.fleet.parameter_server.ir import (
            vars_metatools,
        )
303

304
        processed_var_names = set(["@EMPTY@"])
305
        param_memory_size = 0
306 307
        for varname in program.global_block().vars:
            var = program.global_block().vars[varname]
308 309 310 311
            if (
                not var.persistable
                or var.desc.type() != core.VarDesc.VarType.LOD_TENSOR
            ):
312 313
                continue
            param = vars_metatools.create_var_struct(var)
314
            param_memory_size += param.m_size
315
            processed_var_names.add(varname)
316 317 318 319

        upper_mem_use = param_memory_size * 5.0

        program_tmp_vars = dict()
320
        eval_batch_size = 1024
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
        for op in program.global_block().ops:
            for var_name in op.output_arg_names:
                if var_name in processed_var_names:
                    continue
                processed_var_names.add(var_name)
                var = program.global_block().vars[var_name]

                if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
                    continue

                data_count = 1
                neg_dim_count = 0
                for x in var.shape:
                    if x < 0:
                        if neg_dim_count >= 1:
                            raise ValueError(
337 338 339
                                "Var %s has more than one negative dim."
                                % (var_name)
                            )
340
                        neg_dim_count += 1
341
                        data_count *= -x
342 343
                    else:
                        data_count *= x
344
                program_tmp_vars[var_name] = (
345 346 347 348
                    data_count,
                    neg_dim_count,
                    vars_metatools.dtype_to_size[var.dtype],
                )
349 350 351 352

        for varname in program_tmp_vars:
            data_count, neg_dim_count, type_size = program_tmp_vars[varname]
            if neg_dim_count == 1:
353
                data_count *= eval_batch_size
354 355 356 357
            var_memory = data_count * type_size
            upper_mem_use += var_memory

        if upper_mem_use < free:
358
            return True
359
        else:
360
            return False
361

362 363 364 365 366 367
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
        self.inner_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set
        )
368
        strategy = self._get_distributed_strategy()
369 370 371

        _origin_main_program = loss.block.program
        _origin_startup_program = startup_program
372 373 374 375 376 377 378 379 380 381
        from paddle.fluid.incubate.fleet.parameter_server.ir import (
            public as public,
        )

        compiled_config = public.CompileTimeStrategy(
            _origin_main_program,
            _origin_startup_program,
            strategy,
            self.role_maker,
        )
382
        compiled_config.strategy = strategy
383

384
        if self.role_maker._is_worker() or self.role_maker._is_heter_worker():
385
            main_program, startup_program = self._build_trainer_programs(
386 387
                compiled_config
            )
388 389 390
            if self.role_maker._is_heter_parameter_server_mode:
                _origin_startup_program._heter_pipeline_opt = {
                    "startup_program": startup_program,
391 392
                    "pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
                    "heter_place": self.role_maker._heter_device(),
393 394 395 396 397
                }

                loss.block.program._heter_pipeline_opt = {
                    "trainer": "HeterPipelineTrainer",
                    "device_worker": "HeterSection",
398
                    "trainers": self.role_maker._get_stage_trainers(),  # trainer num in each stage
399 400
                    "trainer_id": int(self.role_maker._role_id()),
                    "pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
401 402 403
                    "num_pipeline_stages": int(
                        self.role_maker._get_num_stage()
                    ),
404 405
                    "section_program": main_program,
                    "num_microbatches": self.num_microbatches,
406
                    "heter_place": self.role_maker._heter_device(),
407 408 409 410 411
                }
            else:
                loss.block.program = main_program
                fluid.framework.switch_startup_program(startup_program)

412
        elif self.role_maker._is_server():
413
            main_program, startup_program = self._build_pserver_programs(
414 415
                compiled_config
            )
416 417
            loss.block.program = main_program
            fluid.framework.switch_startup_program(startup_program)
418 419 420
        return None, None

    def _disable_strategy(self, dist_strategy):
421
        # if self.role_maker._is_heter_parameter_server_mode:
422 423 424 425 426
        #    dist_strategy.pipeline = False
        #    dist_strategy.pipeline_configs = {
        #        "micro_batch_size": 1,
        #        "accumulate_steps": 1,
        #    }
427 428 429 430 431 432
        dist_strategy.a_sync = False
        a_sync_configs = dist_strategy.a_sync_configs
        a_sync_configs["k_steps"] = -1
        dist_strategy.a_sync_configs = a_sync_configs

    def _enable_strategy(self, dist_strategy, context):
433
        # if self.role_maker._is_heter_parameter_server_mode:
434 435 436 437 438
        #    dist_strategy.pipeline = True
        #    dist_strategy.pipeline_configs = {
        #        "micro_batch_size": 1,
        #        "accumulate_steps": 1,
        #    }
439 440 441
        a_sync_configs = dist_strategy.a_sync_configs
        if a_sync_configs["k_steps"] >= 0:
            return
442 443

        dist_strategy.a_sync = True
444 445
        a_sync_configs = dist_strategy.a_sync_configs

446 447 448
        is_geo = self._can_apply_geo(
            dist_strategy, context["origin_main_program"]
        )
449 450 451 452 453 454

        if is_geo:
            a_sync_configs["k_steps"] = 800
        else:
            a_sync_configs["k_steps"] = 0
        dist_strategy.a_sync_configs = a_sync_configs