pipeline_parallel.py 22.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

14
import numpy as np
15 16 17 18

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
19
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype
20 21
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
22 23 24

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
25
from ..utils.log_util import logger
26
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28

29 30
__all__ = []

31 32 33

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
34 35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
37 38 39 40 41
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

42 43
        self.is_pipe_partitioned = self.use_model_parallel

44 45 46 47 48 49
        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
50

51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
68
        self.pp_group = self._hcg.get_pipe_parallel_group()
S
ShenLiang 已提交
69
        p2p.initialize_p2p_groups(hcg)
70 71 72 73 74

        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()

75 76 77
        self.mp_degree = self._hcg.get_model_parallel_world_size()
        self.mp_rank = self._hcg.get_model_parallel_rank()

78 79 80 81 82 83 84 85
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
86
            logger.info("start broadcast dp parameters")
87
            broadcast_dp_parameters(self._layers, self._hcg)
88

89
    def _init_caches(self, num_caches):
90 91
        if self.num_caches >= num_caches:
            return
92
        self.num_caches = num_caches - self.num_caches
93
        for key in self.caches:
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
            self.caches[key].extend([None] * self.num_caches)

    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.clone() / self.accumulate_steps
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
        else:
            loss = paddle.to_tensor(0.0)
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
113

114
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
115 116
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
117
        self.optimizer = optimizer
118
        self.lr_scheduler = lr_scheduler
119
        self.scaler = scaler
120 121 122
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

123 124
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
125 126
                "For the first and the last stage, the data_iter must be set.")
        else:
127 128
            data = None

129
        self.data = data
130 131
        self._layers.train()

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        # store total loss of entire batch
        self.total_loss = None
        self._init_caches(self.accumulate_steps)
        startup_steps = self.num_stages - self.stage_id - 1
        forward_steps = 0
        backward_steps = 0

        # forward
        while (forward_steps < self.accumulate_steps):
            self._forward(cache_id=forward_steps)
            forward_steps += 1

        # backward
        while (backward_steps < self.accumulate_steps):
            self._backward(cache_id=backward_steps)
            backward_steps += 1

149 150
        self._layers.allreduce_shared_weight_gradients()

151 152
        # optimizer
        self.train_loss = self._reduce_final_loss()
153
        self._step()
154
        return self.train_loss
155 156

    def _forward(self, cache_id):
157 158 159 160 161
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

162
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
163
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
164
        else:
L
lilong12 已提交
165
            inputs = self.caches['inputs'][cache_id]
166

167
        self._clear_grads(inputs)
168
        outputs = self._layers.forward(inputs)
169

170 171
        self.caches['outputs'][cache_id] = outputs

172
        if self.is_last_stage:
173 174 175 176
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

177
        if self.is_last_stage:
178 179 180 181 182 183 184 185 186 187 188 189
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
190

191 192
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
193

194
            self.caches['outputs'][cache_id] = self.current_loss.clone()
195

196 197
        else:
            self._send_activations(cache_id)
198 199

    def _backward(self, cache_id):
200
        if self.is_last_stage:
201 202 203 204 205 206
            if self.scaler:
                paddle.autograd.backward(
                    self.scaler.scale(self.caches['outputs'][cache_id]))
            else:
                paddle.autograd.backward(self.caches['outputs'][cache_id])

207
            self._send_gradients(cache_id)
208
            return
209
        self._recv_gradients(cache_id)
210 211 212 213 214

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
215
            out_tensors = [t for t in outputs if is_float_tensor(t)]
216 217 218 219 220 221 222 223
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
224 225 226
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None

227 228 229 230 231 232
    def _broadcast_data(self, data):
        if isinstance(data, paddle.Tensor):
            paddle.distributed.broadcast(
                data,
                src=self._hcg.get_model_parallel_group_src_rank(),
                group=self._hcg.get_model_parallel_group())
233
        else:
234 235
            for d in data:
                assert isinstance(d, paddle.Tensor)
236
                paddle.distributed.broadcast(
237
                    d,
238 239 240
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
        return data
241 242

    def _load_micro_batch(self, cache_id):
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[0] = self._broadcast_data(inputs[0])
            if isinstance(inputs[0], tuple):
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[0]
                ]
                self.caches['inputs'][cache_id] = tuple(data)
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone(
                ).detach()
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[1] = self._broadcast_data(inputs[1])
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[1]
                ]
                self.caches['labels'][cache_id] = tuple(data)
278
            else:
279 280 281 282 283 284 285
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone(
                ).detach()
        else:
            # No data input is required for other stages
            inputs = None
