未验证 提交 f21824d9 编写于 作者: R Roc 提交者: GitHub

fix recompute (#42128)

* fix recompute

* modify return
上级 f4ce8a92
......@@ -399,7 +399,7 @@ class MoELayer(nn.Layer):
def experts_fwd(x, fwd_expert_count, experts):
if x.shape[0] == 0:
return paddle.empty(x.shape, x.dtype)
return x
y = []
last_index = 0
assert isinstance(fwd_expert_count, np.ndarray)
......@@ -411,7 +411,7 @@ class MoELayer(nn.Layer):
last_index = expert_count + last_index
return paddle.concat(y, axis=0)
if self.recompute_interval <= 0:
if self.recompute_interval <= 0 or x.shape[0] == 0:
x = experts_fwd(x, fwd_expert_count.numpy(), self.experts)
else:
x = _hp_recompute(experts_fwd, x,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册