pipeline_parallel.py 19.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from types import MethodType

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
19
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype
20 21
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
22 23 24

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
25
from ..utils.log_util import logger
26
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
27

28 29
__all__ = []

30 31 32

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
33 34 35
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
36 37 38 39 40 41 42 43 44 45 46
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
47

48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
65
        self.pp_group = self._hcg.get_pipe_parallel_group()
66 67 68 69 70

        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()

71 72 73 74 75 76 77 78
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
79
            logger.info("start broadcast dp parameters")
80
            broadcast_dp_parameters(self._layers, self._hcg)
81

82
    def _init_caches(self, num_caches):
83 84
        if self.num_caches >= num_caches:
            return
85
        self.num_caches = num_caches - self.num_caches
86
        for key in self.caches:
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
            self.caches[key].extend([None] * self.num_caches)

    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.clone() / self.accumulate_steps
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
        else:
            loss = paddle.to_tensor(0.0)
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
106

107 108 109
    def train_batch(self, data, optimizer, lr_scheduler=None):
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
110
        self.optimizer = optimizer
111
        self.lr_scheduler = lr_scheduler
112 113 114
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

115 116
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
117 118
                "For the first and the last stage, the data_iter must be set.")
        else:
119 120
            data = None

121
        self.data = data
122 123
        self._layers.train()

124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
        # store total loss of entire batch
        self.total_loss = None
        self._init_caches(self.accumulate_steps)
        startup_steps = self.num_stages - self.stage_id - 1
        forward_steps = 0
        backward_steps = 0

        # forward
        while (forward_steps < self.accumulate_steps):
            self._forward(cache_id=forward_steps)
            forward_steps += 1

        # backward
        while (backward_steps < self.accumulate_steps):
            self._backward(cache_id=backward_steps)
            backward_steps += 1

        # optimizer
        self._step()
        self.train_loss = self._reduce_final_loss()
        return self.train_loss
145 146

    def _forward(self, cache_id):
147 148 149 150 151
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

152
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
153
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
154
        else:
L
lilong12 已提交
155
            inputs = self.caches['inputs'][cache_id]
156 157

        outputs = self._layers.forward(inputs)
158 159
        self._clear_grads(inputs)

160 161
        self.caches['outputs'][cache_id] = outputs

162
        if self.is_last_stage:
163 164 165 166
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

167
        if self.is_last_stage:
168 169 170 171 172 173 174 175 176 177 178 179
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
180

181 182
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
183

184
            self.caches['outputs'][cache_id] = self.current_loss.clone()
185

186 187
        else:
            self._send_activations(cache_id)
188 189

    def _backward(self, cache_id):
190
        if self.is_last_stage:
191 192
            paddle.autograd.backward(self.caches['outputs'][cache_id])
            self._send_gradients(cache_id)
193
            return
194
        self._recv_gradients(cache_id)
195 196 197 198 199

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
200
            out_tensors = [t for t in outputs if is_float_tensor(t)]
201 202 203 204 205 206 207 208
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
209 210 211
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None

212 213 214 215 216 217
    def _broadcast_data(self, data):
        if isinstance(data, paddle.Tensor):
            paddle.distributed.broadcast(
                data,
                src=self._hcg.get_model_parallel_group_src_rank(),
                group=self._hcg.get_model_parallel_group())
218
        else:
219 220
            for d in data:
                assert isinstance(d, paddle.Tensor)
221
                paddle.distributed.broadcast(
222
                    d,
223 224 225
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
        return data
226 227

    def _load_micro_batch(self, cache_id):
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[0] = self._broadcast_data(inputs[0])
            if isinstance(inputs[0], tuple):
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[0]
                ]
                self.caches['inputs'][cache_id] = tuple(data)
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone(
                ).detach()
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[1] = self._broadcast_data(inputs[1])
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[1]
                ]
                self.caches['labels'][cache_id] = tuple(data)
263
            else:
264 265 266 267 268 269 270
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone(
                ).detach()
        else:
            # No data input is required for other stages
            inputs = None
271 272 273 274

    def _send_meta(self, data, peer):
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
275
            # send tensor type
276 277
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
278 279

            # send len(shape)
280
            dims = paddle.to_tensor(len(data.shape))
281 282
            paddle.distributed.send(
                dims, peer, use_calc_stream=True, group=self.pp_group)
283 284

            # send shape
285
            shape = paddle.to_tensor(data.shape)
286 287
            paddle.distributed.send(
                shape, peer, use_calc_stream=True, group=self.pp_group)
288 289 290 291 292 293

            # send dtype
            dtype = paddle.to_tensor(paddle_2_number(data.dtype))
            paddle.distributed.send(
                dtype, peer, use_calc_stream=True, group=self.pp_group)

294 295
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
296 297
            paddle.distributed.send(
                tensor_type, peer, use_calc_stream=True, group=self.pp_group)
298
            nums = paddle.to_tensor(len(data))
299 300
            paddle.distributed.send(
                nums, peer, use_calc_stream=True, group=self.pp_group)
