From 3a16314ed7e56f1008b576931b312ff3a7d8cd01 Mon Sep 17 00:00:00 2001 From: ucsk <53417456+ucsk@users.noreply.github.com> Date: Wed, 4 Jan 2023 14:15:05 +0800 Subject: [PATCH] fix all_reduce (#7488) --- ppdet/modeling/heads/gfl_head.py | 5 +++-- ppdet/modeling/heads/ld_gfl_head.py | 12 ++++-------- ppdet/modeling/heads/simota_head.py | 6 ++++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index fdecaf6b0..a1f518da5 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -311,8 +311,9 @@ class GFLHead(nn.Layer): num_level_anchors) num_total_pos = sum(gt_meta['pos_num']) try: - num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( - )) / paddle.distributed.get_world_size() + paddle.distributed.all_reduce(num_total_pos) + num_total_pos = paddle.clip( + num_total_pos / paddle.distributed.get_world_size(), min=1) except: num_total_pos = max(num_total_pos, 1) diff --git a/ppdet/modeling/heads/ld_gfl_head.py b/ppdet/modeling/heads/ld_gfl_head.py index 49f9ac2ae..dbff7ecba 100644 --- a/ppdet/modeling/heads/ld_gfl_head.py +++ b/ppdet/modeling/heads/ld_gfl_head.py @@ -153,8 +153,9 @@ class LDGFLHead(GFLHead): num_total_pos = sum(gt_meta['pos_num']) try: - num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone( - )) / paddle.distributed.get_world_size() + paddle.distributed.all_reduce(num_total_pos) + num_total_pos = paddle.clip( + num_total_pos / paddle.distributed.get_world_size(), min=1.) except: num_total_pos = max(num_total_pos, 1) @@ -293,12 +294,7 @@ class LDGFLHead(GFLHead): avg_factor = sum(avg_factor) # + 1e-6 try: - avg_factor_clone = avg_factor.clone() - tmp_avg_factor = paddle.distributed.all_reduce(avg_factor_clone) - if tmp_avg_factor is not None: - avg_factor = tmp_avg_factor - else: - avg_factor = avg_factor_clone + paddle.distributed.all_reduce(avg_factor) avg_factor = paddle.clip( avg_factor / paddle.distributed.get_world_size(), min=1) except: diff --git a/ppdet/modeling/heads/simota_head.py b/ppdet/modeling/heads/simota_head.py index e9a786ede..e74f01757 100644 --- a/ppdet/modeling/heads/simota_head.py +++ b/ppdet/modeling/heads/simota_head.py @@ -180,7 +180,8 @@ class OTAHead(GFLHead): num_total_pos = sum(pos_num_l) try: paddle.distributed.all_reduce(num_total_pos) - num_total_pos = num_total_pos / paddle.distributed.get_world_size() + num_total_pos = paddle.clip( + num_total_pos / paddle.distributed.get_world_size(), min=1.) except: num_total_pos = max(num_total_pos, 1) @@ -397,7 +398,8 @@ class OTAVFLHead(OTAHead): num_total_pos = sum(pos_num_l) try: paddle.distributed.all_reduce(num_total_pos) - num_total_pos = num_total_pos / paddle.distributed.get_world_size() + num_total_pos = paddle.clip( + num_total_pos / paddle.distributed.get_world_size(), min=1.) except: num_total_pos = max(num_total_pos, 1) -- GitLab