pipeline_parallel.py 17.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

from types import MethodType

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
19
from .pp_utils.utils import is_float_tensor, get_tensor_dtype, paddle_2_number, number_2_dtype
20 21
from .pp_utils import utils
from .parallel_layers.pp_layers import PipelineLayer
22 23 24

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
25
from ..utils.log_util import logger
26
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer
S
ShenLiang 已提交
27
from .pp_utils import p2p_communication as p2p
28

29 30
__all__ = []

31 32 33

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
34 35 36
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
37 38 39 40 41 42 43 44 45 46 47
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_pipe_parallel = self._hcg.get_pipe_parallel_world_size() > 1
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.num_caches = 0
        self.caches = {
            'inputs': [],
            'labels': [],
            'outputs': [],
        }
48

49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
        self.recv_cache = None
        self.grad_tensors = None

        self.send_meta = True

        self.current_loss = paddle.to_tensor(0.0)
        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
        self.prev_stage_id = self.stage_id - 1
        self.next_stage_id = self.stage_id + 1
66
        self.pp_group = self._hcg.get_pipe_parallel_group()
S
ShenLiang 已提交
67
        p2p.initialize_p2p_groups(hcg)
68 69 70 71 72

        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()

73 74 75 76 77 78 79 80
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
81
            logger.info("start broadcast dp parameters")
82
            broadcast_dp_parameters(self._layers, self._hcg)
83

84
    def _init_caches(self, num_caches):
85 86
        if self.num_caches >= num_caches:
            return
87
        self.num_caches = num_caches - self.num_caches
88
        for key in self.caches:
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
            self.caches[key].extend([None] * self.num_caches)

    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.clone() / self.accumulate_steps
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
        else:
            loss = paddle.to_tensor(0.0)
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
108

109 110 111
    def train_batch(self, data, optimizer, lr_scheduler=None):
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
112
        self.optimizer = optimizer
113
        self.lr_scheduler = lr_scheduler
114 115 116
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

117 118
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
119 120
                "For the first and the last stage, the data_iter must be set.")
        else:
121 122
            data = None

123
        self.data = data
124 125
        self._layers.train()

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        # store total loss of entire batch
        self.total_loss = None
        self._init_caches(self.accumulate_steps)
        startup_steps = self.num_stages - self.stage_id - 1
        forward_steps = 0
        backward_steps = 0

        # forward
        while (forward_steps < self.accumulate_steps):
            self._forward(cache_id=forward_steps)
            forward_steps += 1

        # backward
        while (backward_steps < self.accumulate_steps):
            self._backward(cache_id=backward_steps)
            backward_steps += 1

143 144
        self._layers.allreduce_shared_weight_gradients()

145 146 147 148
        # optimizer
        self._step()
        self.train_loss = self._reduce_final_loss()
        return self.train_loss
149 150

    def _forward(self, cache_id):
151 152 153 154 155
        # load data
        self._load_micro_batch(cache_id)
        if self.stage_id != 0:
            self._recv_activations(cache_id)

156
        if isinstance(self.caches['inputs'][cache_id], tuple):
L
lilong12 已提交
157
            inputs = tuple(t for t in self.caches['inputs'][cache_id])
158
        else:
L
lilong12 已提交
159
            inputs = self.caches['inputs'][cache_id]
160 161

        outputs = self._layers.forward(inputs)
162 163
        self._clear_grads(inputs)

164 165
        self.caches['outputs'][cache_id] = outputs

166
        if self.is_last_stage:
167 168 169 170
            if self._layers._loss_fn is not None:
                labels = self.caches['labels'][cache_id]
                outputs = self._layers._loss_fn(outputs, labels)

171
        if self.is_last_stage:
172 173 174 175 176 177 178 179 180 181 182 183
            self.current_loss = outputs
            if isinstance(self.current_loss, paddle.Tensor):
                if self.total_loss is None:
                    self.total_loss = paddle.zeros_like(self.current_loss)
                self.total_loss += self.current_loss.detach()
            else:
                if self.total_loss is None:
                    self.total_loss = [
                        paddle.zeros_like(v) for v in self.current_loss
                    ]
                for idx, v in enumerate(self.current_loss):
                    self.total_loss[idx] += v.detach()
184

185 186
            if self.accumulate_steps > 1:
                self.current_loss = self.current_loss / self.accumulate_steps
187

188
            self.caches['outputs'][cache_id] = self.current_loss.clone()
189

190 191
        else:
            self._send_activations(cache_id)
192 193

    def _backward(self, cache_id):
194
        if self.is_last_stage:
195 196
            paddle.autograd.backward(self.caches['outputs'][cache_id])
            self._send_gradients(cache_id)
197
            return
198
        self._recv_gradients(cache_id)
199 200 201 202 203

        outputs = self.caches['outputs'][cache_id]

        grad_tensors = self.grad_tensors
        if isinstance(outputs, tuple):
204
            out_tensors = [t for t in outputs if is_float_tensor(t)]
205 206 207 208 209 210 211 212
            assert len(out_tensors) == len(grad_tensors)
            paddle.autograd.backward(
                tensors=out_tensors, grad_tensors=grad_tensors)
        else:
            paddle.autograd.backward(
                tensors=[outputs], grad_tensors=[grad_tensors])

        grad_tensors = None
213 214 215
        if self.stage_id != 0: self._send_gradients(cache_id)
        self.caches['outputs'][cache_id] = None

216 217 218 219 220 221
    def _broadcast_data(self, data):
        if isinstance(data, paddle.Tensor):
            paddle.distributed.broadcast(
                data,
                src=self._hcg.get_model_parallel_group_src_rank(),
                group=self._hcg.get_model_parallel_group())
222
        else:
223 224
            for d in data:
                assert isinstance(d, paddle.Tensor)
225
                paddle.distributed.broadcast(
226
                    d,
227 228 229
                    src=self._hcg.get_model_parallel_group_src_rank(),
                    group=self._hcg.get_model_parallel_group())
        return data
230 231

    def _load_micro_batch(self, cache_id):
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[0] = self._broadcast_data(inputs[0])
            if isinstance(inputs[0], tuple):
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[0]
                ]
                self.caches['inputs'][cache_id] = tuple(data)
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['inputs'][cache_id] = inputs[0][begin:end, :].clone(
                ).detach()
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if self.use_model_parallel:
                inputs[1] = self._broadcast_data(inputs[1])
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                data = [
                    input[begin:end, :].clone().detach() for input in inputs[1]
                ]
                self.caches['labels'][cache_id] = tuple(data)
267
            else:
268 269 270 271 272 273 274
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
                self.caches['labels'][cache_id] = inputs[1][begin:end, :].clone(
                ).detach()
        else:
            # No data input is required for other stages
            inputs = None
275 276 277 278

    def _send_meta(self, data, peer):
        if isinstance(data, paddle.Tensor):
            tensor_type = paddle.to_tensor([0])
279
            # send tensor type
S
ShenLiang 已提交
280
            p2p.send(tensor_type, self.next_stage_id)
281 282

            # send len(shape)
283
            dims = paddle.to_tensor(len(data.shape))
S
ShenLiang 已提交
284
            p2p.send(dims, self.next_stage_id)
285 286

            # send shape
287
            shape = paddle.to_tensor(data.shape)
S
ShenLiang 已提交
288
            p2p.send(shape, self.next_stage_id)
289 290 291

            # send dtype
            dtype = paddle.to_tensor(paddle_2_number(data.dtype))
S
ShenLiang 已提交
292
            p2p.send(dtype, self.next_stage_id)
293

294 295
        elif isinstance(data, tuple):
            tensor_type = paddle.to_tensor([1])
S
ShenLiang 已提交
296 297
            p2p.send(tensor_type, self.next_stage_id)