286 287 288 289

    def _send_meta(self, data, peer):
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
290
            # send tensor type
S
ShenLiang 已提交
291
            p2p.send(tensor_type, self.next_stage_id)
292 293

            # send len(shape)
294
            dims = paddle.to_tensor(len(data.shape))
S
ShenLiang 已提交
295
            p2p.send(dims, self.next_stage_id)
296 297

            # send shape
298
            shape = paddle.to_tensor(data.shape)
S
ShenLiang 已提交
299
            p2p.send(shape, self.next_stage_id)
300 301 302

            # send dtype
            dtype = paddle.to_tensor(paddle_2_number(data.dtype))
S
ShenLiang 已提交
303
            p2p.send(dtype, self.next_stage_id)
304

305 306
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
S
ShenLiang 已提交
307 308
            p2p.send(tensor_type, self.next_stage_id)

309
            nums = paddle.to_tensor(len(data))
S
ShenLiang 已提交
310 311
            p2p.send(nums, self.next_stage_id)

312 313
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
314
                # send len(shape)
315
                dims = paddle.to_tensor(len(d.shape))
S
ShenLiang 已提交
316
                p2p.send(dims, self.next_stage_id)
317 318

                # send shape
319
                shape = paddle.to_tensor(d.shape)
S
ShenLiang 已提交
320
                p2p.send(shape, self.next_stage_id)
321

322 323
                # send dtype
                dtype = paddle.to_tensor(paddle_2_number(d.dtype))
S
ShenLiang 已提交
324
                p2p.send(dtype, self.next_stage_id)
325

326 327
    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
S
ShenLiang 已提交
328 329
        p2p.recv(tensor_type, self.prev_stage_id)

330
        tensor_type = tensor_type.item()
331 332

        if tensor_type == 0:
333
            # recv len(shape)
334
            dims = paddle.to_tensor([0])
S
ShenLiang 已提交
335 336
            p2p.recv(dims, self.prev_stage_id)

337 338 339
            dims = dims.item()

            # recv shape
340
            shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
341 342
            p2p.recv(shape, self.prev_stage_id)

343
            shape = shape.numpy().tolist()
344 345 346

            # recv dtype
            dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
347 348
            p2p.recv(dtype, self.prev_stage_id)

349 350
            return self._allocate_cache(
                shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
351 352
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
S
ShenLiang 已提交
353
            p2p.recv(num, self.prev_stage_id)
354
            num = num.item()
355
            shapes = []
356
            dtypes = []
357
            for i in range(num):
358
                # recv len(shape)
359
                dims = paddle.to_tensor([0])
S
ShenLiang 已提交
360
                p2p.recv(dims, self.prev_stage_id)
361 362 363

                # recv shape
                dims = dims.item()
364
                shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
365
                p2p.recv(shape, self.prev_stage_id)
366 367
                shapes.append(shape.numpy().tolist())

368 369
                # recv dtype
                dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
370
                p2p.recv(dtype, self.prev_stage_id)
371 372 373
                dtypes.append(number_2_dtype(dtype.item()))

            caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
374 375
            caches = tuple(caches)
            return caches
376

377 378 379 380 381
    def _is_valid_send_recv(self, tensor):
        tensor_numel = np.prod(tensor.shape)
        assert tensor_numel != 0, "can't send/recv zero element"
        return tensor_numel % self.mp_degree == 0

382 383 384 385 386 387 388 389
    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
390 391 392 393 394 395 396 397
            if self.is_pipe_partitioned and self._is_valid_send_recv(outputs):
                p2p.send_partial(
                    outputs.detach(),
                    self.next_stage_id,
                    mp_degree=self.mp_degree,
                    mp_rank=self.mp_rank)
            else:
                p2p.send(outputs.detach(), self.next_stage_id)
S
ShenLiang 已提交
398

399 400
        elif isinstance(outputs, tuple):
            for output in outputs:
401 402 403 404 405 406 407 408 409
                if self.is_pipe_partitioned and self._is_valid_send_recv(
                        output):
                    p2p.send_partial(
                        output.detach(),
                        self.next_stage_id,
                        mp_degree=self.mp_degree,
                        mp_rank=self.mp_rank)
                else:
                    p2p.send(output.detach(), self.next_stage_id)
410 411 412 413 414

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]
        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
415 416 417 418 419 420 421 422 423
            if self.is_pipe_partitioned and self._is_valid_send_recv(
                    inputs.grad):
                grad = p2p.send_partial(
                    inputs.grad,
                    self.prev_stage_id,
                    mp_degree=self.mp_degree,
                    mp_rank=self.mp_rank)
            else:
                p2p.send(inputs.grad, self.prev_stage_id)
