pipeline_parallel.py 18.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import time
import copy
import os

from types import MethodType

from numpy import prod

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
25
from .pp_utils.utils import get_tensor_bytes, is_float_tensor
26 27
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
28 29 30 31

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
from ..utils.hybrid_parallel_util import fused_allreduce_gradients
32
from ..utils.log_util import logger
33

34 35 36 37 38 39 40 41
__all__ = []

FLOAT_TYPES = [
    paddle.float16,
    paddle.float32,
    paddle.float64,
]

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
        super(PipelineParallel, self).__init__(layers, hcg, strategy)

        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
57

58 59 60 61 62 63 64 65
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

66 67
        self.use_amp = self._strategy.amp
        self.init_loss_scaling = self._strategy.amp_configs['init_loss_scaling']
68 69 70 71 72 73 74 75 76
        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
77 78 79 80 81 82 83 84 85 86 87
        self.pp_group = self._hcg.get_pipe_parallel_group()
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_dp_parameters(self._layers, self._hcg)
88 89 90 91 92 93 94 95 96 97

    def _allocate_caches(self, num_caches):
        if self.num_caches >= num_caches:
            return

        num = num_caches - self.num_caches
        self.num_caches = num_caches
        for key in self.caches:
            self.caches[key].extend([None] * num)

98
    def train_batch(self, data, optimizer):
99 100 101 102 103
        self.optimizer = optimizer
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

        if self.stage_id == 0 or self.stage_id == self.num_stages - 1:
104
            assert data, (
105 106
                "For the first and the last stage, the data_iter must be set.")
        else:
107
            assert data is None, (
108 109
                "For pipe stages other than the first and the last one, "
                "the data_iter must be None.")
110
        self.data = data
111 112 113 114 115 116 117 118 119
        self._layers.train()
        self.total_loss = None

        minibatch_cmds = utils.TrainGenerator(self.accumulate_steps,
                                              self.num_stages, self.stage_id)
        self._train(minibatch_cmds)
        return self.total_loss

    def _train(self, minibatch_cmds):
120 121 122 123 124
        self._allocate_caches(self.accumulate_steps)
        for micro_cmds in minibatch_cmds:
            for cmd in micro_cmds:
                assert type(cmd) in self._COMMAND_MAP, "unknow cmd: {}".format(
                    type(cmd))
125 126 127 128
                self._apply_cmd = MethodType(self._COMMAND_MAP[type(cmd)], self)
                self._apply_cmd(**cmd.kwargs)

    def _allreduce_grads(self):
129 130
        if not self.use_data_parallel: return
        fused_allreduce_gradients(list(self._layers.parameters()), self._hcg)
131 132

    def _forward(self, cache_id):
133 134 135 136 137
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

138
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
139
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
140
        else:
L
lilong12 已提交
141
            inputs = self.caches['inputs'][cache_id]
142 143 144 145 146

        self._clear_grads(inputs)
        outputs = self._layers.forward(inputs)
        self.caches['outputs'][cache_id] = outputs

147 148 149 150 151
        if self.stage_id == self.num_stages - 1:
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

152 153 154 155 156 157 158 159 160 161 162 163 164
        if self.stage_id == self.num_stages - 1:
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
165 166 167 168 169 170 171 172
            if self.use_data_parallel:
                self.current_loss = self.current_loss / self._hcg.get_data_parallel_world_size(
                )
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
            self.caches['outputs'][cache_id] = self.current_loss.clone()
        else:
            self._send_activations(cache_id)
173 174 175 176

    def _backward(self, cache_id):
        assert self.optimizer is not None
        if self.stage_id == self.num_stages - 1:
177 178
            paddle.autograd.backward(self.caches['outputs'][cache_id])
            self._send_gradients(cache_id)
179
            return
180
        self._recv_gradients(cache_id)
181 182 183 184 185

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
186
            out_tensors = [t for t in outputs if is_float_tensor(t)]
