未验证 提交 8416465e 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support eager all_reducer return value (#6140)

* support eager all_reducer return value

* revert file

* fix error logic

* support simota head in eager
上级 cfb3699c
......@@ -388,7 +388,12 @@ class GFLHead(nn.Layer):
avg_factor = sum(avg_factor)
try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......
......@@ -179,8 +179,15 @@ class OTAHead(GFLHead):
num_level_anchors)
num_total_pos = sum(pos_num_l)
try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
)) / paddle.distributed.get_world_size()
cloned_num_total_pos = num_total_pos.clone()
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
except:
num_total_pos = max(num_total_pos, 1)
......@@ -255,7 +262,12 @@ class OTAHead(GFLHead):
avg_factor = sum(avg_factor)
try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......@@ -396,8 +408,15 @@ class OTAVFLHead(OTAHead):
num_level_anchors)
num_total_pos = sum(pos_num_l)
try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
)) / paddle.distributed.get_world_size()
cloned_num_total_pos = num_total_pos.clone()
reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
except:
num_total_pos = max(num_total_pos, 1)
......@@ -475,7 +494,12 @@ class OTAVFLHead(OTAHead):
avg_factor = sum(avg_factor)
try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1)
except:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册