未验证 提交 b0620a7b 编写于 作者: G Guanghua Yu 提交者: GitHub

correct the use of all_reduce (#7108) (#7199)

上级 e066d8d1
......@@ -388,12 +388,7 @@ class GFLHead(nn.Layer):
avg_factor = sum(avg_factor)
try:
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
paddle.distributed.all_reduce(avg_factor)
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......
......@@ -179,15 +179,8 @@ class OTAHead(GFLHead):
num_level_anchors)
num_total_pos = sum(pos_num_l)
try:
cloned_num_total_pos = num_total_pos.clone()
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
paddle.distributed.all_reduce(num_total_pos)
num_total_pos = num_total_pos / paddle.distributed.get_world_size()
except:
num_total_pos = max(num_total_pos, 1)
......@@ -262,12 +255,7 @@ class OTAHead(GFLHead):
avg_factor = sum(avg_factor)
try:
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
paddle.distributed.all_reduce(avg_factor)
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......@@ -408,15 +396,8 @@ class OTAVFLHead(OTAHead):
num_level_anchors)
num_total_pos = sum(pos_num_l)
try:
cloned_num_total_pos = num_total_pos.clone()
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
paddle.distributed.all_reduce(num_total_pos)
num_total_pos = num_total_pos / paddle.distributed.get_world_size()
except:
num_total_pos = max(num_total_pos, 1)
......@@ -494,12 +475,7 @@ class OTAVFLHead(OTAHead):
avg_factor = sum(avg_factor)
try:
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
paddle.distributed.all_reduce(avg_factor)
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册