parameter_server_optimizer.py 12.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase
16 17 18
from paddle.fluid import core
import subprocess
import re
19
import os
20
import platform
21
from ..base.private_helper_function import wait_server_ready
22 23


24
class ParameterServerOptimizer(MetaOptimizerBase):
25
    def __init__(self, optimizer):
26
        super(ParameterServerOptimizer, self).__init__(optimizer)
27 28 29 30 31 32 33 34 35 36
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []

    def _is_graph_out(self):
        return False

    def _can_apply(self):
        if self.role_maker._is_collective:
            return False
37

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        return True if k_steps >= 0 else False

    def _get_distributed_strategy(self):
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory

        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        strategy = None

        if not self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if self.user_defined_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def _build_trainer_programs(self, compiled_config):
        from paddle.fluid.incubate.fleet.parameter_server.ir import trainer_pass as worker

        _main = compiled_config.origin_main_program.clone()
        _startup = compiled_config.origin_startup_program.clone()

C
Chengmo 已提交
67 68 69 70 71
        from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
        _add_lr_decay_table_pass(
            _main, compiled_config,
            self.user_defined_strategy.a_sync_configs["lr_decay_steps"])

72 73 74 75 76 77 78 79 80 81
        if not compiled_config.is_geo_mode():
            # for main program
            _main = worker.delete_optimizer_pass(_main, compiled_config)
            _main = worker.distributed_ops_pass(_main, compiled_config)
            _main = worker.append_send_ops_pass(_main, compiled_config)

            # for startup program
            _startup = worker.fake_init_ops_pass(_startup, compiled_config)
            _startup = worker.delet_extra_optimizes_pass(_startup,
                                                         compiled_config)
82

83 84
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
85 86 87 88 89 90 91 92 93 94 95 96 97 98
            # for heter program
            if self.role_maker._is_heter_parameter_server_mode:
                from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
                if self.role_maker._is_heter_worker():
                    # for heter worker
                    _main = heter_worker.split_heter_worker_ops_pass(
                        _main, compiled_config)
                else:
                    # for default worker
                    _main = heter_worker.split_trainer_ops_pass(_main,
                                                                compiled_config)
                # for startup change
                _startup = heter_worker.delete_startup_useless_ops_var_pass(
                    _startup, _main, compiled_config)
99 100 101
        else:
            _main = worker.append_send_ops_pass(_main, compiled_config)
            _startup = _startup
102 103
            compiled_config.set_origin_ps_main_program(_main)
            compiled_config.set_origin_ps_startup_program(_startup)
104

105 106 107 108 109 110 111 112
        launch_barrier = self.user_defined_strategy.a_sync_configs[
            "launch_barrier"]
        launch_barrier_flag = int(os.getenv("FLAGS_LAUNCH_BARRIER", "1"))
        if launch_barrier and launch_barrier_flag:
            # for trainer wait server ready
            wait_server_ready(self.role_maker._get_pserver_endpoints())

            # for ps-heter mode, wait heter worker ready
T
tangwei12 已提交
113 114 115
            # if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker(
            # ):
            #     wait_server_ready(self.role_maker._get_heter_worker_endpoints())
116

117 118 119 120 121 122
        return _main, _startup

    def _build_pserver_programs(self, compiled_config):
        _main = fluid.Program()
        _startup = fluid.Program()

T
tangwei12 已提交
123 124
        from paddle.fluid.incubate.fleet.parameter_server.ir import pserver_pass as server

125
        if not compiled_config.is_geo_mode():
T
tangwei12 已提交
126 127 128 129 130 131 132 133 134 135

            from paddle.fluid.incubate.fleet.parameter_server.ir.public import _get_optimize_ops
            is_sgd_adam = False

            main_program = compiled_config.get_origin_main_program()
            ops = _get_optimize_ops(main_program)

            if len(ops) == 0:
                return _main, _startup

C
Chengmo 已提交
136 137 138 139 140 141
            from paddle.fluid.incubate.fleet.parameter_server.ir.public import _add_lr_decay_table_pass
            lr_decay_steps = self.user_defined_strategy.a_sync_configs[
                "lr_decay_steps"]
            _add_lr_decay_table_pass(main_program, compiled_config,
                                     lr_decay_steps)

T
tangwei12 已提交
142 143 144 145 146 147 148 149
            for op in ops:
                if op.type in ["sgd", "adam"]:
                    is_sgd_adam = True
                    break

            if is_sgd_adam:
                return _main, _startup

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)

            if not compiled_config.is_sync_mode():
                _main = server.delete_unused_in_main_pass(_main,
                                                          compiled_config)

            _startup = server.delete_unused_in_startup_pass(_startup, _main,
                                                            compiled_config)
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.delete_unused_in_startup_pass(_startup, _main,
                                                            compiled_config)

        return _main, _startup

