pipeline_optimizer.py 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
from __future__ import print_function
15
from __future__ import division
16 17 18 19

import paddle.fluid as fluid
from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
20 21
from paddle.fluid.optimizer import PipelineOptimizer as PO
from .meta_optimizer_base import MetaOptimizerBase
22
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
23 24


25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
def _get_node_num(endpoints):
    ss = set()
    for ep in endpoints:
        ip = ep.split(":")[0].strip()
        if ip not in ss:
            ss.add(ip)
    return len(ss)


class PipelineHelper(object):
    def __init__(self, role_maker, wait_port='6174'):
        self.wait_port = wait_port
        self.role_maker = role_maker

    def update_startup_program(self,
                               startup_program=None,
                               inner_parallelism=None):
        self.startup_program = startup_program

        nranks = self.role_maker._worker_num()
        rank = self.role_maker._worker_index()
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[rank]
        node_num = _get_node_num(endpoints)
        assert nranks % node_num == 0

        # Create ring 0 for all gpus in the same pipeline
        if inner_parallelism > 1:
            pipeline_rank = rank % inner_parallelism
            pipeline_id = rank // inner_parallelism
            start_index = pipeline_id * inner_parallelism
            pipeline_endpoints = endpoints[start_index:start_index +
                                           inner_parallelism]
            self._init_communicator(self.startup_program, current_endpoint,
                                    pipeline_endpoints, pipeline_rank, 0,
                                    self.wait_port)
61 62 63

        pipeline_num = len(endpoints) // inner_parallelism
        if pipeline_num == 1: return
64
        # Create rings for gpus with the same pipeline id for data parallel
65
        eps = []
66 67
        pipeline_rank = rank % inner_parallelism
        ring_id = pipeline_rank + 1
68
        for i in range(pipeline_num):
69 70 71
            eps.append(endpoints[i * inner_parallelism + pipeline_rank])
        # rank in a ring of gpus with the same pipeline id for data parallel
        dp_rank = rank // inner_parallelism
72
        self._init_communicator(self.startup_program, current_endpoint, eps,
73
                                dp_rank, ring_id, self.wait_port)
74
        self._broadcast_params(ring_id)
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    def _init_communicator(self, program, current_endpoint, endpoints, rank,
                           ring_id, wait_port):
        nranks = len(endpoints)
        other_endpoints = endpoints[:]
        other_endpoints.remove(current_endpoint)
        if rank == 0 and wait_port:
            wait_server_ready(other_endpoints)

        block = program.global_block()
        nccl_id_var = block.create_var(
            name=unique_name.generate('nccl_id'),
            persistable=True,
            type=core.VarDesc.VarType.RAW)
        block.append_op(
            type='c_gen_nccl_id',
            inputs={},
            outputs={'Out': nccl_id_var},
            attrs={
                'rank': rank,
                'endpoint': current_endpoint,
                'other_endpoints': other_endpoints,
97
                OP_ROLE_KEY: OpRole.Forward,
98 99 100 101 102 103 104 105 106 107 108 109
            })
        block.append_op(
            type='c_comm_init',
            inputs={'X': nccl_id_var},
            outputs={},
            attrs={
                'nranks': nranks,
                'rank': rank,
                'ring_id': ring_id,
                OP_ROLE_KEY: OpRole.Forward,
            })

110
    def _broadcast_params(self, ring_id):
111
        block = self.startup_program.global_block()
112 113 114 115
        for var_name in block.vars:
            if "nccl_id" in var_name: continue
            param = block.var(var_name)
            if not param.persistable:
116 117 118 119 120 121 122 123 124 125 126 127
                continue

            block.append_op(
                type='c_broadcast',
                inputs={'X': param},
                outputs={'Out': param},
                attrs={
                    'ring_id': ring_id,
                    'root': 0,
                    OP_ROLE_KEY: OpRole.Forward
                })

128 129 130 131 132 133
        block.append_op(
            type='c_sync_comm_stream',
            inputs={'X': param},
            outputs={'Out': param},
            attrs={'ring_id': ring_id,
                   OP_ROLE_KEY: OpRole.Forward})
134 135


136 137 138 139 140 141
class PipelineOptimizer(MetaOptimizerBase):
    def __init__(self, optimizer):
        super(PipelineOptimizer, self).__init__(optimizer)
        self.inner_opt = optimizer
        # we do not allow meta optimizer to be inner optimizer currently
        self.meta_optimizers_white_list = []
142
        self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
143 144 145 146 147

    def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
                        user_defined_strategy):
        super(PipelineOptimizer, self)._set_basic_info(
            loss, role_maker, user_defined_optimizer, user_defined_strategy)
148 149
        self.micro_batch_size = user_defined_strategy.pipeline_configs[
            'micro_batch_size']
150
        self.num_microbatches = user_defined_strategy.pipeline_configs[
151
            'accumulate_steps']
152 153

    def _can_apply(self):
154 155 156
        if not self.role_maker._is_collective:
            return False

157 158 159 160 161 162
        if self.user_defined_strategy.pipeline == True:
            return True
        return False

    def _disable_strategy(self, dist_strategy):
        dist_strategy.pipeline = False
