pipeline_parallel.py 19.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from types import MethodType

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
19
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype
20 21
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
22 23 24

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
25
from ..utils.log_util import logger
26
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
27

28 29
__all__ = []

30 31 32

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
33 34 35
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
36 37 38 39 40 41 42 43 44 45 46
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
65
        self.pp_group = self._hcg.get_pipe_parallel_group()
66 67 68 69 70

        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()

71 72 73 74 75 76 77 78
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
79
            logger.info("start broadcast dp parameters")
80
            broadcast_dp_parameters(self._layers, self._hcg)
81

82
    def _init_caches(self, num_caches):
83 84
        if self.num_caches >= num_caches:
            return
85
        self.num_caches = num_caches - self.num_caches
86
        for key in self.caches:
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            self.caches[key].extend([None] * self.num_caches)

    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.clone() / self.accumulate_steps
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
        else:
            loss = paddle.to_tensor(0.0)
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
106

107 108 109
    def train_batch(self, data, optimizer, lr_scheduler=None):
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
110
        self.optimizer = optimizer
111
        self.lr_scheduler = lr_scheduler
112 113 114
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

115 116
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
117 118
                "For the first and the last stage, the data_iter must be set.")
        else:
119 120
            data = None

121
        self.data = data
122 123
        self._layers.train()

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
        # store total loss of entire batch
        self.total_loss = None
        self._init_caches(self.accumulate_steps)
        startup_steps = self.num_stages - self.stage_id - 1
        forward_steps = 0
        backward_steps = 0

        # forward
        while (forward_steps < self.accumulate_steps):
            self._forward(cache_id=forward_steps)
            forward_steps += 1

        # backward
        while (backward_steps < self.accumulate_steps):
            self._backward(cache_id=backward_steps)
            backward_steps += 1

141 142
        self._layers.allreduce_shared_weight_gradients()

143 144 145 146
        # optimizer
        self._step()
        self.train_loss = self._reduce_final_loss()
        return self.train_loss
147 148

    def _forward(self, cache_id):
149 150 151 152 153
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

154
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
155
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
156
        else:
L
lilong12 已提交
157
            inputs = self.caches['inputs'][cache_id]
158 159

        outputs = self._layers.forward(inputs)
160 161
        self._clear_grads(inputs)

162 163
        self.caches['outputs'][cache_id] = outputs

164
        if self.is_last_stage:
165 166 167 168
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

169
        if self.is_last_stage:
170 171 172 173 174 175 176 177 178 179 180 181
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
182

183 184
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
185

186
            self.caches['outputs'][cache_id] = self.current_loss.clone()
187

188 189
        else:
            self._send_activations(cache_id)
190 191

    def _backward(self, cache_id):
192
        if self.is_last_stage:
193 194
            paddle.autograd.backward(self.caches['outputs'][cache_id])
            self._send_gradients(cache_id)
195
            return
196
        self._recv_gradients(cache_id)
197 198 199 200 201

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
202
            out_tensors = [t for t in outputs if is_float_tensor(t)]
203 204 205 206 207 208 209 210
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
211 212 213
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None

214 215 216 217 218 219
    def _broadcast_data(self, data):
        if isinstance(data, paddle.Tensor):
            paddle.distributed.broadcast(
                data,
                src=self._hcg.get_model_parallel_group_src_rank(),
                group=self._hcg.get_model_parallel_group())
220
        else:
221 222
            for d in data:
                assert isinstance(d, paddle.Tensor)
223
                paddle.distributed.broadcast(
224
                    d,
225 226 227
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
        return data
228 229

    def _load_micro_batch(self, cache_id):
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[0] = self._broadcast_data(inputs[0])
            if isinstance(inputs[0], tuple):
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[0]
                ]
                self.caches['inputs'][cache_id] = tuple(data)
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone(
                ).detach()
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[1] = self._broadcast_data(inputs[1])
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[1]
                ]
                self.caches['labels'][cache_id] = tuple(data)
265
            else:
266 267 268 269 270 271 272
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone(
                ).detach()
        else:
            # No data input is required for other stages
            inputs = None
273 274 275 276

    def _send_meta(self, data, peer):
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
277
            # send tensor type
278 279
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
280 281

            # send len(shape)
282
            dims = paddle.to_tensor(len(data.shape))
283 284
            paddle.distributed.send(
                dims, peer, use_calc_stream=True, group=self.pp_group)
285 286

            # send shape
287
            shape = paddle.to_tensor(data.shape)
288 289
            paddle.distributed.send(
                shape, peer, use_calc_stream=True, group=self.pp_group)
290 291 292 293 294 295

            # send dtype
            dtype = paddle.to_tensor(paddle_2_number(data.dtype))
            paddle.distributed.send(
                dtype, peer, use_calc_stream=True, group=self.pp_group)

296 297
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
298 299
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
300
            nums = paddle.to_tensor(len(data))
