未验证 提交 5f2c5b9e 编写于 作者: R Roc 提交者: GitHub

fix moe apis (#41650)

上级 d95280c7
...@@ -62,6 +62,6 @@ class GShardGate(NaiveGate): ...@@ -62,6 +62,6 @@ class GShardGate(NaiveGate):
if self.random_routing: if self.random_routing:
rand_routing_prob = paddle.rand( rand_routing_prob = paddle.rand(
shape=[gate_score.shape[0]], dtype="float32") shape=[gate_score.shape[0]], dtype="float32")
topk_idx = paddle.distributed.utils.random_routing( topk_idx = paddle.distributed.models.moe.utils._random_routing(
topk_idx, topk_val, rand_routing_prob) topk_idx, topk_val, rand_routing_prob)
return topk_val, topk_idx return topk_val, topk_idx
...@@ -11,7 +11,8 @@ ...@@ -11,7 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.distributed.models.moe.utils import * from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos
import paddle
def _alltoall(in_tensor_list, group=None, use_calc_stream=True): def _alltoall(in_tensor_list, group=None, use_calc_stream=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册