未验证 提交 fa878846 编写于 作者: Y Yuang Liu 提交者: GitHub

cherry pick #55651 and #55890 (#56063)

上级 0d920178
...@@ -11,8 +11,10 @@ ...@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import os import os
import sys
import time import time
import warnings import warnings
from collections import defaultdict
import paddle import paddle
from paddle import framework from paddle import framework
...@@ -217,7 +219,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -217,7 +219,7 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap and self._sharding_comm_overlap self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time." ), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = [] self._chunk_2_comm_buffers = defaultdict(list)
self._comm_overlap = ( self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap self._dp_comm_overlap or self._sharding_comm_overlap
) )
...@@ -291,7 +293,9 @@ class PipelineParallel(MetaParallelBase): ...@@ -291,7 +293,9 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp): def register_allreduce_overlap_hook(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
if model.get_num_virtual_stages() > 1: if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks() models = model.get_model_chunks()
else: else:
...@@ -308,7 +312,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -308,7 +312,7 @@ class PipelineParallel(MetaParallelBase):
else HOOK_ACTION.REDUCE else HOOK_ACTION.REDUCE
) )
for model in models: for chunk_idx, model in enumerate(models):
# For virtual pipeline. Will separate parameters in different chunk into # For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance. # different groups to get the best performance.
...@@ -338,12 +342,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -338,12 +342,12 @@ class PipelineParallel(MetaParallelBase):
dst = comm_group.ranks[dst] dst = comm_group.ranks[dst]
else: else:
dst = -1 dst = -1
var_groups = assign_group_by_size(parameter_list) var_groups = assign_group_by_size(parameter_list, group_size)
for group_idx, parameters in var_groups.items(): for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer( buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst group_idx, parameters, comm_group, acc_steps, act, dst
) )
self._comm_buffers.append(buffer) self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters: for param in parameters:
param._register_backward_hook( param._register_backward_hook(
self.bw_hook_func(buffer, param) self.bw_hook_func(buffer, param)
...@@ -514,9 +518,12 @@ class PipelineParallel(MetaParallelBase): ...@@ -514,9 +518,12 @@ class PipelineParallel(MetaParallelBase):
self._flush_records() self._flush_records()
if self._comm_overlap: if self._comm_overlap:
assert len(self._comm_buffers) > 0 assert (
for buffer in self._comm_buffers: len(self._chunk_2_comm_buffers) > 0
buffer.scale_grads() ), "comm buffers should be created"
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_grads()
if self._enable_timer: if self._enable_timer:
self.timers("allreduce_shared_weight_gradients").start() self.timers("allreduce_shared_weight_gradients").start()
...@@ -557,7 +564,7 @@ class PipelineParallel(MetaParallelBase): ...@@ -557,7 +564,7 @@ class PipelineParallel(MetaParallelBase):
self._layers.train() self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0: if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0:
self.register_allreduce_overlap_hook( self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False self._layers, self.sharding_group, self.accumulate_steps, False
) )
...@@ -932,6 +939,39 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -932,6 +939,39 @@ class PipelineParallelWithInterleave(PipelineParallel):
return output_tensor return output_tensor
def _overlap_comm_grads(self):
if self._comm_overlap:
self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.num_stages == 0:
chunk_idx = self._virtual_pp_world_size - (
sync_step // self.num_stages
)
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads()
if self.stage_id != 0:
if (
self._backward_step_count
== self.num_stages * self.num_model_chunks
):
for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads()
def _sync_overlap_grads(self):
if self._comm_overlap:
assert (
self._backward_step_count
== self.num_stages * self.num_model_chunks
), (
"backward step count should be equal to accumulate steps * virtual pp world size,"
f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}"
)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_grads()
def _backward_step_helper(self, micro_step): def _backward_step_helper(self, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False) virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
self.set_virtual_pipeline_rank(virtual_pp_rank) self.set_virtual_pipeline_rank(virtual_pp_rank)
...@@ -955,8 +995,24 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -955,8 +995,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor, output_tensor, output_tensor_grad input_tensor, output_tensor, output_tensor_grad
) )
self._overlap_comm_grads()
return input_tensor_grad return input_tensor_grad
def bw_hook_func(self, buffer, param):
# For pipeline with interleave, we need to add grad to buffer without communication.
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param, use_comm=False)
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
super().register_allreduce_overlap_hook(
model, comm_group, acc_steps, dp, group_size=sys.maxsize
)
def forward_backward_pipeline( def forward_backward_pipeline(
self, self,
data, data,
...@@ -995,6 +1051,19 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -995,6 +1051,19 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.micro_batch_id = 0 self.micro_batch_id = 0
self._forward_only = forward_only self._forward_only = forward_only
# store the number of backward steps
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
)
per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
self._backward_step_count = (
-(per_stage_accumulate_steps - 1)
* self.num_stages
* self.num_model_chunks
)
# init some data buffers for interleave scheduler # init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)] self.input_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensors = [[] for _ in range(self.num_model_chunks)] self.output_tensors = [[] for _ in range(self.num_model_chunks)]
...@@ -1254,10 +1323,7 @@ class PipelineParallelWithInterleave(PipelineParallel): ...@@ -1254,10 +1323,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
) )
) )
if self._comm_overlap: self._sync_overlap_grads()
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_grads()
if static_scheduler: if static_scheduler:
self._reset_counter() self._reset_counter()
......
...@@ -206,7 +206,7 @@ class FusedCommBuffer: ...@@ -206,7 +206,7 @@ class FusedCommBuffer:
and len(self._params_step_dict) == 0 and len(self._params_step_dict) == 0
) )
def add_grad(self, param): def add_grad(self, param, use_comm=True):
assert param.name in self._params_step_dict assert param.name in self._params_step_dict
current_ptr = ( current_ptr = (
param.main_grad.data_ptr() param.main_grad.data_ptr()
...@@ -227,12 +227,17 @@ class FusedCommBuffer: ...@@ -227,12 +227,17 @@ class FusedCommBuffer:
self._params_checked_in += 1 self._params_checked_in += 1
self._params_step_dict.pop(param.name) self._params_step_dict.pop(param.name)
if self._all_params_checked_in: if self._all_params_checked_in and use_comm:
self._comm_grads() self.comm_grads()
@imperative_base.no_grad @imperative_base.no_grad
def _comm_grads(self): def comm_grads(self):
assert self._all_params_checked_in assert self._all_params_checked_in, (
"Not all params checked in."
"Parameter number: {}, Check-in number: {}".format(
len(self._params), self._params_checked_in
)
)
if not self._scale_after_comm: if not self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks scale_factor = 1.0 / self._comm_group.nranks
...@@ -255,7 +260,7 @@ class FusedCommBuffer: ...@@ -255,7 +260,7 @@ class FusedCommBuffer:
@imperative_base.no_grad @imperative_base.no_grad
def scale_grads(self): def scale_grads(self):
assert self._task is not None assert self._task is not None, "Task is not initialized."
self._task.wait() self._task.wait()
if self._scale_after_comm: if self._scale_after_comm:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册