未验证 提交 3a16314e 编写于 作者: U ucsk 提交者: GitHub

fix all_reduce (#7488)

上级 5171aa23
...@@ -311,8 +311,9 @@ class GFLHead(nn.Layer): ...@@ -311,8 +311,9 @@ class GFLHead(nn.Layer):
num_level_anchors) num_level_anchors)
num_total_pos = sum(gt_meta['pos_num']) num_total_pos = sum(gt_meta['pos_num'])
try: try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( paddle.distributed.all_reduce(num_total_pos)
)) / paddle.distributed.get_world_size() num_total_pos = paddle.clip(
num_total_pos / paddle.distributed.get_world_size(), min=1)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
......
...@@ -153,8 +153,9 @@ class LDGFLHead(GFLHead): ...@@ -153,8 +153,9 @@ class LDGFLHead(GFLHead):
num_total_pos = sum(gt_meta['pos_num']) num_total_pos = sum(gt_meta['pos_num'])
try: try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( paddle.distributed.all_reduce(num_total_pos)
)) / paddle.distributed.get_world_size() num_total_pos = paddle.clip(
num_total_pos / paddle.distributed.get_world_size(), min=1.)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
...@@ -293,12 +294,7 @@ class LDGFLHead(GFLHead): ...@@ -293,12 +294,7 @@ class LDGFLHead(GFLHead):
avg_factor = sum(avg_factor) # + 1e-6 avg_factor = sum(avg_factor) # + 1e-6
try: try:
avg_factor_clone = avg_factor.clone() paddle.distributed.all_reduce(avg_factor)
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip( avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1) avg_factor / paddle.distributed.get_world_size(), min=1)
except: except:
......
...@@ -180,7 +180,8 @@ class OTAHead(GFLHead): ...@@ -180,7 +180,8 @@ class OTAHead(GFLHead):
num_total_pos = sum(pos_num_l) num_total_pos = sum(pos_num_l)
try: try:
paddle.distributed.all_reduce(num_total_pos) paddle.distributed.all_reduce(num_total_pos)
num_total_pos = num_total_pos / paddle.distributed.get_world_size() num_total_pos = paddle.clip(
num_total_pos / paddle.distributed.get_world_size(), min=1.)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
...@@ -397,7 +398,8 @@ class OTAVFLHead(OTAHead): ...@@ -397,7 +398,8 @@ class OTAVFLHead(OTAHead):
num_total_pos = sum(pos_num_l) num_total_pos = sum(pos_num_l)
try: try:
paddle.distributed.all_reduce(num_total_pos) paddle.distributed.all_reduce(num_total_pos)
num_total_pos = num_total_pos / paddle.distributed.get_world_size() num_total_pos = paddle.clip(
num_total_pos / paddle.distributed.get_world_size(), min=1.)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册