301 302
            paddle.distributed.send(
                nums, peer, use_calc_stream=True, group=self.pp_group)
303 304
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
305
                # send len(shape)
306
                dims = paddle.to_tensor(len(d.shape))
307 308
                paddle.distributed.send(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
309 310

                # send shape
311
                shape = paddle.to_tensor(d.shape)
312 313
                paddle.distributed.send(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
314

315 316 317 318 319
                # send dtype
                dtype = paddle.to_tensor(paddle_2_number(d.dtype))
                paddle.distributed.send(
                    dtype, peer, use_calc_stream=True, group=self.pp_group)

320 321
    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
322 323
        paddle.distributed.recv(
            tensor_type, peer, use_calc_stream=True, group=self.pp_group)
324
        tensor_type = tensor_type.item()
325 326

        if tensor_type == 0:
327
            # recv len(shape)
328
            dims = paddle.to_tensor([0])
329 330
            paddle.distributed.recv(
                dims, peer, use_calc_stream=True, group=self.pp_group)
331 332 333
            dims = dims.item()

            # recv shape
334
            shape = paddle.to_tensor([0] * dims)
335 336
            paddle.distributed.recv(
                shape, peer, use_calc_stream=True, group=self.pp_group)
337
            shape = shape.numpy().tolist()
338 339 340 341 342 343 344

            # recv dtype
            dtype = paddle.to_tensor([0])
            paddle.distributed.recv(
                dtype, peer, use_calc_stream=True, group=self.pp_group)
            return self._allocate_cache(
                shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
345 346
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
347 348
            paddle.distributed.recv(
                num, peer, use_calc_stream=True, group=self.pp_group)
349
            num = num.item()
350
            shapes = []
351
            dtypes = []
352
            for i in range(num):
353
                # recv len(shape)
354
                dims = paddle.to_tensor([0])
355 356
                paddle.distributed.recv(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
357 358 359

                # recv shape
                dims = dims.item()
360
                shape = paddle.to_tensor([0] * dims)
361 362
                paddle.distributed.recv(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
363 364
                shapes.append(shape.numpy().tolist())

365 366 367 368 369 370 371
                # recv dtype
                dtype = paddle.to_tensor([0])
                paddle.distributed.recv(
                    dtype, peer, use_calc_stream=True, group=self.pp_group)
                dtypes.append(number_2_dtype(dtype.item()))

            caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
372 373
            caches = tuple(caches)
            return caches
374 375 376 377 378 379 380 381 382

    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
383 384 385 386 387
            paddle.distributed.send(
                outputs,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
388 389
        elif isinstance(outputs, tuple):
            for output in outputs:
390 391 392 393 394
                paddle.distributed.send(
                    output,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
395 396 397 398 399 400

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]
        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
            paddle.distributed.send(
401 402 403 404
                paddle.to_tensor(inputs.grad),
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
405 406 407
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
408
                if not is_float_tensor(d):
409 410
                    assert d.grad is None
                    continue
411 412 413 414 415
                paddle.distributed.send(
                    d.grad,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
416 417 418 419 420 421 422 423
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
424 425 426 427 428
            paddle.distributed.recv(
                self.recv_cache,
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
429
            inputs = self.recv_cache.clone().detach()
430
            inputs.stop_gradient = not is_float_tensor(inputs)
431 432 433 434 435 436
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
                assert isinstance(d, paddle.Tensor)

437 438 439 440 441
                paddle.distributed.recv(
                    d,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
442 443 444 445 446
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
447
                d.stop_gradient = not is_float_tensor(d)
448 449 450 451 452 453 454 455

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
456 457 458
                dtype = get_tensor_dtype(outputs.dtype)
                self.grad_tensors = self._allocate_cache(
                    s, dtype, num_caches=1)[0]
459
            else:
460
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
461 462 463 464 465
                dtypes = [
                    get_tensor_dtype(d.dtype) for d in outputs
                    if is_float_tensor(d)
                ]
                self.grad_tensors = self._allocate_caches(
466
                    sizes, dtypes, num_caches=1)[0]
467 468

        if isinstance(self.grad_tensors, paddle.Tensor):
469 470 471 472 473
            paddle.distributed.recv(
                self.grad_tensors,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
474 475 476
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
477 478 479 480 481 482 483
                paddle.distributed.recv(
                    d,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)

    def _step(self):
484
        self.optimizer.step()
485 486 487
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
488 489 490 491 492 493 494 495 496 497 498 499 500

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

501
    def _allocate_cache(self, shape, dtype, num_caches=-1):
502 503 504 505 506 507 508
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

509
    def _allocate_caches(self, shapes, dtypes, num_caches=-1):
510 511 512 513 514
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
515
            for shape, dtype in zip(shapes, dtypes):
516 517 518
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
519 520 521 522 523 524 525 526 527 528 529

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")