“13e99cf92fe7e64caa94ddeeb84d8e2a168ca3ec”上不存在“python/git@gitcode.net:paddlepaddle/PaddleDetection.git”
parameter_server_optimizer.py 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from paddle import fluid
from .meta_optimizer_base import MetaOptimizerBase
16 17 18 19
from paddle.fluid import core
import subprocess
import re
import platform
20 21


22
class ParameterServerOptimizer(MetaOptimizerBase):
23
    def __init__(self, optimizer):
24
        super(ParameterServerOptimizer, self).__init__(optimizer)
25 26 27 28 29 30 31 32 33 34
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []

    def _is_graph_out(self):
        return False

    def _can_apply(self):
        if self.role_maker._is_collective:
            return False
35

36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        return True if k_steps >= 0 else False

    def _get_distributed_strategy(self):
        from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory

        k_steps = self.user_defined_strategy.a_sync_configs["k_steps"]
        strategy = None

        if not self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_sync_strategy()

        if self.user_defined_strategy.a_sync and k_steps == 0:
            strategy = StrategyFactory.create_async_strategy()

        if self.user_defined_strategy.a_sync and k_steps > 0:
            strategy = StrategyFactory.create_geo_strategy(k_steps)

        if not strategy:
            raise ValueError("k_steps must be invalid value, please check")

        return strategy

    def _build_trainer_programs(self, compiled_config):
        from paddle.fluid.incubate.fleet.parameter_server.ir import trainer_pass as worker

        _main = compiled_config.origin_main_program.clone()
        _startup = compiled_config.origin_startup_program.clone()

        if not compiled_config.is_geo_mode():
            # for main program
            _main = worker.delete_optimizer_pass(_main, compiled_config)
            _main = worker.distributed_ops_pass(_main, compiled_config)
            _main = worker.append_send_ops_pass(_main, compiled_config)

            # for startup program
            _startup = worker.fake_init_ops_pass(_startup, compiled_config)
            _startup = worker.init_from_server_pass(_startup, compiled_config)
            _startup = worker.delet_extra_optimizes_pass(_startup,
                                                         compiled_config)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

            # for heter program
            if self.role_maker._is_heter_parameter_server_mode:
                from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker
                if self.role_maker._is_heter_worker():
                    # for heter worker
                    _main = heter_worker.split_heter_worker_ops_pass(
                        _main, compiled_config)
                else:
                    # for default worker
                    _main = heter_worker.split_trainer_ops_pass(_main,
                                                                compiled_config)
                # for startup change
                _startup = heter_worker.delete_startup_useless_ops_var_pass(
                    _startup, _main, compiled_config)
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
        else:
            _main = worker.append_send_ops_pass(_main, compiled_config)
            _startup = _startup

        return _main, _startup

    def _build_pserver_programs(self, compiled_config):
        from paddle.fluid.incubate.fleet.parameter_server.ir import pserver_pass as server

        _main = fluid.Program()
        _startup = fluid.Program()

        if not compiled_config.is_geo_mode():
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)

            if not compiled_config.is_sync_mode():
                _main = server.delete_unused_in_main_pass(_main,
                                                          compiled_config)

            _startup = server.delete_unused_in_startup_pass(_startup, _main,
                                                            compiled_config)
        else:
            _main = server.add_listen_and_serv_pass(_main, compiled_config)
            _main = server.add_rpc_global_flags_pass(_main, compiled_config)
            _main = server.add_geo_optimizer_pass(_main, compiled_config)
            _main = server.large_scale_sparse_pass(_main, _main,
                                                   compiled_config, False)
            _startup = server.build_pserver_startup_program_pass(
                _startup, _main, compiled_config)
            _startup = server.large_scale_sparse_pass(_startup, _main,
                                                      compiled_config, True)
            _startup = server.delete_unused_in_startup_pass(_startup, _main,
                                                            compiled_config)

        return _main, _startup

