From b0620a7b87826d9e37f14ade696b59c3c80c7ce3 Mon Sep 17 00:00:00 2001 From: Guanghua Yu <742925032@qq.com> Date: Wed, 26 Oct 2022 20:36:24 +0800 Subject: [PATCH] correct the use of all_reduce (#7108) (#7199) --- ppdet/modeling/heads/gfl_head.py | 7 +----- ppdet/modeling/heads/simota_head.py | 36 +++++------------------------ 2 files changed, 7 insertions(+), 36 deletions(-) diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index 9c87eecd8..aa6dc478b 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -388,12 +388,7 @@ class GFLHead(nn.Layer): avg_factor = sum(avg_factor) try: - avg_factor_clone = avg_factor.clone() - tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) - if tmp_avg_factor is not None: - avg_factor = tmp_avg_factor - else: - avg_factor = avg_factor_clone + paddle.distributed.all_reduce(avg_factor) avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: diff --git a/ppdet/modeling/heads/simota_head.py b/ppdet/modeling/heads/simota_head.py index 77e515bbc..e9a786ede 100644 --- a/ppdet/modeling/heads/simota_head.py +++ b/ppdet/modeling/heads/simota_head.py @@ -179,15 +179,8 @@ class OTAHead(GFLHead): num_level_anchors) num_total_pos = sum(pos_num_l) try: - cloned_num_total_pos = num_total_pos.clone() - reduced_cloned_num_total_pos = paddle.distributed.all_reduce( - cloned_num_total_pos) - if reduced_cloned_num_total_pos is not None: - num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size( - ) - else: - num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size( - ) + paddle.distributed.all_reduce(num_total_pos) + num_total_pos = num_total_pos / paddle.distributed.get_world_size() except: num_total_pos = max(num_total_pos, 1) @@ -262,12 +255,7 @@ class OTAHead(GFLHead): avg_factor = sum(avg_factor) try: - avg_factor_clone = avg_factor.clone() - tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) - if tmp_avg_factor is not None: - avg_factor = tmp_avg_factor - else: - avg_factor = avg_factor_clone + paddle.distributed.all_reduce(avg_factor) avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: @@ -408,15 +396,8 @@ class OTAVFLHead(OTAHead): num_level_anchors) num_total_pos = sum(pos_num_l) try: - cloned_num_total_pos = num_total_pos.clone() - reduced_cloned_num_total_pos = paddle.distributed.all_reduce( - cloned_num_total_pos) - if reduced_cloned_num_total_pos is not None: - num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size( - ) - else: - num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size( - ) + paddle.distributed.all_reduce(num_total_pos) + num_total_pos = num_total_pos / paddle.distributed.get_world_size() except: num_total_pos = max(num_total_pos, 1) @@ -494,12 +475,7 @@ class OTAVFLHead(OTAHead): avg_factor = sum(avg_factor) try: - avg_factor_clone = avg_factor.clone() - tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) - if tmp_avg_factor is not None: - avg_factor = tmp_avg_factor - else: - avg_factor = avg_factor_clone + paddle.distributed.all_reduce(avg_factor) avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: -- GitLab