177
    def _can_apply_geo(self, dist_strategy, program):
178 179 180 181 182 183 184
        def get_sys_free_mem():
            plat = platform.system()
            if platform.system() == "Darwin":
                vm = subprocess.Popen(
                    ['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
                # Process vm_stat
                vmLines = vm.split('\n')
185
                sep = re.compile(r':[\s]+')
186 187 188 189 190
                vmStats = {}
                for row in range(1, len(vmLines) - 2):
                    rowText = vmLines[row].strip()
                    rowElements = sep.split(rowText)
                    vmStats[(rowElements[0]
191
                             )] = int(rowElements[1].strip(r'\.')) * 4096
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
                return vmStats["Pages free"]
            elif platform.system() == "Linux":
                mems = {}
                with open('/proc/meminfo', 'rb') as f:
                    for line in f:
                        fields = line.split()
                        mems[fields[0]] = int(fields[1]) * 1024
                free = mems[b'MemFree:']
                return free
            else:
                raise ValueError(
                    "%s platform is unsupported is parameter server optimizer" %
                    (platform.system()))

        if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
207
            return False
208 209 210

        free = get_sys_free_mem()

211
        from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
212

213
        processed_var_names = set(["@EMPTY@"])
214
        param_memory_size = 0
215 216 217 218 219 220
        for varname in program.global_block().vars:
            var = program.global_block().vars[varname]
            if not var.persistable or var.desc.type(
            ) != core.VarDesc.VarType.LOD_TENSOR:
                continue
            param = vars_metatools.create_var_struct(var)
221
            param_memory_size += param.m_size
222
            processed_var_names.add(varname)
223 224 225 226

        upper_mem_use = param_memory_size * 5.0

        program_tmp_vars = dict()
227
        eval_batch_size = 1024
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
        for op in program.global_block().ops:
            for var_name in op.output_arg_names:
                if var_name in processed_var_names:
                    continue
                processed_var_names.add(var_name)
                var = program.global_block().vars[var_name]

                if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
                    continue

                data_count = 1
                neg_dim_count = 0
                for x in var.shape:
                    if x < 0:
                        if neg_dim_count >= 1:
                            raise ValueError(
                                "Var %s has more than one negative dim." %
                                (var_name))
                        neg_dim_count += 1
                        data_count *= (-x)
                    else:
                        data_count *= x
250 251 252
                program_tmp_vars[var_name] = (
                    data_count, neg_dim_count,
                    vars_metatools.dtype_to_size[var.dtype])
253 254 255 256

        for varname in program_tmp_vars:
            data_count, neg_dim_count, type_size = program_tmp_vars[varname]
            if neg_dim_count == 1:
257
                data_count *= eval_batch_size
258 259 260 261
            var_memory = data_count * type_size
            upper_mem_use += var_memory

        if upper_mem_use < free:
262
            return True
263
        else:
264
            return False
265

266 267 268 269 270 271 272
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        self.inner_opt.minimize(loss, startup_program, parameter_list,
                                no_grad_set)
273
        strategy = self._get_distributed_strategy()
274 275 276 277 278 279 280

        _origin_main_program = loss.block.program
        _origin_startup_program = startup_program
        from paddle.fluid.incubate.fleet.parameter_server.ir import public as public

        compiled_config = public.CompileTimeStrategy(_origin_main_program,
                                                     _origin_startup_program,
281
                                                     strategy, self.role_maker)
282
        compiled_config.strategy = strategy
283

284
        if self.role_maker._is_worker() or self.role_maker._is_heter_worker():
285 286
            main_program, startup_program = self._build_trainer_programs(
                compiled_config)
287
        elif self.role_maker._is_server():
288 289
            main_program, startup_program = self._build_pserver_programs(
                compiled_config)
290 291 292 293 294 295 296

        loss.block.program = main_program
        fluid.framework.switch_startup_program(startup_program)

        return None, None

    def _disable_strategy(self, dist_strategy):
297 298 299 300 301 302 303 304 305
        dist_strategy.a_sync = False
        a_sync_configs = dist_strategy.a_sync_configs
        a_sync_configs["k_steps"] = -1
        dist_strategy.a_sync_configs = a_sync_configs

    def _enable_strategy(self, dist_strategy, context):
        a_sync_configs = dist_strategy.a_sync_configs
        if a_sync_configs["k_steps"] >= 0:
            return
306 307

        dist_strategy.a_sync = True
308 309 310 311 312 313 314 315 316 317
        a_sync_configs = dist_strategy.a_sync_configs

        is_geo = self._can_apply_geo(dist_strategy,
                                     context["origin_main_program"])

        if is_geo:
            a_sync_configs["k_steps"] = 800
        else:
            a_sync_configs["k_steps"] = 0
        dist_strategy.a_sync_configs = a_sync_configs