135
    def _can_apply_geo(self, dist_strategy, program):
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164
        def get_sys_free_mem():
            plat = platform.system()
            if platform.system() == "Darwin":
                vm = subprocess.Popen(
                    ['vm_stat'], stdout=subprocess.PIPE).communicate()[0]
                # Process vm_stat
                vmLines = vm.split('\n')
                sep = re.compile(':[\s]+')
                vmStats = {}
                for row in range(1, len(vmLines) - 2):
                    rowText = vmLines[row].strip()
                    rowElements = sep.split(rowText)
                    vmStats[(rowElements[0]
                             )] = int(rowElements[1].strip('\.')) * 4096
                return vmStats["Pages free"]
            elif platform.system() == "Linux":
                mems = {}
                with open('/proc/meminfo', 'rb') as f:
                    for line in f:
                        fields = line.split()
                        mems[fields[0]] = int(fields[1]) * 1024
                free = mems[b'MemFree:']
                return free
            else:
                raise ValueError(
                    "%s platform is unsupported is parameter server optimizer" %
                    (platform.system()))

        if not isinstance(self.inner_opt, fluid.optimizer.SGDOptimizer):
165
            return False
166 167 168

        free = get_sys_free_mem()

169
        from paddle.fluid.incubate.fleet.parameter_server.ir import vars_metatools
170

171
        processed_var_names = set(["@EMPTY@"])
172
        param_memory_size = 0
173 174 175 176 177 178
        for varname in program.global_block().vars:
            var = program.global_block().vars[varname]
            if not var.persistable or var.desc.type(
            ) != core.VarDesc.VarType.LOD_TENSOR:
                continue
            param = vars_metatools.create_var_struct(var)
179
            param_memory_size += param.m_size
180
            processed_var_names.add(varname)
181 182 183 184

        upper_mem_use = param_memory_size * 5.0

        program_tmp_vars = dict()
185
        eval_batch_size = 1024
186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
        for op in program.global_block().ops:
            for var_name in op.output_arg_names:
                if var_name in processed_var_names:
                    continue
                processed_var_names.add(var_name)
                var = program.global_block().vars[var_name]

                if var.desc.type() != core.VarDesc.VarType.LOD_TENSOR:
                    continue

                data_count = 1
                neg_dim_count = 0
                for x in var.shape:
                    if x < 0:
                        if neg_dim_count >= 1:
                            raise ValueError(
                                "Var %s has more than one negative dim." %
                                (var_name))
                        neg_dim_count += 1
                        data_count *= (-x)
                    else:
                        data_count *= x
208 209 210
                program_tmp_vars[var_name] = (
                    data_count, neg_dim_count,
                    vars_metatools.dtype_to_size[var.dtype])
211 212 213 214

        for varname in program_tmp_vars:
            data_count, neg_dim_count, type_size = program_tmp_vars[varname]
            if neg_dim_count == 1:
215
                data_count *= eval_batch_size
216 217 218 219
            var_memory = data_count * type_size
            upper_mem_use += var_memory

        if upper_mem_use < free:
220
            return True
221
        else:
222
            return False
223

224 225 226 227 228 229 230
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
        self.inner_opt.minimize(loss, startup_program, parameter_list,
                                no_grad_set)
231
        strategy = self._get_distributed_strategy()
232 233 234 235 236 237 238

        _origin_main_program = loss.block.program
        _origin_startup_program = startup_program
        from paddle.fluid.incubate.fleet.parameter_server.ir import public as public

        compiled_config = public.CompileTimeStrategy(_origin_main_program,
                                                     _origin_startup_program,
239
                                                     strategy, self.role_maker)
240
        compiled_config.strategy = strategy
241

242 243 244 245 246 247
        if self.role_maker.is_worker() or self.role_maker._is_heter_worker():
            main_program, startup_program = self._build_trainer_programs(
                compiled_config)
        elif self.role_maker.is_server():
            main_program, startup_program = self._build_pserver_programs(
                compiled_config)
248 249 250 251 252 253 254

        loss.block.program = main_program
        fluid.framework.switch_startup_program(startup_program)

        return None, None

    def _disable_strategy(self, dist_strategy):
255 256 257 258 259 260 261 262 263
        dist_strategy.a_sync = False
        a_sync_configs = dist_strategy.a_sync_configs
        a_sync_configs["k_steps"] = -1
        dist_strategy.a_sync_configs = a_sync_configs

    def _enable_strategy(self, dist_strategy, context):
        a_sync_configs = dist_strategy.a_sync_configs
        if a_sync_configs["k_steps"] >= 0:
            return
264 265

        dist_strategy.a_sync = True
266 267 268 269 270 271 272 273 274 275
        a_sync_configs = dist_strategy.a_sync_configs

        is_geo = self._can_apply_geo(dist_strategy,
                                     context["origin_main_program"])

        if is_geo:
            a_sync_configs["k_steps"] = 800
        else:
            a_sync_configs["k_steps"] = 0
        dist_strategy.a_sync_configs = a_sync_configs