未验证 提交 589d13c5 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Add Recompute for PipeLineParallel (#34607)

* add recompute for pp

* add recompute offload

* add recompute partition
上级 cfa69133
......@@ -136,7 +136,6 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"c_reduce_min", {"Out"}},
{"c_reduce_prod", {"Out"}},
{"c_reduce", {"Out"}},
{"c_allgather", {"Out"}},
{"c_scatter", {"Out"}},
{"barrier", {"Out"}},
{"fake_quantize_dequantize_moving_average_abs_max",
......
......@@ -632,14 +632,13 @@ def all_gather(tensor_list, tensor, group=None, use_calc_stream=True):
ring_id = 0 if group is None else group.id
nranks = _get_global_group().nranks if group is None else group.nranks
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if in_dygraph_mode():
_C_ops.c_allgather(tensor, out, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
out = _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
else:
op_type = 'c_allgather'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
if not isinstance(tensor_list, list):
raise ValueError("The type of 'tensor_list' for all_gather "
"should be list.")
......
......@@ -23,6 +23,7 @@ from functools import partial
import paddle
from paddle.fluid.dygraph.layers import Layer
from ...utils.log_util import logger, layer_to_str
from ..pp_utils.utils import _hp_recompute, _initialize_recompute_setting
__all__ = []
......@@ -134,7 +135,10 @@ class PipelineLayer(Layer):
num_stages=None,
topology=None,
loss_fn=None,
seg_method="uniform"):
seg_method="uniform",
recompute_interval=0,
recompute_offload=False,
recompute_partition=False):
super(PipelineLayer, self).__init__()
if num_stages is None and topology is None:
raise ValueError("should provide num_stages or topology")
......@@ -147,6 +151,16 @@ class PipelineLayer(Layer):
self.layers = layers
self._loss_fn = loss_fn
self._topo = topology
self._recompute_interval = recompute_interval
self._recompute_offload = recompute_offload
self._recompute_partition = recompute_partition
if recompute_interval > 0:
logger.info(
"Start Recompute for PipeLineParallel. recompute_offload: {}, recompute_partition: {}".
format(recompute_offload, recompute_partition))
_initialize_recompute_setting(recompute_offload, recompute_partition)
world_size = dist.get_world_size()
self.global_rank = dist.get_rank()
......@@ -312,11 +326,44 @@ class PipelineLayer(Layer):
else:
self.run_function.append(layer)
def forward_function(self, start, end):
def execute_func(*x):
if len(x) == 1:
x = x[0]
for idx, layer in enumerate(self.run_function[start:end]):
x = layer(x)
return x
return execute_func
def forward(self, input):
for layer in self.run_function:
input = layer(input)
if self._recompute_interval == 0:
input = self.forward_function(0, len(self.run_function))(input)
else:
num_layers = len(self.run_function)
for start_idx in range(0, num_layers, self._recompute_interval):
end_idx = min(start_idx + self._recompute_interval, num_layers)
funcs = self.run_function[start_idx:end_idx]
if not isinstance(input, tuple):
input = (input, )
if self._need_recompute(funcs, input):
input = _hp_recompute(
self.forward_function(start_idx, end_idx), *input)
else:
input = self.forward_function(start_idx, end_idx)(*input)
return input
def _need_recompute(self, funcs, inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
return False
params = [f.parameters() for f in funcs if isinstance(f, Layer)]
return any(len(list(p)) > 0 for p in params)
def save_state_dict(self, path):
if self._topo.get_coord(self.global_rank).data != 0:
return
......
......@@ -20,6 +20,9 @@ __all__ = []
MODEL_PARALLEL_RNG = 'model_parallel_rng'
# This file is inspired by Megatron to control random states for MP:
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/mpu/random.py
class RNGStatesTracker:
"""
......@@ -46,6 +49,15 @@ class RNGStatesTracker:
self.states_[name] = paddle.get_cuda_rng_state()
paddle.set_cuda_rng_state(orig_rng_state)
def get_states_tracker(self):
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states_tracker(self, states):
self.states_ = states
@contextlib.contextmanager
def rng_state(self, name=MODEL_PARALLEL_RNG):
if name not in self.states_:
......
......@@ -14,7 +14,7 @@
import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
from .pp_utils.utils import is_float_tensor
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
from .parallel_layers.pp_layers import PipelineLayer
from ..utils.hybrid_parallel_util import broadcast_mp_parameters
......@@ -48,6 +48,8 @@ class PipelineParallel(MetaParallelBase):
p2p.initialize_p2p_groups(hcg)
_initialize_recompute_hcg(hcg)
self.is_first_stage = self.stage_id == 0
self.is_last_stage = (self.stage_id == (self.num_stages - 1))
self.global_rank = self._hcg.get_global_rank()
......@@ -213,6 +215,9 @@ class PipelineParallel(MetaParallelBase):
if self.is_first_stage:
assert len(inputs) == 2, "length of input should be 2"
if isinstance(inputs[0], tuple):
assert len(
inputs[0]
) > 1, "If you use tuple for input data, it should have at least two inputs."
batch_size = inputs[0][0].shape[0]
assert self.micro_batch_size * self.accumulate_steps == batch_size, (
"batch_size needs to be divisible by micro_batch_size. Currently, "
......
......@@ -12,9 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import contextlib
import paddle
from ...utils import log_util as hp_util
from paddle.fluid import core
from paddle import _C_ops
import paddle.distributed as dist
from paddle.autograd import PyLayer
from paddle.fluid import framework
from paddle.distributed.fleet.utils.recompute import check_recompute_necessary, detach_variable
from ..parallel_layers.random import get_rng_state_tracker
__all__ = []
......@@ -79,3 +86,222 @@ def get_tensor_bytes(tensor):
else:
raise ValueError("unknown data type: {}".format(tensor.dtype))
return tensor.numel() * elem_size
_hcg = None
_recompute_offload = False
_recompute_partition = False
def _initialize_recompute_setting(is_offload, is_partition):
global _recompute_offload, _recompute_partition
_recompute_offload = is_offload
_recompute_partition = is_partition
def _initialize_recompute_hcg(hcg):
global _hcg
_hcg = hcg
def _all_gather(tensor, group=None, use_calc_stream=True):
"""
The main difference with paddle.distributed.all_gather:
no need to pass in tensor_list, the returned tensor is spliced
"""
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
nranks = paddle.distributed.collective._get_global_group(
).nranks if group is None else group.nranks
return _C_ops.c_allgather(tensor, 'use_calc_stream', use_calc_stream,
'ring_id', ring_id, 'nranks', nranks)
def _split_activation(tensor):
global _hcg
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
if mp_degree < 2:
return tensor
tensor_numel = paddle.numel(tensor)
assert tensor_numel != 0, "can't recompute zero element"
assert tensor_numel % mp_degree == 0, "The capacity of the activation () cannot be divisible by mp_degree()".format(
tensor_numel, mp_degree)
# use inplace operation to save memory
data = tensor.flatten_()
part_size = tensor_numel // mp_degree
start = part_size * mp_rank
end = start + part_size
return data[start:end]
def _merge_activation(tensor):
global _hcg
mp_degree = _hcg.get_model_parallel_world_size()
mp_rank = _hcg.get_model_parallel_rank()
mp_group = _hcg.get_model_parallel_group()
if mp_degree < 2:
return tensor
return _all_gather(tensor, group=mp_group)
@contextlib.contextmanager
def _swith_rng_state_tracker(rng_state, tracker):
orig_cuda_rng_state = paddle.get_cuda_rng_state()
orig_cuda_rng_tracker = get_rng_state_tracker().get_states_tracker()
paddle.set_cuda_rng_state(rng_state)
get_rng_state_tracker().set_states_tracker(tracker)
try:
yield
finally:
paddle.set_cuda_rng_state(orig_cuda_rng_state)
get_rng_state_tracker().set_states_tracker(orig_cuda_rng_tracker)
class _HPRecomputeFunction(PyLayer):
"""
Compared with paddle.distributed.fleet.utils.recompute, there are the following differences:
1. In order to support PipeLineParallel, the input of recompute is modified to ensure that the input can be tuple type.
2. Offload support for activation
3. Support MP segmentation of activation to further reduce cuda memory
4. Adapt to the random state of MP
"""
@staticmethod
def forward(ctx, run_function, all_outputs, *args):
check_recompute_necessary(args)
# store for recomputing
ctx.run_function = run_function
# store the rng states
ctx.fwd_cuda_rng_state = paddle.get_cuda_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_rng_state_tracker(
).get_states_tracker()
# save input for backward
ctx.inputs = []
ctx.tensor_indices = []
ctx.tensor_shapes = []
tensor_inputs = []
cur_device = paddle.get_device()
assert 'gpu:' in paddle.get_device(
), "Recompute with RNG is not support current device: {}.".format(
cur_device)
# TODO support AMP
tracer = framework._dygraph_tracer()
ctx.is_fw_autocast = tracer._enable_autocast
ctx.amp_white_list, ctx.amp_black_list = tracer._get_amp_op_list()
with paddle.no_grad():
outputs = run_function(*args)
for i, arg in enumerate(args):
if paddle.is_tensor(arg):
state = arg.stop_gradient
if _recompute_partition:
ctx.tensor_shapes.append(arg.shape)
partition = _split_activation(arg.detach()).clone()
# TODO(shenliang03) not use calculate stream to D2H to speed
arg = partition.cpu() if _recompute_offload else partition
else:
arg = arg.cpu() if _recompute_offload else arg
arg.stop_gradient = state
tensor_inputs.append(arg)
ctx.tensor_indices.append(i)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)
ctx.save_for_backward(*tensor_inputs)
if paddle.is_tensor(outputs):
all_outputs += [outputs]
return outputs
else:
all_outputs += outputs
return tuple(outputs)
@staticmethod
def backward(ctx, *args):
with paddle.fluid.dygraph.guard():
# Restore inputs
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensor_shapes = ctx.tensor_shapes
tensors = list(ctx.saved_tensor())
device_id = dist.ParallelEnv().device_id
for i, idx in enumerate(tensor_indices):
if _recompute_partition:
state = tensors[i].stop_gradient
tensors[i] = _merge_activation(tensors[i]).detach(
).reshape_(tensor_shapes[i])
tensors[i].stop_gradient = state
inputs[idx] = tensors[i].cuda(
device_id) if _recompute_offload else tensors[i]
tracer = framework._dygraph_tracer()
tracer._has_grad = True
# need restore auto_cast state as well as w/b list
with _swith_rng_state_tracker(ctx.fwd_cuda_rng_state,
ctx.fwd_cuda_rng_state_tracker):
with paddle.amp.auto_cast(
enable=ctx.is_fw_autocast,
custom_white_list=ctx.amp_white_list,
custom_black_list=ctx.amp_black_list):
detached_inputs = detach_variable(tuple(inputs))
outputs = ctx.run_function(*detached_inputs)
if isinstance(outputs, core.VarBase):
outputs = (outputs, )
assert len(outputs) == len(args)
forward_outputs_with_grad = []
backward_inputs = []
for i in range(len(outputs)):
if isinstance(outputs[i],
core.VarBase) and not outputs[i].stop_gradient:
forward_outputs_with_grad.append(outputs[i])
backward_inputs.append(args[i])
if len(forward_outputs_with_grad) == 0:
raise RuntimeError(
"none of output has stop_gradient=False, this recompute() is not necessary"
)
# actually backward
paddle.autograd.backward(forward_outputs_with_grad, backward_inputs)
grads = list(inp._grad_ivar() for inp in detached_inputs
if isinstance(inp, core.VarBase))
return grads
def _hp_recompute(function, *args):
# NODTE(shenliang03)The current hybrid parallel recompute has limitations.
# It cannot handle the following situations:
# 1. The calculation output of recompute, there are tensors that do not require gradients.
# 2. The forward output tensor has no gradient. This problem can be solved temporarily by detach().
# 3. Here, we only use float dtype to distinguish whether a gradient is needed in output tensor
all_outputs = []
_HPRecomputeFunction.apply(function, all_outputs, *args)
if len(all_outputs) == 1:
return all_outputs[0]
else:
for output in all_outputs:
if paddle.is_tensor(output) and not is_float_tensor(output):
output.stop_gradient = True
return tuple(all_outputs)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import unittest
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.distributed.fleet as fleet
from paddle.fluid import layers
import paddle.nn.functional as F
from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc
from paddle.fluid.dygraph.layers import Layer
import paddle.nn as nn
def set_random_seed(seed, dp_id, rank_id):
"""Set random seed for reproducability."""
random.seed(seed)
np.random.seed(seed + dp_id)
paddle.seed(seed + dp_id)
batch_size = 8
length = 8
micro_batch_size = 2
vocab_size = 128
hidden_size = 16
d_model = hidden_size
dim_feedforward = 4 * d_model
class EmbeddingNet(Layer):
def __init__(self):
super(EmbeddingNet, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(vocab_size, hidden_size)
def forward(self, x):
w_emb = self.word_embeddings(x)
p_emb = self.position_embeddings(x)
w_emb = w_emb + p_emb
return w_emb
class TransformerNet(Layer):
def __init__(self):
super(TransformerNet, self).__init__()
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5)
def forward(self, x):
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5)
weights = F.softmax(product)
weights = F.dropout(weights, 0.2)
tgt = layers.matmul(weights, v)
residual = tgt
tgt = self.norm1(tgt)
tgt = residual + tgt
out = self.linear2(F.gelu(self.linear1(tgt), approximate=True))
return out
class EmbeddingPipe(EmbeddingNet):
def forward(self, x):
return super().forward(x)
class TransformerNetPipe(TransformerNet):
def forward(self, x):
output = super().forward(x)
return output
class CriterionPipe(Layer):
def __init__(self):
super(CriterionPipe, self).__init__()
def forward(self, out, label):
loss = out.mean()
return loss
class ModelPipe(PipelineLayer):
def __init__(self, topology):
self.descs = []
self.descs.append(LayerDesc(EmbeddingPipe))
for x in range(2):
self.descs.append(LayerDesc(TransformerNetPipe))
super().__init__(
layers=self.descs,
loss_fn=CriterionPipe(),
topology=topology,
seg_method="layer:TransformerNetPipe",
recompute_interval=1,
recompute_partition=False,
recompute_offload=False)
class TestDistPPTraning(unittest.TestCase):
def setUp(self):
strategy = fleet.DistributedStrategy()
self.model_parallel_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 2
strategy.hybrid_configs = {
"dp_degree": self.data_parallel_size,
"mp_degree": self.model_parallel_size,
"pp_degree": self.pipeline_parallel_size,
}
strategy.pipeline_configs = {
"accumulate_steps": batch_size // micro_batch_size,
"micro_batch_size": micro_batch_size
}
fleet.init(is_collective=True, strategy=strategy)
def test_pp_model(self):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
dp_id = hcg.get_data_parallel_rank()
pp_id = hcg.get_stage_id()
rank_id = dist.get_rank()
topology = hcg.topology()
set_random_seed(1024, dp_id, rank_id)
model = ModelPipe(topology)
scheduler = paddle.optimizer.lr.PiecewiseDecay(
boundaries=[2], values=[0.001, 0.002], verbose=True)
optimizer = paddle.optimizer.SGD(learning_rate=scheduler,
parameters=model.parameters())
model = fleet.distributed_model(model)
optimizer = fleet.distributed_optimizer(optimizer)
for step_id in range(5):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
print("loss: ", loss)
if __name__ == "__main__":
unittest.main()
......@@ -39,6 +39,9 @@ class TestHybridPipeParallel(TestMultipleGpus):
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_save_load.py')
def test_hybrid_parallel_transformer(self):
self.run_mnist_2gpu('hybrid_parallel_pp_recompute.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册