未验证 提交 7c764060 编写于 作者: S sneaxiy 提交者: GitHub

Align VPP global norm clip with PP (#54820)

上级 cb1a50f5
......@@ -13,6 +13,11 @@
# limitations under the License.
import distutils.util
import os
import numpy as np
import paddle
from paddle import framework
from paddle.autograd import no_grad
......@@ -42,9 +47,131 @@ class HybridParallelClipGrad:
self._clip = clip
self._hcg = hcg
self.not_sharding_stage1 = True
self._vpp_chunk_num = None
self._force_align_vpp_grad_sum_order = distutils.util.strtobool(
os.getenv('FLAGS_force_align_vpp_grad_sum_order', '1')
)
def _get_vpp_chunk_num(self, params_grads):
chunk_num = -1
for p, g in params_grads:
if g is None:
continue
chunk_info = getattr(p, '_chunk_info', {})
cur_chunk_num = chunk_info.get('chunk_num', -1)
if chunk_num < 0:
chunk_num = cur_chunk_num
else:
assert chunk_num == cur_chunk_num
return chunk_num
@no_grad()
def _vpp_dygraph_clip(self, params_grads, chunk_num):
pp_group = self._hcg.get_pipe_parallel_group()
pp_rank = self._hcg.get_stage_id()
pp_size = self._hcg.get_pipe_parallel_world_size()
if self._vpp_chunk_num is None:
all_chunk_nums = []
paddle.distributed.all_gather_object(
all_chunk_nums, chunk_num, group=pp_group
)
assert all([chunk_num == n for n in all_chunk_nums])
self._vpp_chunk_num = chunk_num
else:
assert self._vpp_chunk_num == chunk_num
sum_square_metas = []
for p, g in params_grads:
if g is None:
continue
not_shared_enable = (not hasattr(p, 'is_firstly_shared')) or (
hasattr(p, 'is_firstly_shared')
and getattr(p, 'is_firstly_shared', True)
)
chunk_id = p._chunk_info['chunk_id']
if not_shared_enable:
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = clip.merge_selected_rows(g)
g = clip.get_tensor_from_selected_rows(merge_grad)
square = paddle.square(g)
sum_square = paddle.sum(square)
layer_id = chunk_id * pp_size + pp_rank
sum_square_metas.append(
[layer_id, p.is_distributed, sum_square.numpy()]
)
all_sum_square_metas = []
paddle.distributed.all_gather_object(
all_sum_square_metas,
sum_square_metas,
group=pp_group,
)
# order: FP16, BF16, FP32
sum_square_dist = [[], [], []]
sum_square_not_dist = [[], [], []]
pp_stage = self._hcg.get_stage_id()
for i, metas in enumerate(all_sum_square_metas):
for layer_id, is_distributed, sum_square in metas:
rank = layer_id // chunk_num
assert rank < pp_size
if rank != pp_rank:
continue
if sum_square.dtype == np.float32:
idx = 2
elif sum_square.dtype == np.float16:
idx = 0
else:
assert (
sum_square.dtype == np.uint16
), "The data type of grad must be FP32, FP16 or BF16, but got {}".format(
sum_square.dtype
)
idx = 1
if is_distributed:
sum_square_dist[idx].append(sum_square)
else:
sum_square_not_dist[idx].append(sum_square)
global_norm_var_dist = self._add_sum_squares(sum_square_dist)
global_norm_var_not_dist = self._add_sum_squares(sum_square_not_dist)
return self._comm_and_clip(
params_grads, global_norm_var_dist, global_norm_var_not_dist
)
def _add_sum_squares(self, sum_squares):
norm_sum = None
for sq in sum_squares:
if len(sq) == 0:
continue
sq = np.concatenate(sq, axis=0).flatten()
sq = paddle.to_tensor(sq)
sq = paddle.sum(sq)
if sq.dtype != paddle.float32:
sq = sq.astype(paddle.float32)
if norm_sum is None:
norm_sum = sq
else:
norm_sum = norm_sum + sq
if norm_sum is None:
norm_sum = paddle.to_tensor([0.0], dtype=paddle.float32)
return norm_sum
@no_grad()
def _dygraph_clip(self, params_grads):
chunk_num = self._get_vpp_chunk_num(params_grads)
if chunk_num > 0 and self._force_align_vpp_grad_sum_order:
return self._vpp_dygraph_clip(params_grads, chunk_num)
sum_square_dist_fp16 = []
sum_square_dist_bf16 = []
sum_square_dist_fp32 = []
......@@ -159,7 +286,13 @@ class HybridParallelClipGrad:
+ global_norm_not_dist_bf16
+ global_norm_not_dist_fp32
)
return self._comm_and_clip(
params_grads, global_norm_var_dist, global_norm_var_not_dist
)
def _comm_and_clip(
self, params_grads, global_norm_var_dist, global_norm_var_not_dist
):
# add all reduce to get global norm of distributed params_and_grads
if self._hcg.get_model_parallel_world_size() > 1:
sharding_flag = False
......
......@@ -638,6 +638,13 @@ class PipelineLayer(nn.Layer):
logger.info(f"loss: {self._loss_fn.__class__.__name__}")
def _build_layer_with_interleave(self):
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for i in range(len(self._start_poss)):
start = self._start_poss[i]
end = self._end_poss[i]
......@@ -648,10 +655,21 @@ class PipelineLayer(nn.Layer):
self._model_chunks.append(chunk)
self.add_sublayer(str(start), chunk)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer(self):
start = self._start_pos
end = self._end_pos
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
self.run_function = self._build_layer_impl(start, end)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
def _build_layer_impl(self, start, end):
if self._num_virtual_pipeline_stages > 1:
......@@ -661,13 +679,6 @@ class PipelineLayer(nn.Layer):
# For 1f1b scheduler, just use run_function list
run_function = self.run_function
from paddle.distributed.fleet.meta_parallel.parallel_layers.random import (
get_rng_state_tracker,
)
orig_rng_state = paddle.get_rng_state()
orig_rng_tracker = get_rng_state_tracker().get_states_tracker()
for index, layer in enumerate(self._layers_desc[start:end]):
layer_index = start + index
......@@ -722,8 +733,6 @@ class PipelineLayer(nn.Layer):
else:
run_function.append(layer)
paddle.set_rng_state(orig_rng_state)
get_rng_state_tracker().set_states_tracker(orig_rng_tracker)
return run_function
def forward_function(self, start, end):
......
......@@ -709,6 +709,13 @@ class PipelineParallelWithInterleave(PipelineParallel):
assert len(self.model_chunks) == self.num_model_chunks
self._virtual_pp_world_size = self.num_model_chunks
self._virtual_pp_rank = 0
self._assign_vpp_info(self.model_chunks)
def _assign_vpp_info(self, chunks):
chunk_num = len(chunks)
for i, chunk in enumerate(chunks):
for p in chunk.parameters():
p._chunk_info = {"chunk_id": i, "chunk_num": chunk_num}
def _get_virtual_pp_rank(self, micro_step, forward):
virtual_pp_stage = micro_step % (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册