diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index eebb635e3ead76e0ef95e3f6eb558c43b4008c1c..ba22ffee3e4d6b846ce7207561219ed7acec7e34 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -399,7 +399,7 @@ class MoELayer(nn.Layer): def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: - return paddle.empty(x.shape, x.dtype) + return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) @@ -411,7 +411,7 @@ class MoELayer(nn.Layer): last_index = expert_count + last_index return paddle.concat(y, axis=0) - if self.recompute_interval <= 0: + if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = _hp_recompute(experts_fwd, x,