187 188 189 190 191 192 193 194
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None
        #self.caches['backward_tensors'][cache_id] = None

    def _get_data(self):
        if self.use_model_parallel:
            mp_rank = self._hcg.get_model_parallel_rank()
        else:
            mp_rank = 0

        # mp rank 0 loads the data and broadcat it to others.
        data = self.data
        if self.use_model_parallel and (self.stage_id == 0 or
                                        self.stage_id == self.num_stages - 1):
            assert isinstance(data, (tuple, paddle.Tensor))
            if isinstance(data, paddle.Tensor):
                paddle.distributed.broadcast(
                    data,
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
            else:
                data = []
                for d in self.data:
                    assert isinstance(d, paddle.Tensor)
                    paddle.distributed.broadcast(
                        d,
                        src=self._hcg.get_model_parallel_group_src_rank(),
                        group=self._hcg.get_model_parallel_group())
                    data.append(d)
            data = tuple(data)
        return data
226 227 228 229 230 231

    def _load_micro_batch(self, cache_id):
        inputs = self._get_data()

        if self.stage_id == 0:
            data = None
232 233 234
            #if isinstance(inputs[0], paddle.Tensor):
            if len(inputs) == 1:
                assert isinstance(inputs[0], paddle.Tensor)
235
                data = inputs[0].clone().detach()
236 237
                #data.stop_gradient = not is_float_tensor(data)
                data.stop_gradient = True
238
            else:
239
                assert isinstance(inputs, tuple)
240
                data = []
241
                for d in inputs:
242
                    assert isinstance(d, paddle.Tensor)
243 244 245 246
                    i = d.clone().detach()
                    #i.stop_gradient = not is_float_tensor(i)
                    i.stop_gradient = True
                    data.append(i)
247 248 249 250
                data = tuple(data)
            self.caches['inputs'][cache_id] = data

        if self.stage_id == self.num_stages - 1:
251 252 253 254 255 256 257 258 259 260 261 262 263
            labels = None
            #if isinstance(inputs[1], paddle.Tensor):
            if len(inputs) == 1:
                assert isinstance(inputs[0], paddle.Tensor)
                labels = inputs[0]
            elif isinstance(inputs, tuple):
                labels = []
                for label in inputs:
                    assert isinstance(label, paddle.Tensor)
                    label = label.detach()
                    labels.append(label)
                labels = tuple(labels)
            self.caches['labels'][cache_id] = labels
264 265 266 267 268 269 270 271 272 273 274

    def _send_meta(self, data, peer):
        """
        % type (0: tensor, 1: tuple)
        % num_tensors if type=tuple
        foreach tensor:
          % ndims
          % shape
        """
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
275 276
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
277
            dims = paddle.to_tensor(len(data.shape))
278 279
            paddle.distributed.send(
                dims, peer, use_calc_stream=True, group=self.pp_group)
280
            shape = paddle.to_tensor(data.shape)
281 282
            paddle.distributed.send(
                shape, peer, use_calc_stream=True, group=self.pp_group)
283 284
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
285 286
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
287
            nums = paddle.to_tensor(len(data))
288 289
            paddle.distributed.send(
                nums, peer, use_calc_stream=True, group=self.pp_group)
290 291 292
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
                dims = paddle.to_tensor(len(d.shape))
293 294
                paddle.distributed.send(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
295
                shape = paddle.to_tensor(d.shape)
296 297
                paddle.distributed.send(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
298 299 300

    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
301 302
        paddle.distributed.recv(
            tensor_type, peer, use_calc_stream=True, group=self.pp_group)
303 304 305 306
        tensor_type = tensor_type.numpy()[0]

        if tensor_type == 0:
            dims = paddle.to_tensor([0])
307 308
            paddle.distributed.recv(
                dims, peer, use_calc_stream=True, group=self.pp_group)
309 310
            dims = dims.numpy()[0]
            shape = paddle.to_tensor([0] * dims)
311 312
            paddle.distributed.recv(
                shape, peer, use_calc_stream=True, group=self.pp_group)
313 314 315 316 317
            shape = shape.numpy().tolist()
            return self._allocate_buffer(
                shape, dtype="float32", num_caches=1)[0]
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
318 319
            paddle.distributed.recv(
                num, peer, use_calc_stream=True, group=self.pp_group)
320 321 322 323
            num = num.numpy()[0]
            shapes = []
            for i in range(num):
                dims = paddle.to_tensor([0])
324 325
                paddle.distributed.recv(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
326 327
                dims = dims.numpy()[0]
                shape = paddle.to_tensor([0] * dims)
328 329
                paddle.distributed.recv(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
330 331 332
                shapes.append(shape.numpy().tolist())

            dtypes = ["float32"] * len(shapes)
333 334 335
            caches = self._allocate_buffers(shapes, dtypes, num_caches=1)[0]
            caches = tuple(caches)
            return caches
336 337 338 339 340 341 342 343 344

    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
345 346 347 348 349
            paddle.distributed.send(
                outputs,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
350 351
        elif isinstance(outputs, tuple):
            for output in outputs:
352 353 354 355 356
                paddle.distributed.send(
                    output,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
357 358 359 360 361 362 363

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]

        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
            paddle.distributed.send(
364 365 366 367
                paddle.to_tensor(inputs.grad),
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
368 369 370
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
371
                if not is_float_tensor(d):
372 373 374
                    assert d.grad is None
                    continue
                assert d.grad is not None
375 376 377 378 379
                paddle.distributed.send(
                    d.grad,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
380 381 382 383 384 385 386 387 388 389
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None

        # Allocate the buffer if necessary
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
390 391 392 393 394
            paddle.distributed.recv(
                self.recv_cache,
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
395
            inputs = self.recv_cache.clone().detach()
396
            inputs.stop_gradient = not is_float_tensor(inputs)
397 398 399 400 401 402
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
                assert isinstance(d, paddle.Tensor)

403 404 405 406 407
                paddle.distributed.recv(
                    d,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
408 409 410 411 412
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
413
                d.stop_gradient = not is_float_tensor(d)
414 415 416 417 418 419 420 421

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
422
                dtype = 'float16' if self.use_amp else "float32"
423 424 425
                self.grad_tensors = self._allocate_buffer(
                    s, dtype, num_buffers=1)[0]
            else:
426 427 428
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
                dtypes = ['float16'] * len(
                    sizes) if self.use_amp else ['float32'] * len(sizes)
429
                self.grad_tensors = self._allocate_buffers(
430
                    sizes, dtypes, num_caches=1)[0]
431 432

        if isinstance(self.grad_tensors, paddle.Tensor):
433 434 435 436 437
            paddle.distributed.recv(
                self.grad_tensors,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
438 439 440
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
441 442 443 444 445 446 447 448
                paddle.distributed.recv(
                    d,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)

    def _step(self):
        self._allreduce_grads()
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
        self.optimizer.step()
        self.optimizer.clear_gradients()

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

464 465 466 467 468 469 470 471 472 473 474 475 476 477
    def _allocate_buffer(self, shape, dtype, num_caches=-1):
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

    def _allocate_buffers(self, shapes, dtypes, num_caches=-1):
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
478
            for shape, dtype in zip(shapes, dtypes):
479 480 481
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    _COMMAND_MAP = {
        utils.Optimize: _step,
        utils.Forward: _forward,
        utils.Backward: _backward,
    }

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")