未验证 提交 8416465e 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support eager all_reducer return value (#6140)

* support eager all_reducer return value

* revert file

* fix error logic

* support simota head in eager
上级 cfb3699c
...@@ -388,7 +388,12 @@ class GFLHead(nn.Layer): ...@@ -388,7 +388,12 @@ class GFLHead(nn.Layer):
avg_factor = sum(avg_factor) avg_factor = sum(avg_factor)
try: try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip( avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1) avg_factor / paddle.distributed.get_world_size(), min=1)
except: except:
......
...@@ -179,8 +179,15 @@ class OTAHead(GFLHead): ...@@ -179,8 +179,15 @@ class OTAHead(GFLHead):
num_level_anchors) num_level_anchors)
num_total_pos = sum(pos_num_l) num_total_pos = sum(pos_num_l)
try: try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( cloned_num_total_pos = num_total_pos.clone()
)) / paddle.distributed.get_world_size() reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
...@@ -255,7 +262,12 @@ class OTAHead(GFLHead): ...@@ -255,7 +262,12 @@ class OTAHead(GFLHead):
avg_factor = sum(avg_factor) avg_factor = sum(avg_factor)
try: try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip( avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1) avg_factor / paddle.distributed.get_world_size(), min=1)
except: except:
...@@ -396,8 +408,15 @@ class OTAVFLHead(OTAHead): ...@@ -396,8 +408,15 @@ class OTAVFLHead(OTAHead):
num_level_anchors) num_level_anchors)
num_total_pos = sum(pos_num_l) num_total_pos = sum(pos_num_l)
try: try:
num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( cloned_num_total_pos = num_total_pos.clone()
)) / paddle.distributed.get_world_size() reduced_cloned_num_total_pos = paddle.distributed.all_reduce(
cloned_num_total_pos)
if reduced_cloned_num_total_pos is not None:
num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size(
)
else:
num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size(
)
except: except:
num_total_pos = max(num_total_pos, 1) num_total_pos = max(num_total_pos, 1)
...@@ -475,7 +494,12 @@ class OTAVFLHead(OTAHead): ...@@ -475,7 +494,12 @@ class OTAVFLHead(OTAHead):
avg_factor = sum(avg_factor) avg_factor = sum(avg_factor)
try: try:
avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) avg_factor_clone = avg_factor.clone()
tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone)
if tmp_avg_factor is not None:
avg_factor = tmp_avg_factor
else:
avg_factor = avg_factor_clone
avg_factor = paddle.clip( avg_factor = paddle.clip(
avg_factor / paddle.distributed.get_world_size(), min=1) avg_factor / paddle.distributed.get_world_size(), min=1)
except: except:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册