parameter_server_optimizer.py 16.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14 15 16 17 18
import os
import platform
import re
import subprocess

19
import paddle
20
from paddle import fluid
21
from paddle.framework import core
22

23
from ..base.private_helper_function import wait_server_ready
24
from .meta_optimizer_base import MetaOptimizerBase
25

26 27
__all__ = []

28

29
class ParameterServerOptimizer(MetaOptimizerBase):
30
    def __init__(self, optimizer):
31
        super().__init__(optimizer)
32 33 34 35
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []

36 37 38
    def _set_basic_info(
        self, loss, role_maker, user_defined_optimizer, user_defined_strategy
    ):
39
        super()._set_basic_info(
40 41
            loss, role_maker, user_defined_optimizer, user_defined_strategy
        )
42

43
        # self.micro_batch_size = user_defined_strategy.pipeline_configs[
44 45
        #    'micro_batch_size']
        self.num_microbatches = user_defined_strategy.pipeline_configs[
46 47
            'accumulate_steps'
        ]
48

49 50 51 52 53 54
    def _is_graph_out(self):
        return False

    def _can_apply(self):
        if self.role_maker._is_collective:
            return False
55

56 57 58
        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        return True if k_steps >= 0 else False

T
Thunderbrook 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72
    def get_dist_env(self):
        trainer_id = int(os.getenv('PADDLE_TRAINER_ID', '0'))
        trainer_endpoints = ''
        current_endpoint = ''
        num_trainers = 0
        if os.getenv('PADDLE_TRAINER_ENDPOINTS'):
            trainer_endpoints = os.getenv('PADDLE_TRAINER_ENDPOINTS')
            current_endpoint = trainer_endpoints.split(',')[trainer_id]
            num_trainers = len(trainer_endpoints.split(','))

        return {
            'trainer_id': trainer_id,
            'num_trainers': num_trainers,
            'current_endpoint': current_endpoint,
73
            'trainer_endpoints': trainer_endpoints,
T
Thunderbrook 已提交
74 75
        }

76
    def _get_distributed_strategy(self):
W
wangzhen38 已提交
77
        from paddle.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import (
78 79
            StrategyFactory,
        )
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98

        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        strategy = None

        if not self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if self.user_defined_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def _build_trainer_programs(self, compiled_config):
99
        from paddle.incubate.fleet.parameter_server.ir import (
100 101
            trainer_pass as worker,
        )
102 103 104 105

        _main = compiled_config.origin_main_program.clone()
        _startup = compiled_config.origin_startup_program.clone()

T
Thunderbrook 已提交
106 107
        use_ps_gpu = self.user_defined_strategy.a_sync_configs["use_ps_gpu"]

108
        if not compiled_config.is_geo_mode():
109
            from paddle.incubate.fleet.parameter_server.ir.public import (
110 111 112
                _add_lr_decay_table_pass,
            )

113
            _add_lr_decay_table_pass(
114 115 116 117
                _main,
                compiled_config,
                self.user_defined_strategy.a_sync_configs["lr_decay_steps"],
            )
118

119
            # for main program
120 121 122
            _main = worker.distributed_ops_pass(
                _main, compiled_config, use_ps_gpu
            )
T
Thunderbrook 已提交
123 124 125
            if not use_ps_gpu:
                _main = worker.delete_optimizer_pass(_main, compiled_config)
                _main = worker.append_send_ops_pass(_main, compiled_config)
126
                _startup = worker.delete_extra_optimizes_pass(
127 128
                    _startup, compiled_config
                )
T
Thunderbrook 已提交
129 130

                # for startup program
131
            _startup = worker.fake_init_ops_pass(_startup, compiled_config)
T
Thunderbrook 已提交
132 133
            if use_ps_gpu:
                _main = worker.ps_gpu_pass(_main)
134
                from paddle.distributed.transpiler.collective import (
135 136 137
                    SingleProcessMultiThread,
                )

T
Thunderbrook 已提交
138 139
                t = SingleProcessMultiThread()
                env = self.get_dist_env()
140 141 142 143 144 145 146 147
                t.transpile(
                    startup_program=_startup,
                    main_program=_main,
                    rank=env["trainer_id"],
                    endpoints=env["trainer_endpoints"],
                    current_endpoint=env['current_endpoint'],
                    wait_port=False,
                )
148

149 150
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
151 152
            # for heter program
            if self.role_maker._is_heter_parameter_server_mode:
153
                from paddle.incubate.fleet.parameter_server.ir import (
154 155 156
                    heter_trainer_pass as heter_worker,
                )

157 158
                if self.role_maker._is_heter_worker():
                    # for heter worker
159 160
                    stage_id = self.role_maker._get_stage_id()
                    device = self.role_maker._heter_device_type().lower()
161
                    _main = heter_worker.split_heter_worker_ops_pass(
162 163
                        _main, compiled_config, stage_id, device
                    )