298
            nums = paddle.to_tensor(len(data))
S
ShenLiang 已提交
299 300
            p2p.send(nums, self.next_stage_id)

301 302
            for idx, d in enumerate(data):
                assert isinstance(d, paddle.Tensor)
303
                # send len(shape)
304
                dims = paddle.to_tensor(len(d.shape))
S
ShenLiang 已提交
305
                p2p.send(dims, self.next_stage_id)
306 307

                # send shape
308
                shape = paddle.to_tensor(d.shape)
S
ShenLiang 已提交
309
                p2p.send(shape, self.next_stage_id)
310

311 312
                # send dtype
                dtype = paddle.to_tensor(paddle_2_number(d.dtype))
S
ShenLiang 已提交
313
                p2p.send(dtype, self.next_stage_id)
314

315 316
    def _recv_meta(self, peer):
        tensor_type = paddle.to_tensor([0])
S
ShenLiang 已提交
317 318
        p2p.recv(tensor_type, self.prev_stage_id)

319
        tensor_type = tensor_type.item()
320 321

        if tensor_type == 0:
322
            # recv len(shape)
323
            dims = paddle.to_tensor([0])
S
ShenLiang 已提交
324 325
            p2p.recv(dims, self.prev_stage_id)

326 327 328
            dims = dims.item()

            # recv shape
329
            shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
330 331
            p2p.recv(shape, self.prev_stage_id)

332
            shape = shape.numpy().tolist()
333 334 335

            # recv dtype
            dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
336 337
            p2p.recv(dtype, self.prev_stage_id)

338 339
            return self._allocate_cache(
                shape, dtype=number_2_dtype(dtype.item()), num_caches=1)[0]
340 341
        elif tensor_type == 1:
            num = paddle.to_tensor([0])
S
ShenLiang 已提交
342
            p2p.recv(num, self.prev_stage_id)
343
            num = num.item()
344
            shapes = []
345
            dtypes = []
346
            for i in range(num):
347
                # recv len(shape)
348
                dims = paddle.to_tensor([0])
S
ShenLiang 已提交
349
                p2p.recv(dims, self.prev_stage_id)
350 351 352

                # recv shape
                dims = dims.item()
353
                shape = paddle.to_tensor([0] * dims)
S
ShenLiang 已提交
354
                p2p.recv(shape, self.prev_stage_id)
355 356
                shapes.append(shape.numpy().tolist())

357 358
                # recv dtype
                dtype = paddle.to_tensor([0])
S
ShenLiang 已提交
359
                p2p.recv(dtype, self.prev_stage_id)
360 361 362
                dtypes.append(number_2_dtype(dtype.item()))

            caches = self._allocate_caches(shapes, dtypes, num_caches=1)[0]
363 364
            caches = tuple(caches)
            return caches
365 366 367 368 369 370 371 372 373

    def _send_activations(self, cache_id):
        outputs = self.caches['outputs'][cache_id]

        if self.send_meta:
            self.send_meta = False
            self._send_meta(outputs, self.next_stage_id)

        if isinstance(outputs, paddle.Tensor):
S
ShenLiang 已提交
374 375
            p2p.send(outputs, self.next_stage_id)

376 377
        elif isinstance(outputs, tuple):
            for output in outputs:
S
ShenLiang 已提交
378
                p2p.send(output, self.next_stage_id)
379 380 381 382 383

    def _send_gradients(self, cache_id):
        inputs = self.caches['inputs'][cache_id]
        if isinstance(inputs, paddle.Tensor):
            assert inputs.grad is not None
S
ShenLiang 已提交
384
            p2p.send(inputs.grad, self.prev_stage_id)
385 386 387
        else:
            for idx, d in enumerate(inputs):
                # Skip tensors that will not produce a grad
388
                if not is_float_tensor(d):
389 390
                    assert d.grad is None
                    continue
