未验证 提交 584ae4d7 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add MOE support, PART3 (#54676)

上级 ff806111
...@@ -19,6 +19,8 @@ ...@@ -19,6 +19,8 @@
# Copyright 2021, Jiaao He. All rights reserved. # Copyright 2021, Jiaao He. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License"). # Licensed under the Apache License, Version 2.0 (the "License").
import os
import numpy as np import numpy as np
import paddle import paddle
...@@ -352,7 +354,10 @@ class MoELayer(nn.Layer): ...@@ -352,7 +354,10 @@ class MoELayer(nn.Layer):
assert experts is not None assert experts is not None
self.experts = experts self.experts = experts
if self.world_size > 1: if (
self.world_size > 1
and os.getenv("PADDLE_DISTRI_BACKEND", None) != "xccl"
):
check_nccl_version_for_p2p() check_nccl_version_for_p2p()
self.mp_group = mp_group self.mp_group = mp_group
......
...@@ -1913,6 +1913,13 @@ class Layer: ...@@ -1913,6 +1913,13 @@ class Layer:
p = core.Place() p = core.Place()
p.set_place(t._place()) p.set_place(t._place())
place = core.XPUPlace(p.xpu_device_id()) place = core.XPUPlace(p.xpu_device_id())
elif p.is_custom_place():
p = core.Place()
p.set_place(t._place())
place = core.CustomPlace(
paddle.device.get_device().split(':')[0],
p.custom_device_id(),
)
else: else:
p = core.Place() p = core.Place()
p.set_place(t._place()) p.set_place(t._place())
......
...@@ -1540,7 +1540,12 @@ def load(program, model_path, executor=None, var_list=None): ...@@ -1540,7 +1540,12 @@ def load(program, model_path, executor=None, var_list=None):
p = paddle.fluid.core.Place() p = paddle.fluid.core.Place()
p.set_place(t._place()) p.set_place(t._place())
place = paddle.fluid.XPUPlace(p.xpu_device_id()) place = paddle.fluid.XPUPlace(p.xpu_device_id())
elif p.is_custom_place():
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.CustomPlace(
paddle.device.get_device().split(':')[0], p.custom_device_id()
)
else: else:
p = paddle.fluid.core.Place() p = paddle.fluid.core.Place()
p.set_place(t._place()) p.set_place(t._place())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册