164 165
                else:
                    # for default worker
166
                    _main = heter_worker.split_trainer_ops_pass(
167 168
                        _main, compiled_config
                    )
169 170 171
        else:
            _main = worker.append_send_ops_pass(_main, compiled_config)
            _startup = _startup
172 173
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
174

175
        launch_barrier = self.user_defined_strategy.a_sync_configs[
176 177
            "launch_barrier"
        ]
178 179 180 181 182 183
        launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
        if launch_barrier and launch_barrier_flag:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
T
tangwei12 已提交
184 185 186
            # if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker(
            # ):
            #     wait_server_ready(self.role_maker._get_heter_worker_endpoints())
187

188 189 190
        return _main, _startup

    def _build_pserver_programs(self, compiled_config):
191 192
        _main = paddle.static.Program()
        _startup = paddle.static.Program()
193

194
        from paddle.incubate.fleet.parameter_server.ir import (
195 196
            pserver_pass as server,
        )
T
tangwei12 已提交
197

198
        if not compiled_config.is_geo_mode():
T
tangwei12 已提交
199

200
            from paddle.incubate.fleet.parameter_server.ir.public import (
201 202 203
                _get_optimize_ops,
            )

T
tangwei12 已提交
204 205 206 207 208 209 210 211
            is_sgd_adam = False

            main_program = compiled_config.get_origin_main_program()
            ops = _get_optimize_ops(main_program)

            if len(ops) == 0:
                return _main, _startup

212
            from paddle.incubate.fleet.parameter_server.ir.public import (
213 214 215
                _add_lr_decay_table_pass,
            )

216
            lr_decay_steps = self.user_defined_strategy.a_sync_configs[
217 218 219 220 221
                "lr_decay_steps"
            ]
            _add_lr_decay_table_pass(
                main_program, compiled_config, lr_decay_steps
            )
222

T
tangwei12 已提交
223 224 225 226 227 228 229 230
            for op in ops:
                if op.type in ["sgd", "adam"]:
                    is_sgd_adam = True
                    break

            if is_sgd_adam:
                return _main, _startup

231 232 233
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
234 235 236
            _main = server.large_scale_sparse_pass(
                _main, _main, compiled_config, False
            )
237
            _startup = server.build_pserver_startup_program_pass(
238 239 240 241 242
                _startup, _main, compiled_config
            )
            _startup = server.large_scale_sparse_pass(
                _startup, _main, compiled_config, True
            )
243 244

            if not compiled_config.is_sync_mode():
245
                _main = server.delete_unused_in_main_pass(
246 247
                    _main, compiled_config
                )
248

249
            _startup = server.delete_unused_in_startup_pass(
250 251
                _startup, _main, compiled_config
            )
252 253 254 255 256
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _startup = server.build_pserver_startup_program_pass(
257 258
                _startup, _main, compiled_config
            )
259
            _startup = server.delete_unused_in_startup_pass(
260 261
                _startup, _main, compiled_config
            )
262 263 264

        return _main, _startup

265
    def _can_apply_geo(self, dist_strategy, program):
266 267 268
        def get_sys_free_mem():
            plat = platform.system()
            if platform.system() == "Darwin":
269 270 271
                vm = subprocess.Popen(
                    ['vm_stat'], stdout=subprocess.PIPE
                ).communicate()[0]
272 273
                # Process vm_stat
                vmLines = vm.split('\n')
274
                sep = re.compile(r':[\s]+')
275 276 277 278
                vmStats = {}
                for row in range(1, len(vmLines) - 2):
                    rowText = vmLines[row].strip()
                    rowElements = sep.split(rowText)
279 280 281
                    vmStats[(rowElements[0])] = (
                        int(rowElements[1].strip(r'\.')) * 4096
                    )
282 283 284 285 286 287 288 289 290 291 292
                return vmStats["Pages free"]
            elif platform.system() == "Linux":
                mems = {}
                with open('/proc/meminfo', 'rb') as f:
                    for line in f:
                        fields = line.split()
                        mems[fields[0]] = int(fields[1]) * 1024
                free = mems[b'MemFree:']
                return free
            else:
                raise ValueError(
293 294 295
                    "%s platform is unsupported is parameter server optimizer"
                    % (platform.system())
                )
296 297

        if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
298
            return False
299 300 301

        free = get_sys_free_mem()

302
        from paddle.incubate.fleet.parameter_server.ir import vars_metatools
303

304
        processed_var_names = set(["@EMPTY@"])
305
        param_memory_size = 0
306 307
        for varname in program.global_block().vars:
            var = program.global_block().vars[varname]
308 309 310 311
            if (
                not var.persistable
                or var.desc.type() != core.VarDesc.VarType.LOD_TENSOR
            ):
312 313
                continue
            param = vars_metatools.create_var_struct(var)
314
            param_memory_size += param.m_size
315
            processed_var_names.add(varname)
