From 65ced1fa6f3cfec8935d808fc4ba5f7931710562 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Mon, 21 Feb 2022 10:38:33 +0800 Subject: [PATCH] fix alignment bug (#39747) --- .../distributed_fused_lamb_init_op.cu | 26 +++++++------------ ...est_distributed_fused_lamb_op_with_clip.py | 2 +- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu index 04d6d51acaf..3bb605d7f55 100644 --- a/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu +++ b/paddle/fluid/operators/optimizers/distributed_fused_lamb_init_op.cu @@ -63,19 +63,6 @@ static size_t GetAlignSize(size_t n, size_t alignment) { return remainder == 0 ? n : n + alignment - remainder; } -// gcd(x, y) = gcd(y, x % y) -// gcd(x, 0) = x -static size_t GCD(size_t x, size_t y) { - while (y > 0) { - auto tmp = x; - x = y; - y = tmp % y; - } - return x; -} - -static size_t LCM(size_t x, size_t y) { return x / GCD(x, y) * y; } - // Shard the ParamGradInfo list by the numel size [start_size, end_size) // The final results should be: // @@ -155,11 +142,18 @@ static size_t FillAlignmentPaddingInfo(std::vector *infos, size_t total_numel_sum_with_padding = 0; size_t n = infos->size(); - auto lcm = LCM(alignment, nranks); for (size_t i = 0; i < n; ++i) { auto &info = (*infos)[i]; - size_t numel_with_padding = - GetAlignSize(info.numel, i + 1 == n ? lcm : alignment); + size_t numel_with_padding; + if (i + 1 == n) { + // the total fused numel must be a factor of alignment * nranks + numel_with_padding = + GetAlignSize(info.numel + total_numel_sum_with_padding, + alignment * nranks) - + total_numel_sum_with_padding; + } else { + numel_with_padding = GetAlignSize(info.numel, alignment); + } info.numel_with_padding = numel_with_padding; info.numel_offset = total_numel_sum_with_padding; total_numel_sum_with_padding += numel_with_padding; diff --git a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py index 060a790a6e5..af99529adfa 100644 --- a/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py +++ b/python/paddle/fluid/tests/unittests/test_distributed_fused_lamb_op_with_clip.py @@ -72,7 +72,7 @@ class TestDistributedFusedLambWithClip(unittest.TestCase): def test_1(self): run_test(clip_after_allreduce=True, max_global_norm=0.01) - def _test_2(self): + def test_2(self): run_test(clip_after_allreduce=False, max_global_norm=0.01) -- GitLab