From 8416465eb49db5c71e8073840556429f2e493057 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Tue, 7 Jun 2022 20:26:36 +0800 Subject: [PATCH] [Eager] Support eager all_reducer return value (#6140) * support eager all_reducer return value * revert file * fix error logic * support simota head in eager --- ppdet/modeling/heads/gfl_head.py | 7 +++++- ppdet/modeling/heads/simota_head.py | 36 ++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index 654c84fce..9c87eecd8 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -388,7 +388,12 @@ class GFLHead(nn.Layer): avg_factor = sum(avg_factor) try: - avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) + avg_factor_clone = avg_factor.clone() + tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) + if tmp_avg_factor is not None: + avg_factor = tmp_avg_factor + else: + avg_factor = avg_factor_clone avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: diff --git a/ppdet/modeling/heads/simota_head.py b/ppdet/modeling/heads/simota_head.py index a1485f390..77e515bbc 100644 --- a/ppdet/modeling/heads/simota_head.py +++ b/ppdet/modeling/heads/simota_head.py @@ -179,8 +179,15 @@ class OTAHead(GFLHead): num_level_anchors) num_total_pos = sum(pos_num_l) try: - num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( - )) / paddle.distributed.get_world_size() + cloned_num_total_pos = num_total_pos.clone() + reduced_cloned_num_total_pos = paddle.distributed.all_reduce( + cloned_num_total_pos) + if reduced_cloned_num_total_pos is not None: + num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size( + ) + else: + num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size( + ) except: num_total_pos = max(num_total_pos, 1) @@ -255,7 +262,12 @@ class OTAHead(GFLHead): avg_factor = sum(avg_factor) try: - avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) + avg_factor_clone = avg_factor.clone() + tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) + if tmp_avg_factor is not None: + avg_factor = tmp_avg_factor + else: + avg_factor = avg_factor_clone avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: @@ -396,8 +408,15 @@ class OTAVFLHead(OTAHead): num_level_anchors) num_total_pos = sum(pos_num_l) try: - num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( - )) / paddle.distributed.get_world_size() + cloned_num_total_pos = num_total_pos.clone() + reduced_cloned_num_total_pos = paddle.distributed.all_reduce( + cloned_num_total_pos) + if reduced_cloned_num_total_pos is not None: + num_total_pos = reduced_cloned_num_total_pos / paddle.distributed.get_world_size( + ) + else: + num_total_pos = cloned_num_total_pos / paddle.distributed.get_world_size( + ) except: num_total_pos = max(num_total_pos, 1) @@ -475,7 +494,12 @@ class OTAVFLHead(OTAHead): avg_factor = sum(avg_factor) try: - avg_factor = paddle.distributed.all_reduce(avg_factor.clone()) + avg_factor_clone = avg_factor.clone() + tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) + if tmp_avg_factor is not None: + avg_factor = tmp_avg_factor + else: + avg_factor = avg_factor_clone avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: -- GitLab