S
ShenLiang 已提交
391 392
                p2p.send(d.grad, self.prev_stage_id)

393 394 395 396 397 398 399 400
        self.caches['inputs'][cache_id] = None

    def _recv_activations(self, cache_id):
        inputs = None
        if self.recv_cache is None:
            self.recv_cache = self._recv_meta(self.prev_stage_id)

        if isinstance(self.recv_cache, paddle.Tensor):
S
ShenLiang 已提交
401
            p2p.recv(self.recv_cache, self.prev_stage_id)
402
            inputs = self.recv_cache.clone().detach()
403
            inputs.stop_gradient = not is_float_tensor(inputs)
404 405 406 407 408
        else:
            assert isinstance(self.recv_cache, tuple)
            inputs = [None] * len(self.recv_cache)
            for idx, d in enumerate(self.recv_cache):
                assert isinstance(d, paddle.Tensor)
S
ShenLiang 已提交
409
                p2p.recv(d, self.prev_stage_id)
410 411 412 413 414
                inputs[idx] = d.clone().detach()

            inputs = tuple(inputs)

            for d in inputs:
415
                d.stop_gradient = not is_float_tensor(d)
416 417 418 419 420 421 422 423

        self.caches['inputs'][cache_id] = inputs

    def _recv_gradients(self, cache_id):
        outputs = self.caches['outputs'][cache_id]
        if self.grad_tensors is None:
            if isinstance(outputs, paddle.Tensor):
                s = list(outputs.shape)
424 425 426
                dtype = get_tensor_dtype(outputs.dtype)
                self.grad_tensors = self._allocate_cache(
                    s, dtype, num_caches=1)[0]
427
            else:
428
                sizes = [list(d.shape) for d in outputs if is_float_tensor(d)]
429 430 431 432 433
                dtypes = [
                    get_tensor_dtype(d.dtype) for d in outputs
                    if is_float_tensor(d)
                ]
                self.grad_tensors = self._allocate_caches(
434
                    sizes, dtypes, num_caches=1)[0]
435 436

        if isinstance(self.grad_tensors, paddle.Tensor):
S
ShenLiang 已提交
437
            p2p.recv(self.grad_tensors, self.next_stage_id)
438 439 440
        else:
            assert isinstance(outputs, tuple)
            for d in self.grad_tensors:
S
ShenLiang 已提交
441
                p2p.recv(d, self.next_stage_id)
442 443

    def _step(self):
444
        self.optimizer.step()
445 446 447
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()
448 449 450 451 452 453 454 455 456 457 458 459 460

    def _clear_grads(self, inputs):
        if isinstance(inputs, paddle.Tensor):
            if inputs.grad is not None:
                inputs.clear_gradient()
        else:
            for d in inputs:
                if d.grad is not None:
                    d.clear_gradient()

    def _allocate_zeros(self, shape, dtype):
        return paddle.zeros(shape, dtype)

461
    def _allocate_cache(self, shape, dtype, num_caches=-1):
462 463 464 465 466 467 468
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            caches.append(self._allocate_zeros(shape, dtype))
        return caches

469
    def _allocate_caches(self, shapes, dtypes, num_caches=-1):
470 471 472 473 474
        caches = []
        if num_caches == -1:
            num_caches = self.num_caches
        for count in range(num_caches):
            cache = []
475
            for shape, dtype in zip(shapes, dtypes):
476 477 478
                cache.append(self._allocate_zeros(shape, dtype))
            caches.append(cache)
        return caches
479 480 481 482 483 484 485 486 487 488 489

    def save_state_dict(self, model_path):
        state_dict = self._layers.state_dict()
        paddle.save(state_dict, model_path)

    def load_state_dict(self, model_path):
        state_dict = paddle.load(self.model_path)
        self._layers.set_state_dict(state_dict)

    def forward(self, *inputs, **kwargs):
        raise RuntimeError("Call train_batch for pipeline instead of forward.")