301 302
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
303
                # send len(shape)
304
                dims = paddle.to_tensor(len(d.shape))
305 306
                paddle.distributed.send(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
307 308

                # send shape
309
                shape = paddle.to_tensor(d.shape)
310 311
                paddle.distributed.send(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
312

313 314 315 316 317
                # send dtype
                dtype = paddle.to_tensor(paddle_2_number(d.dtype))
                paddle.distributed.send(
                    dtype, peer, use_calc_stream=True, group=self.pp_group)

318 319
    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
320 321
        paddle.distributed.recv(
            tensor_type, peer, use_calc_stream=True, group=self.pp_group)
322
        tensor_type = tensor_type.item()
323 324

        if tensor_type == 0:
325
            # recv len(shape)
326
            dims = paddle.to_tensor([0])
327 328
            paddle.distributed.recv(
                dims, peer, use_calc_stream=True, group=self.pp_group)
329 330 331
            dims = dims.item()

            # recv shape
332
            shape = paddle.to_tensor([0] * dims)
333 334
            paddle.distributed.recv(
                shape, peer, use_calc_stream=True, group=self.pp_group)
335
            shape = shape.numpy().tolist()
336 337 338 339 340 341 342

            # recv dtype
            dtype = paddle.to_tensor([0])
            paddle.distributed.recv(
                dtype, peer, use_calc_stream=True, group=self.pp_group)
            return self._allocate_cache(
                shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
343 344
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
345 346
            paddle.distributed.recv(
                num, peer, use_calc_stream=True, group=self.pp_group)
347
            num = num.item()
348
            shapes = []
349
            dtypes = []
350
            for i in range(num):
351
                # recv len(shape)
352
                dims = paddle.to_tensor([0])
353 354
                paddle.distributed.recv(
                    dims, peer, use_calc_stream=True, group=self.pp_group)
355 356 357

                # recv shape
                dims = dims.item()
358
                shape = paddle.to_tensor([0] * dims)
359 360
                paddle.distributed.recv(
                    shape, peer, use_calc_stream=True, group=self.pp_group)
361 362
                shapes.append(shape.numpy().tolist())

363 364 365 366 367 368 369
                # recv dtype
                dtype = paddle.to_tensor([0])
                paddle.distributed.recv(
                    dtype, peer, use_calc_stream=True, group=self.pp_group)
                dtypes.append(number_2_dtype(dtype.item()))

            caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
370 371
            caches = tuple(caches)
            return caches
372 373 374 375 376 377 378 379 380

    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
381 382 383 384 385
            paddle.distributed.send(
                outputs,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
386 387
        elif isinstance(outputs, tuple):
            for output in outputs:
388 389 390 391 392
                paddle.distributed.send(
                    output,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
393 394 395 396 397 398

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]
        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
            paddle.distributed.send(
399 400 401 402
                paddle.to_tensor(inputs.grad),
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
403 404 405
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
406
                if not is_float_tensor(d):
407 408
                    assert d.grad is None
                    continue
409 410 411 412 413
                paddle.distributed.send(
                    d.grad,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
414 415 416 417 418 419 420 421
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
422 423 424 425 426
            paddle.distributed.recv(
                self.recv_cache,
                self.prev_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
427
            inputs = self.recv_cache.clone().detach()
428
            inputs.stop_gradient = not is_float_tensor(inputs)
429 430 431 432 433 434
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
                assert isinstance(d, paddle.Tensor)

435 436 437 438 439
                paddle.distributed.recv(
                    d,
                    self.prev_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)
440 441 442 443 444
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
445
                d.stop_gradient = not is_float_tensor(d)
446 447 448 449 450 451 452 453

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
454 455 456
                dtype = get_tensor_dtype(outputs.dtype)
                self.grad_tensors = self._allocate_cache(
                    s, dtype, num_caches=1)[0]
457
            else:
458
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
459 460 461 462 463
                dtypes = [
                    get_tensor_dtype(d.dtype) for d in outputs
                    if is_float_tensor(d)
                ]
                self.grad_tensors = self._allocate_caches(
464
                    sizes, dtypes, num_caches=1)[0]
465 466

        if isinstance(self.grad_tensors, paddle.Tensor):
467 468 469 470 471
            paddle.distributed.recv(
                self.grad_tensors,
                self.next_stage_id,
                use_calc_stream=True,
                group=self.pp_group)
472 473 474
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
475 476 477 478 479 480 481
                paddle.distributed.recv(
                    d,
                    self.next_stage_id,
                    use_calc_stream=True,
                    group=self.pp_group)

    def _step(self):
482
        self.optimizer.step()
483 484 485
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
486 487 488 489 490 491 492 493 494 495 496 497 498

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

499
    def _allocate_cache(self, shape, dtype, num_caches=-1):
500 501 502 503 504 505 506
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

507
    def _allocate_caches(self, shapes, dtypes, num_caches=-1):
508 509 510 511 512
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
513
            for shape, dtype in zip(shapes, dtypes):
514 515 516
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
517 518 519 520 521 522 523 524 525 526 527

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")