未验证 提交 fa878846 编写于 作者: Y Yuang Liu 提交者: GitHub

cherry pick #55651 and #55890 (#56063)

上级 0d920178
......@@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import os
import sys
import time
import warnings
from collections import defaultdict
import paddle
from paddle import framework
......@@ -217,7 +219,7 @@ class PipelineParallel(MetaParallelBase):
self._dp_comm_overlap and self._sharding_comm_overlap
), "Cannot use dp pp overlap and sharding pp overlap at the same time."
self._comm_buffers = []
self._chunk_2_comm_buffers = defaultdict(list)
self._comm_overlap = (
self._dp_comm_overlap or self._sharding_comm_overlap
)
......@@ -291,7 +293,9 @@ class PipelineParallel(MetaParallelBase):
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
def register_allreduce_overlap_hook(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
if model.get_num_virtual_stages() > 1:
models = model.get_model_chunks()
else:
......@@ -308,7 +312,7 @@ class PipelineParallel(MetaParallelBase):
else HOOK_ACTION.REDUCE
)
for model in models:
for chunk_idx, model in enumerate(models):
# For virtual pipeline. Will separate parameters in different chunk into
# different groups to get the best performance.
......@@ -338,12 +342,12 @@ class PipelineParallel(MetaParallelBase):
dst = comm_group.ranks[dst]
else:
dst = -1
var_groups = assign_group_by_size(parameter_list)
var_groups = assign_group_by_size(parameter_list, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._comm_buffers.append(buffer)
self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
......@@ -514,8 +518,11 @@ class PipelineParallel(MetaParallelBase):
self._flush_records()
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
assert (
len(self._chunk_2_comm_buffers) > 0
), "comm buffers should be created"
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_grads()
if self._enable_timer:
......@@ -557,7 +564,7 @@ class PipelineParallel(MetaParallelBase):
self._layers.train()
if self._sharding_comm_overlap and len(self._comm_buffers) == 0:
if self._sharding_comm_overlap and len(self._chunk_2_comm_buffers) == 0:
self.register_allreduce_overlap_hook(
self._layers, self.sharding_group, self.accumulate_steps, False
)
......@@ -932,6 +939,39 @@ class PipelineParallelWithInterleave(PipelineParallel):
return output_tensor
def _overlap_comm_grads(self):
if self._comm_overlap:
self._backward_step_count += 1
sync_step = self._backward_step_count - self.stage_id
if sync_step > 0 and sync_step % self.num_stages == 0:
chunk_idx = self._virtual_pp_world_size - (
sync_step // self.num_stages
)
for buffer in self._chunk_2_comm_buffers[chunk_idx]:
buffer.comm_grads()
if self.stage_id != 0:
if (
self._backward_step_count
== self.num_stages * self.num_model_chunks
):
for buffer in self._chunk_2_comm_buffers[0]:
buffer.comm_grads()
def _sync_overlap_grads(self):
if self._comm_overlap:
assert (
self._backward_step_count
== self.num_stages * self.num_model_chunks
), (
"backward step count should be equal to accumulate steps * virtual pp world size,"
f" but get {self._backward_step_count}, excepted result is {self.num_stages * self.num_model_chunks}"
)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
buffer.scale_grads()
def _backward_step_helper(self, micro_step):
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=False)
self.set_virtual_pipeline_rank(virtual_pp_rank)
......@@ -955,8 +995,24 @@ class PipelineParallelWithInterleave(PipelineParallel):
input_tensor, output_tensor, output_tensor_grad
)
self._overlap_comm_grads()
return input_tensor_grad
def bw_hook_func(self, buffer, param):
# For pipeline with interleave, we need to add grad to buffer without communication.
# Use communication where appropriate to avoid dp communication and pp scheduling conflicts.
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param, use_comm=False)
return fused_allreduce
def register_allreduce_overlap_hook(self, model, comm_group, acc_steps, dp):
super().register_allreduce_overlap_hook(
model, comm_group, acc_steps, dp, group_size=sys.maxsize
)
def forward_backward_pipeline(
self,
data,
......@@ -995,6 +1051,19 @@ class PipelineParallelWithInterleave(PipelineParallel):
self.micro_batch_id = 0
self._forward_only = forward_only
# store the number of backward steps
assert (
self.accumulate_steps % self.num_stages == 0
), "accumulate_steps({}) should be evenly divisible by num_stages({}) for pipeline with interleave".format(
self.accumulate_steps, self.num_stages
)
per_stage_accumulate_steps = self.accumulate_steps // self.num_stages
self._backward_step_count = (
-(per_stage_accumulate_steps - 1)
* self.num_stages
* self.num_model_chunks
)
# init some data buffers for interleave scheduler
self.input_tensors = [[] for _ in range(self.num_model_chunks)]
self.output_tensors = [[] for _ in range(self.num_model_chunks)]
......@@ -1254,10 +1323,7 @@ class PipelineParallelWithInterleave(PipelineParallel):
)
)
if self._comm_overlap:
assert len(self._comm_buffers) > 0
for buffer in self._comm_buffers:
buffer.scale_grads()
self._sync_overlap_grads()
if static_scheduler:
self._reset_counter()
......
......@@ -206,7 +206,7 @@ class FusedCommBuffer:
and len(self._params_step_dict) == 0
)
def add_grad(self, param):
def add_grad(self, param, use_comm=True):
assert param.name in self._params_step_dict
current_ptr = (
param.main_grad.data_ptr()
......@@ -227,12 +227,17 @@ class FusedCommBuffer:
self._params_checked_in += 1
self._params_step_dict.pop(param.name)
if self._all_params_checked_in:
self._comm_grads()
if self._all_params_checked_in and use_comm:
self.comm_grads()
@imperative_base.no_grad
def _comm_grads(self):
assert self._all_params_checked_in
def comm_grads(self):
assert self._all_params_checked_in, (
"Not all params checked in."
"Parameter number: {}, Check-in number: {}".format(
len(self._params), self._params_checked_in
)
)
if not self._scale_after_comm:
scale_factor = 1.0 / self._comm_group.nranks
......@@ -255,7 +260,7 @@ class FusedCommBuffer:
@imperative_base.no_grad
def scale_grads(self):
assert self._task is not None
assert self._task is not None, "Task is not initialized."
self._task.wait()
if self._scale_after_comm:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册