未验证 提交 264ff9ef 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Support finetinue model for PipelineParallel (#35287)

* add cache for send_recv

* add eval_batch for pipeline

* add eval batch for pipelineparallel

* add style code
上级 bee511d5
......@@ -158,6 +158,7 @@ message PipelineConfig {
optional int32 micro_batch_size = 1 [ default = 1 ];
optional int32 accumulate_steps = 2 [ default = 1 ];
optional string schedule_mode = 3 [ default = '1F1B' ];
optional bool p2p_cache_shape = 4 [ default = true ];
}
message TensorParallelConfig {
......
......@@ -42,11 +42,13 @@ class PipelineParallel(MetaParallelBase):
self.accumulate_steps = self._strategy.pipeline_configs[
'accumulate_steps']
self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']
self.num_stages = self._hcg.get_pipe_parallel_world_size()
self.stage_id = self._hcg.get_stage_id()
self.pp_group = self._hcg.get_pipe_parallel_group()
p2p.initialize_p2p_groups(hcg)
p2p.initialize_p2p_groups(hcg, self._using_cache)
_initialize_recompute_hcg(hcg)
......@@ -55,6 +57,8 @@ class PipelineParallel(MetaParallelBase):
self.global_rank = self._hcg.get_global_rank()
self.micro_batch_id = 0
self._compute_loss = True
logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
self.num_stages, self.stage_id))
......@@ -85,6 +89,7 @@ class PipelineParallel(MetaParallelBase):
self.lr_scheduler = lr_scheduler
self.scaler = scaler
self.data = data
self._compute_loss = True
self._layers.train()
......@@ -151,12 +156,57 @@ class PipelineParallel(MetaParallelBase):
self._layers.allreduce_shared_weight_gradients()
self.train_loss = self._reduce_final_loss()
self.train_loss = self._broadcast_final_loss()
# optimizer
self._optimizer_step()
return self.train_loss
def eval_batch(self, data, compute_loss=False):
self._layers.eval()
self._compute_loss = compute_loss
# save data for eval
self.data = data
# store data id for micro_batch
self.micro_batch_id = 0
# store total loss of entire batch
self.total_loss = None
startup_steps = (self.num_stages - self.stage_id - 1)
startup_steps = min(startup_steps, self.accumulate_steps)
steady_steps = self.accumulate_steps - startup_steps
input_buffers = []
output_buffers = []
for step_id in range(startup_steps):
input_tensor = p2p.recv_forward()
output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor)
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if steady_steps > 0:
input_tensor = p2p.recv_forward()
for i in range(steady_steps):
last_iter = (i == (steady_steps - 1))
output_tensor = self._forward_step(input_tensor)
p2p.send_forward(output_tensor)
input_buffers.append(input_tensor)
output_buffers.append(output_tensor)
if not last_iter:
input_tensor = p2p.recv_forward()
return self.total_loss if self._compute_loss else output_buffers
def _forward_step(self, input_tensor):
if self.stage_id == 0:
input_tensor = self._load_micro_batch(self.micro_batch_id)
......@@ -164,11 +214,14 @@ class PipelineParallel(MetaParallelBase):
output_tensor = self._layers.forward(input_tensor)
if self.is_last_stage:
# train calculate loss for train
if self._compute_loss:
assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
labels = self._load_micro_batch(self.micro_batch_id)
output_tensor = self._layers._loss_fn(output_tensor, labels)
assert isinstance(
output_tensor, paddle.
Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype"
output_tensor, paddle.Tensor
), "Currently, loss_fn should obtain Paddle.Tensor dtype"
if self.accumulate_steps > 1:
output_tensor = output_tensor / self.accumulate_steps
......@@ -245,7 +298,7 @@ class PipelineParallel(MetaParallelBase):
# No data input is required for other stages
inputs = None
def _reduce_final_loss(self):
def _broadcast_final_loss(self):
if self.is_last_stage:
assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
loss = self.total_loss.detach()
......
......@@ -19,11 +19,13 @@ import numpy as np
from paddle import _C_ops
_hcg = None
_use_cache = False
def initialize_p2p_groups(hcg):
global _hcg
def initialize_p2p_groups(hcg, use_cache=True):
global _hcg, _use_cache
_hcg = hcg
_use_cache = use_cache
send_next_group, send_prev_group, recv_next_group, recv_prev_group = _hcg.get_p2p_groups(
)
......@@ -372,7 +374,7 @@ def recv_forward():
else:
if not _send_recv_meta.has_recv_meta:
_send_recv_meta.recv_meta(_hcg.recv_prev_group)
_send_recv_meta.has_recv_meta = True
_send_recv_meta.has_recv_meta = _use_cache
input_tensor, _ = _p2p_helper(
tensor_send_next=None,
......@@ -399,7 +401,7 @@ def send_forward(output_tensor):
if not _send_recv_meta.has_send_meta:
_send_recv_meta.set_send_message(output_tensor)
_send_recv_meta.send_meta(output_tensor, _hcg.send_next_group)
_send_recv_meta.has_send_meta = True
_send_recv_meta.has_send_meta = _use_cache
_p2p_helper(
tensor_send_next=output_tensor,
......
......@@ -177,10 +177,13 @@ class TestDistPPTraning(unittest.TestCase):
x_data = np.random.randint(0, vocab_size, size=[batch_size, length])
x = paddle.to_tensor(x_data)
x.stop_gradient = True
e_loss = model.eval_batch([x, x], True)
loss = model.train_batch([x, x], optimizer, scheduler)
# TODO(shenliang03) add utest for loss
print("loss: ", loss)
# TODO(shenliang03) add utest for loss
if pp_id != 0:
np.testing.assert_allclose(loss.numpy(), e_loss.numpy())
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册