316 317 318 319

        upper_mem_use = param_memory_size * 5.0

        program_tmp_vars = dict()
320
        eval_batch_size = 1024
321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
        for op in program.global_block().ops:
            for var_name in op.output_arg_names:
                if var_name in processed_var_names:
                    continue
                processed_var_names.add(var_name)
                var = program.global_block().vars[var_name]

                if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
                    continue

                data_count = 1
                neg_dim_count = 0
                for x in var.shape:
                    if x < 0:
                        if neg_dim_count >= 1:
                            raise ValueError(
337 338 339
                                "Var %s has more than one negative dim."
                                % (var_name)
                            )
340
                        neg_dim_count += 1
341
                        data_count *= -x
342 343
                    else:
                        data_count *= x
344
                program_tmp_vars[var_name] = (
345 346 347 348
                    data_count,
                    neg_dim_count,
                    vars_metatools.dtype_to_size[var.dtype],
                )
349 350 351 352

        for varname in program_tmp_vars:
            data_count, neg_dim_count, type_size = program_tmp_vars[varname]
            if neg_dim_count == 1:
353
                data_count *= eval_batch_size
354 355 356 357
            var_memory = data_count * type_size
            upper_mem_use += var_memory

        if upper_mem_use < free:
358
            return True
359
        else:
360
            return False
361

362 363 364 365 366 367
    def minimize_impl(
        self, loss, startup_program=None, parameter_list=None, no_grad_set=None
    ):
        self.inner_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set
        )
368
        strategy = self._get_distributed_strategy()
369 370 371

        _origin_main_program = loss.block.program
        _origin_startup_program = startup_program
372
        from paddle.incubate.fleet.parameter_server.ir import public as public
373 374 375 376 377 378 379

        compiled_config = public.CompileTimeStrategy(
            _origin_main_program,
            _origin_startup_program,
            strategy,
            self.role_maker,
        )
380
        compiled_config.strategy = strategy
381

382
        if self.role_maker._is_worker() or self.role_maker._is_heter_worker():
383
            main_program, startup_program = self._build_trainer_programs(
384 385
                compiled_config
            )
386 387 388
            if self.role_maker._is_heter_parameter_server_mode:
                _origin_startup_program._heter_pipeline_opt = {
                    "startup_program": startup_program,
389 390
                    "pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
                    "heter_place": self.role_maker._heter_device(),
391 392 393 394 395
                }

                loss.block.program._heter_pipeline_opt = {
                    "trainer": "HeterPipelineTrainer",
                    "device_worker": "HeterSection",
396
                    "trainers": self.role_maker._get_stage_trainers(),  # trainer num in each stage
397 398
                    "trainer_id": int(self.role_maker._role_id()),
                    "pipeline_stage": int(self.role_maker._get_stage_id()) - 1,
399 400 401
                    "num_pipeline_stages": int(
                        self.role_maker._get_num_stage()
                    ),
402 403
                    "section_program": main_program,
                    "num_microbatches": self.num_microbatches,
404
                    "heter_place": self.role_maker._heter_device(),
405 406 407
                }
            else:
                loss.block.program = main_program
408
                paddle.framework.switch_startup_program(startup_program)
409

410
        elif self.role_maker._is_server():
411
            main_program, startup_program = self._build_pserver_programs(
412 413
                compiled_config
            )
414
            loss.block.program = main_program
415
            paddle.framework.switch_startup_program(startup_program)
416 417 418
        return None, None

    def _disable_strategy(self, dist_strategy):
419
        # if self.role_maker._is_heter_parameter_server_mode:
420 421 422 423 424
        #    dist_strategy.pipeline = False
        #    dist_strategy.pipeline_configs = {
        #        "micro_batch_size": 1,
        #        "accumulate_steps": 1,
        #    }
425 426 427 428 429 430
        dist_strategy.a_sync = False
        a_sync_configs = dist_strategy.a_sync_configs
        a_sync_configs["k_steps"] = -1
        dist_strategy.a_sync_configs = a_sync_configs

    def _enable_strategy(self, dist_strategy, context):
431
        # if self.role_maker._is_heter_parameter_server_mode:
432 433 434 435 436
        #    dist_strategy.pipeline = True
        #    dist_strategy.pipeline_configs = {
        #        "micro_batch_size": 1,
        #        "accumulate_steps": 1,
        #    }
437 438 439
        a_sync_configs = dist_strategy.a_sync_configs
        if a_sync_configs["k_steps"] >= 0:
            return
440 441

        dist_strategy.a_sync = True
442 443
        a_sync_configs = dist_strategy.a_sync_configs

444 445 446
        is_geo = self._can_apply_geo(
            dist_strategy, context["origin_main_program"]
        )
447 448 449 450 451 452

        if is_geo:
            a_sync_configs["k_steps"] = 800
        else:
            a_sync_configs["k_steps"] = 0
        dist_strategy.a_sync_configs = a_sync_configs