From 6366e0a9bf78bdfca92c509ae3f85a580508d6e7 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Tue, 26 Apr 2022 17:53:48 +0800 Subject: [PATCH] fix recompute (#42128) (#42216) * fix recompute * modify return --- python/paddle/incubate/distributed/models/moe/moe_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index eebb635e3e..ba22ffee3e 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -399,7 +399,7 @@ class MoELayer(nn.Layer): def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: - return paddle.empty(x.shape, x.dtype) + return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) @@ -411,7 +411,7 @@ class MoELayer(nn.Layer): last_index = expert_count + last_index return paddle.concat(y, axis=0) - if self.recompute_interval <= 0: + if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = _hp_recompute(experts_fwd, x, -- GitLab