From 5f2c5b9e06ce1c531d59155b35b91a1ea5bd8764 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Wed, 13 Apr 2022 11:42:31 +0800 Subject: [PATCH] fix moe apis (#41650) --- .../paddle/incubate/distributed/models/moe/gate/gshard_gate.py | 2 +- python/paddle/incubate/distributed/models/moe/utils.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py index b1c0cd4214d..3ab3cf69014 100644 --- a/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py +++ b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py @@ -62,6 +62,6 @@ class GShardGate(NaiveGate): if self.random_routing: rand_routing_prob = paddle.rand( shape=[gate_score.shape[0]], dtype="float32") - topk_idx = paddle.distributed.utils.random_routing( + topk_idx = paddle.distributed.models.moe.utils._random_routing( topk_idx, topk_val, rand_routing_prob) return topk_val, topk_idx diff --git a/python/paddle/incubate/distributed/models/moe/utils.py b/python/paddle/incubate/distributed/models/moe/utils.py index 99e31a16273..0e87fe3e313 100644 --- a/python/paddle/incubate/distributed/models/moe/utils.py +++ b/python/paddle/incubate/distributed/models/moe/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddle.distributed.models.moe.utils import * +from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos +import paddle def _alltoall(in_tensor_list, group=None, use_calc_stream=True): -- GitLab