pipeline_parallel.py 18.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from types import MethodType

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
19
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype
20 21
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
22 23 24

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
25
from ..utils.log_util import logger
26
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28

29 30
__all__ = []

31 32 33

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
34 35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
37 38 39 40 41 42 43 44 45 46 47
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
48

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
66
        self.pp_group = self._hcg.get_pipe_parallel_group()
S
ShenLiang 已提交
67
        p2p.initialize_p2p_groups(hcg)
68 69 70 71 72

        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()

73 74 75 76 77 78 79 80
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
81
            logger.info("start broadcast dp parameters")
82
            broadcast_dp_parameters(self._layers, self._hcg)
83

84
    def _init_caches(self, num_caches):
85 86
        if self.num_caches >= num_caches:
            return
87
        self.num_caches = num_caches - self.num_caches
88
        for key in self.caches:
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            self.caches[key].extend([None] * self.num_caches)

    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.clone() / self.accumulate_steps
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
        else:
            loss = paddle.to_tensor(0.0)
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
108

109
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
110 111
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
112
        self.optimizer = optimizer
113
        self.lr_scheduler = lr_scheduler
114
        self.scaler = scaler
115 116 117
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

118 119
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
120 121
                "For the first and the last stage, the data_iter must be set.")
        else:
122 123
            data = None

124
        self.data = data
125 126
        self._layers.train()

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        # store total loss of entire batch
        self.total_loss = None
        self._init_caches(self.accumulate_steps)
        startup_steps = self.num_stages - self.stage_id - 1
        forward_steps = 0
        backward_steps = 0

        # forward
        while (forward_steps < self.accumulate_steps):
            self._forward(cache_id=forward_steps)
            forward_steps += 1

        # backward
        while (backward_steps < self.accumulate_steps):
            self._backward(cache_id=backward_steps)
            backward_steps += 1

144 145
        self._layers.allreduce_shared_weight_gradients()

146 147
        # optimizer
        self.train_loss = self._reduce_final_loss()
148
        self._step()
149
        return self.train_loss
150 151

    def _forward(self, cache_id):
152 153 154 155 156
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

157
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
158
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
159
        else:
L
lilong12 已提交
160
            inputs = self.caches['inputs'][cache_id]
161 162

        outputs = self._layers.forward(inputs)
163 164
        self._clear_grads(inputs)

165 166
        self.caches['outputs'][cache_id] = outputs

167
        if self.is_last_stage:
168 169 170 171
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

172
        if self.is_last_stage:
173 174 175 176 177 178 179 180 181 182 183 184
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
185

186 187
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
188

189
            self.caches['outputs'][cache_id] = self.current_loss.clone()
190

191 192
        else:
            self._send_activations(cache_id)
193 194

    def _backward(self, cache_id):
195
        if self.is_last_stage:
196 197 198 199 200 201
            if self.scaler:
                paddle.autograd.backward(
                    self.scaler.scale(self.caches['outputs'][cache_id]))
            else:
                paddle.autograd.backward(self.caches['outputs'][cache_id])

202
            self._send_gradients(cache_id)
203
            return
204
        self._recv_gradients(cache_id)
205 206 207 208 209

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
210
            out_tensors = [t for t in outputs if is_float_tensor(t)]
211 212 213 214 215 216 217 218
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
219 220 221
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None

222 223 224 225 226 227
    def _broadcast_data(self, data):
        if isinstance(data, paddle.Tensor):
            paddle.distributed.broadcast(
                data,
                src=self._hcg.get_model_parallel_group_src_rank(),
                group=self._hcg.get_model_parallel_group())
228
        else:
229 230
            for d in data:
                assert isinstance(d, paddle.Tensor)
231
                paddle.distributed.broadcast(
232
                    d,
233 234 235
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
        return data
236 237

    def _load_micro_batch(self, cache_id):
238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[0] = self._broadcast_data(inputs[0])
            if isinstance(inputs[0], tuple):
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[0]
                ]
                self.caches['inputs'][cache_id] = tuple(data)
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone(
                ).detach()
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[1] = self._broadcast_data(inputs[1])
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[1]
                ]
                self.caches['labels'][cache_id] = tuple(data)
273
            else:
274 275 276 277 278 279 280
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone(
                ).detach()
        else:
            # No data input is required for other stages
            inputs = None
281 282 283 284

    def _send_meta(self, data, peer):
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
285
            # send tensor type
S
ShenLiang 已提交
286
            p2p.send(tensor_type, self.next_stage_id)
287 288

            # send len(shape)
289
            dims = paddle.to_tensor(len(data.shape))
S
ShenLiang 已提交
290
            p2p.send(dims, self.next_stage_id)
291 292

            # send shape
293
            shape = paddle.to_tensor(data.shape)
S
ShenLiang 已提交
294
            p2p.send(shape, self.next_stage_id)
295 296 297

            # send dtype
            dtype = paddle.to_tensor(paddle_2_number(data.dtype))
S
ShenLiang 已提交
298
            p2p.send(dtype, self.next_stage_id)
