diff --git a/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py index b1c0cd4214dbb1b66cea91224fb5e3eaa094b991..3ab3cf6901402d450082fcfe2344f707e3353e13 100644 --- a/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py +++ b/python/paddle/incubate/distributed/models/moe/gate/gshard_gate.py @@ -62,6 +62,6 @@ class GShardGate(NaiveGate): if self.random_routing: rand_routing_prob = paddle.rand( shape=[gate_score.shape[0]], dtype="float32") - topk_idx = paddle.distributed.utils.random_routing( + topk_idx = paddle.distributed.models.moe.utils._random_routing( topk_idx, topk_val, rand_routing_prob) return topk_val, topk_idx diff --git a/python/paddle/incubate/distributed/models/moe/utils.py b/python/paddle/incubate/distributed/models/moe/utils.py index 99e31a16273bf7ef939d724c00d35e7fb647aada..0e87fe3e313600a520f57caa04baa1ccf59f4d21 100644 --- a/python/paddle/incubate/distributed/models/moe/utils.py +++ b/python/paddle/incubate/distributed/models/moe/utils.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from paddle.distributed.models.moe.utils import * +from paddle.distributed.models.moe.utils import _number_count, _limit_by_capacity, _prune_gate_by_capacity, _assign_pos +import paddle def _alltoall(in_tensor_list, group=None, use_calc_stream=True):