424 425 426
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
427
                if not is_float_tensor(d):
428 429
                    assert d.grad is None
                    continue
430 431 432 433 434 435 436 437 438 439

                if self.is_pipe_partitioned and self._is_valid_send_recv(
                        d.grad):
                    grad = p2p.send_partial(
                        d.grad,
                        self.prev_stage_id,
                        mp_degree=self.mp_degree,
                        mp_rank=self.mp_rank)
                else:
                    p2p.send(d.grad, self.prev_stage_id)
S
ShenLiang 已提交
440

441 442 443 444 445 446 447 448
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
449 450 451 452 453 454 455 456 457 458 459 460 461
            if self.is_pipe_partitioned and self._is_valid_send_recv(
                    self.recv_cache):
                p2p.recv_partial(self.recv_cache, self.prev_stage_id,
                                 self.mp_degree, self.mp_rank)
                p2p.partial_allgather_operator(
                    self.recv_cache,
                    mp_ranks=self.mp_degree,
                    mp_rank_id=self.mp_rank,
                    group=self._hcg.get_model_parallel_group(),
                    use_calc_stream=True)
            else:
                p2p.recv(self.recv_cache, self.prev_stage_id)

462
            inputs = self.recv_cache.clone().detach()
463
            inputs.stop_gradient = not is_float_tensor(inputs)
464

465 466 467 468
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
469 470 471 472 473 474 475 476 477 478 479 480 481
                if self.is_pipe_partitioned and self._is_valid_send_recv(d):
                    assert isinstance(d, paddle.Tensor)
                    p2p.recv_partial(d, self.prev_stage_id, self.mp_degree,
                                     self.mp_rank)
                    p2p.partial_allgather_operator(
                        d,
                        mp_ranks=self.mp_degree,
                        mp_rank_id=self.mp_rank,
                        group=self._hcg.get_model_parallel_group(),
                        use_calc_stream=True)
                else:
                    assert isinstance(d, paddle.Tensor)
                    p2p.recv(d, self.prev_stage_id)
482 483 484 485 486
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
487
                d.stop_gradient = not is_float_tensor(d)
488 489 490 491 492 493 494 495

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
496 497 498
                dtype = get_tensor_dtype(outputs.dtype)
                self.grad_tensors = self._allocate_cache(
                    s, dtype, num_caches=1)[0]
499
            else:
500
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
501 502 503 504 505
                dtypes = [
                    get_tensor_dtype(d.dtype) for d in outputs
                    if is_float_tensor(d)
                ]
                self.grad_tensors = self._allocate_caches(
506
                    sizes, dtypes, num_caches=1)[0]
507 508

        if isinstance(self.grad_tensors, paddle.Tensor):
509 510 511 512 513 514 515 516 517 518 519 520 521
            if self.is_pipe_partitioned and self._is_valid_send_recv(
                    self.grad_tensors):
                p2p.recv_partial(self.grad_tensors, self.next_stage_id,
                                 self.mp_degree, self.mp_rank)
                p2p.partial_allgather_operator(
                    self.grad_tensors,
                    mp_ranks=self.mp_degree,
                    mp_rank_id=self.mp_rank,
                    group=self._hcg.get_model_parallel_group(),
                    use_calc_stream=True)
            else:
                p2p.recv(self.grad_tensors, self.next_stage_id)

522 523 524
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
525 526 527 528 529 530 531 532 533 534 535
                if self.is_pipe_partitioned and self._is_valid_send_recv(d):
                    p2p.recv_partial(d, self.next_stage_id, self.mp_degree,
                                     self.mp_rank)
                    p2p.partial_allgather_operator(
                        d,
                        mp_ranks=self.mp_degree,
                        mp_rank_id=self.mp_rank,
                        group=self._hcg.get_model_parallel_group(),
                        use_calc_stream=True)
                else:
                    p2p.recv(d, self.next_stage_id)
536 537

    def _step(self):
538 539 540 541
        if self.scaler:
            self.scaler.minimize(self.optimizer, self.train_loss)
        else:
            self.optimizer.step()
542 543 544
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
545 546 547 548 549 550 551 552 553 554 555 556 557

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

558
    def _allocate_cache(self, shape, dtype, num_caches=-1):
559 560 561 562 563 564 565
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

566
    def _allocate_caches(self, shapes, dtypes, num_caches=-1):
567 568 569 570 571
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
572
            for shape, dtype in zip(shapes, dtypes):
573 574 575
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
576 577 578 579 580 581 582 583 584 585 586

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")