163
        dist_strategy.pipeline_configs = {}
164

165
    def _enable_strategy(self, dist_strategy, context):
166
        dist_strategy.pipeline = True
167 168 169 170
        dist_strategy.pipeline_configs = {
            "micro_batch_size": 1,
            "accumulate_steps": 1,
        }
171

172 173 174 175 176
    def minimize_impl(self,
                      loss,
                      startup_program=None,
                      parameter_list=None,
                      no_grad_set=None):
177 178
        endpoints = self.role_maker._get_trainer_endpoints()
        current_endpoint = endpoints[self.role_maker._worker_index()]
179
        self.wrapped_opt = PO(self.inner_opt,
180
                              num_microbatches=self.num_microbatches)
181 182
        node_num = _get_node_num(endpoints)
        gpus_per_node = len(endpoints) // node_num
183 184 185 186
        self.startup_program = startup_program
        if startup_program is None:
            self.startup_program = fluid.default_startup_program()

187 188 189
        self.rank = self.role_maker._worker_index()
        self.nranks = self.role_maker._worker_num()
        assert self.nranks % node_num == 0
190

191 192
        loss.block.program._pipeline_opt = dict()
        loss.block.program._pipeline_opt['local_rank'] = self.rank
193 194
        loss.block.program._pipeline_opt[
            'micro_batch_size'] = self.micro_batch_size
195 196
        optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
            loss, startup_program, parameter_list, no_grad_set)
197
        assert prog_list
198

199 200
        self.main_program_list = prog_list
        self.main_program = loss.block.program
201 202
        self.inner_parallelism = loss.block.program._pipeline_opt[
            'inner_parallelism']
203
        assert self.nranks % self.inner_parallelism == 0
204

205 206 207 208
        pipeline_helper = PipelineHelper(self.role_maker)
        pipeline_helper.update_startup_program(
            self.startup_program._pipeline_opt["startup_program"],
            self.inner_parallelism)
209

210 211
        pipeline_num = self.nranks // self.inner_parallelism
        self._transpile_main_program(loss, pipeline_num, self.inner_parallelism)
212
        return optimize_ops, params_grads
213

214 215 216 217
    def _transpile_main_program(self, loss, pipeline_num, inner_parallelism):
        if pipeline_num <= 1: return
        self._insert_loss_grad_ops(loss, pipeline_num)
        for ring_id in range(1, inner_parallelism + 1):
218 219
            self._insert_allreduce_ops(ring_id)

220
    def _insert_loss_grad_ops(self, loss, pipeline_num):
221 222 223 224
        """
        In order to keep the learning rate consistent in different numbers of
        training workers, we scale the loss grad by the number of workers
        """
225
        block = self.main_program_list[-1]['program'].global_block()
226 227 228 229 230 231 232 233 234
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_loss_grad_op(op):
                loss_grad_var = block.vars[op.output_arg_names[0]]
                block._insert_op(
                    idx + 1,
                    type='scale',
                    inputs={'X': loss_grad_var},
                    outputs={'Out': loss_grad_var},
                    attrs={
235
                        'scale': 1.0 / pipeline_num,
236 237 238 239
                        OP_ROLE_KEY: OpRole.Backward
                    })

    def _insert_allreduce_ops(self, ring_id):
240
        block = self.main_program_list[ring_id - 1]['program'].global_block()
241 242
        origin_block = self.main_program.global_block()
        grad = None
243
        processed_param_name = set()
244 245
        for idx, op in reversed(list(enumerate(block.ops))):
            if is_backward_op(op) and \
246
                    OP_ROLE_VAR_KEY in op.attr_names:
247 248 249 250 251 252
                op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
                if len(op_role_var) == 0:
                    continue
                assert len(op_role_var) % 2 == 0
                offset = idx
                for i in range(0, len(op_role_var), 2):
253
                    param_name = op_role_var[i]
254
                    param = block.vars[op_role_var[i]]
255 256
                    if param_name in processed_param_name: continue
                    processed_param_name.add(param_name)
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
                    grad = block.vars[op_role_var[i + 1]]
                    origin_param = origin_block.vars[op_role_var[i]]
                    if origin_param.is_distributed:
                        continue
                    if offset == idx:
                        offset += 1
                        block._insert_op(
                            offset,
                            type='c_sync_calc_stream',
                            inputs={'X': grad},
                            outputs={'Out': grad},
                            attrs={OP_ROLE_KEY: OpRole.Backward})
                        offset += 1

                    block._insert_op(
                        offset,
273
                        type='c_allreduce_sum',
274 275 276 277 278 279 280 281 282 283 284 285 286
                        inputs={'X': grad},
                        outputs={'Out': grad},
                        attrs={
                            'ring_id': ring_id,
                            OP_ROLE_KEY: OpRole.Backward
                        })

        if grad is None:
            return

        for idx, op in enumerate(block.ops):
            if is_optimizer_op(op):
                block._insert_op(
287
                    idx,
288 289 290 291 292 293
                    type='c_sync_comm_stream',
                    inputs={'X': grad},
                    outputs={'Out': grad},
                    attrs={'ring_id': ring_id,
                           OP_ROLE_KEY: OpRole.Backward})
            break