未验证 提交 65ced1fa 编写于 作者: S sneaxiy 提交者: GitHub

fix alignment bug (#39747)

上级 496aadfb
......@@ -63,19 +63,6 @@ static size_t GetAlignSize(size_t n, size_t alignment) {
return remainder == 0 ? n : n + alignment - remainder;
}
// gcd(x, y) = gcd(y, x % y)
// gcd(x, 0) = x
static size_t GCD(size_t x, size_t y) {
while (y > 0) {
auto tmp = x;
x = y;
y = tmp % y;
}
return x;
}
static size_t LCM(size_t x, size_t y) { return x / GCD(x, y) * y; }
// Shard the ParamGradInfo list by the numel size [start_size, end_size)
// The final results should be:
//
......@@ -155,11 +142,18 @@ static size_t FillAlignmentPaddingInfo(std::vector<ParamGradInfo> *infos,
size_t total_numel_sum_with_padding = 0;
size_t n = infos->size();
auto lcm = LCM(alignment, nranks);
for (size_t i = 0; i < n; ++i) {
auto &info = (*infos)[i];
size_t numel_with_padding =
GetAlignSize(info.numel, i + 1 == n ? lcm : alignment);
size_t numel_with_padding;
if (i + 1 == n) {
// the total fused numel must be a factor of alignment * nranks
numel_with_padding =
GetAlignSize(info.numel + total_numel_sum_with_padding,
alignment * nranks) -
total_numel_sum_with_padding;
} else {
numel_with_padding = GetAlignSize(info.numel, alignment);
}
info.numel_with_padding = numel_with_padding;
info.numel_offset = total_numel_sum_with_padding;
total_numel_sum_with_padding += numel_with_padding;
......
......@@ -72,7 +72,7 @@ class TestDistributedFusedLambWithClip(unittest.TestCase):
def test_1(self):
run_test(clip_after_allreduce=True, max_global_norm=0.01)
def _test_2(self):
def test_2(self):
run_test(clip_after_allreduce=False, max_global_norm=0.01)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册