299

300 301
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
S
ShenLiang 已提交
302 303
            p2p.send(tensor_type, self.next_stage_id)

304
            nums = paddle.to_tensor(len(data))
S
ShenLiang 已提交
305 306
            p2p.send(nums, self.next_stage_id)

307 308
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
309
                # send len(shape)
310
                dims = paddle.to_tensor(len(d.shape))
S
ShenLiang 已提交
311
                p2p.send(dims, self.next_stage_id)
312 313

                # send shape
314
                shape = paddle.to_tensor(d.shape)
S
ShenLiang 已提交
315
                p2p.send(shape, self.next_stage_id)
316

317 318
                # send dtype
                dtype = paddle.to_tensor(paddle_2_number(d.dtype))
S
ShenLiang 已提交
319
                p2p.send(dtype, self.next_stage_id)
320

321 322
    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
S
ShenLiang 已提交
323 324
        p2p.recv(tensor_type, self.prev_stage_id)

325
        tensor_type = tensor_type.item()
326 327

        if tensor_type == 0:
328
            # recv len(shape)
329
            dims = paddle.to_tensor([0])
S
ShenLiang 已提交
330 331
            p2p.recv(dims, self.prev_stage_id)

332 333 334
            dims = dims.item()

            # recv shape
335
            shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
336 337
            p2p.recv(shape, self.prev_stage_id)

338
            shape = shape.numpy().tolist()
339 340 341

            # recv dtype
            dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
342 343
            p2p.recv(dtype, self.prev_stage_id)

344 345
            return self._allocate_cache(
                shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
346 347
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
S
ShenLiang 已提交
348
            p2p.recv(num, self.prev_stage_id)
349
            num = num.item()
350
            shapes = []
351
            dtypes = []
352
            for i in range(num):
353
                # recv len(shape)
354
                dims = paddle.to_tensor([0])
S
ShenLiang 已提交
355
                p2p.recv(dims, self.prev_stage_id)
356 357 358

                # recv shape
                dims = dims.item()
359
                shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
360
                p2p.recv(shape, self.prev_stage_id)
361 362
                shapes.append(shape.numpy().tolist())

363 364
                # recv dtype
                dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
365
                p2p.recv(dtype, self.prev_stage_id)
366 367 368
                dtypes.append(number_2_dtype(dtype.item()))

            caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
369 370
            caches = tuple(caches)
            return caches
371 372 373 374 375 376 377 378 379

    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
S
ShenLiang 已提交
380 381
            p2p.send(outputs, self.next_stage_id)

382 383
        elif isinstance(outputs, tuple):
            for output in outputs:
S
ShenLiang 已提交
384
                p2p.send(output, self.next_stage_id)
385 386 387 388 389

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]
        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
S
ShenLiang 已提交
390
            p2p.send(inputs.grad, self.prev_stage_id)
391 392 393
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
394
                if not is_float_tensor(d):
395 396
                    assert d.grad is None
                    continue
S
ShenLiang 已提交
397 398
                p2p.send(d.grad, self.prev_stage_id)

399 400 401 402 403 404 405 406
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
S
ShenLiang 已提交
407
            p2p.recv(self.recv_cache, self.prev_stage_id)
408
            inputs = self.recv_cache.clone().detach()
409
            inputs.stop_gradient = not is_float_tensor(inputs)
410 411 412 413 414
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
                assert isinstance(d, paddle.Tensor)
S
ShenLiang 已提交
415
                p2p.recv(d, self.prev_stage_id)
416 417 418 419 420
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
421
                d.stop_gradient = not is_float_tensor(d)
422 423 424 425 426 427 428 429

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
430 431 432
                dtype = get_tensor_dtype(outputs.dtype)
                self.grad_tensors = self._allocate_cache(
                    s, dtype, num_caches=1)[0]
433
            else:
434
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
435 436 437 438 439
                dtypes = [
                    get_tensor_dtype(d.dtype) for d in outputs
                    if is_float_tensor(d)
                ]
                self.grad_tensors = self._allocate_caches(
440
                    sizes, dtypes, num_caches=1)[0]
441 442

        if isinstance(self.grad_tensors, paddle.Tensor):
S
ShenLiang 已提交
443
            p2p.recv(self.grad_tensors, self.next_stage_id)
444 445 446
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
S
ShenLiang 已提交
447
                p2p.recv(d, self.next_stage_id)
448 449

    def _step(self):
450 451 452 453
        if self.scaler:
            self.scaler.minimize(self.optimizer, self.train_loss)
        else:
            self.optimizer.step()
454 455 456
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
457 458 459 460 461 462 463 464 465 466 467 468 469

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

470
    def _allocate_cache(self, shape, dtype, num_caches=-1):
471 472 473 474 475 476 477
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

478
    def _allocate_caches(self, shapes, dtypes, num_caches=-1):
479 480 481 482 483
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
484
            for shape, dtype in zip(shapes, dtypes):
485 486 487
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
488 489 490 491 492 493 494 495 496 497 498

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")