未验证 提交 56fded1b 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART3 (#53703)

上级 e04f8d4a
......@@ -29,6 +29,19 @@ namespace pybind {
void BindCustomDevicePy(py::module *m_ptr) {
auto &m = *m_ptr;
// Bind Methods
m.def(
"_get_device_total_memory",
[](const std::string &device_type, int device_id) {
auto place = paddle::platform::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id);
size_t total = 0, free = 0;
phi::DeviceManager::MemoryStats(place, &total, &free);
return total;
},
py::arg("device_type"),
py::arg("device_id") = -1);
m.def(
"_get_current_custom_device_stream",
[](const std::string &device_type, int device_id) {
......
......@@ -883,6 +883,13 @@ def get_default_cluster(json_config=None):
assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count
if os.getenv("PADDLE_DISTRI_BACKEND", None) == "xccl":
gpu_name = os.getenv("PADDLE_XCCL_BACKEND", None)
gpu_model = gpu_name
memory = int(
paddle.fluid.core._get_device_total_memory(gpu_name)
) // (1000**3)
else:
gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now."
......
......@@ -148,6 +148,11 @@ class ProcessGroup:
core.BKCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
elif genv.device_type in core.get_all_custom_device_type():
place = core.CustomPlace(genv.device_type, genv.device_id)
core.XCCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
else:
raise AssertionError('No CUDA device found')
......@@ -162,6 +167,14 @@ class ProcessGroup:
paddle.set_device(
'xpu:%d' % paddle.distributed.ParallelEnv().dev_id
)
elif genv.device_type in core.get_all_custom_device_type():
paddle.set_device(
'%s:%d'
% (
paddle.distributed.ParallelEnv().device_type,
paddle.distributed.ParallelEnv().dev_id,
),
)
tmp = (
paddle.to_tensor([1], dtype="int32")
if in_dygraph_mode()
......
......@@ -1040,6 +1040,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_gen_bkcl_id",
"c_gen_xccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
......
......@@ -167,6 +167,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx)
dev = 'custom'
else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx)
......
......@@ -14,6 +14,7 @@
import os
import paddle
from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready,
)
......@@ -204,6 +205,38 @@ class Collective:
self.op_role_key: OpRole.Forward,
},
)
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=unique_name.generate('xccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
def _broadcast_params(self):
block = self.startup_program.global_block()
......
......@@ -14,6 +14,7 @@
import os
import paddle
from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready,
)
......@@ -200,6 +201,37 @@ class Collective:
self.op_role_key: OpRole.Forward,
},
)
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=unique_name.generate('xccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
def _broadcast_params(self):
block = self.startup_program.global_block()
......
......@@ -184,6 +184,37 @@ def init_communicator(
'ring_id': 0,
},
)
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=fluid.unique_name.generate('xccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
},
)
def prepare_distributed_context(place=None):
......
......@@ -58,6 +58,21 @@ def init_communicator(block, rank, ranks, ring_id):
'ring_id': ring_id,
},
)
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': local_rank,
'endpoint': cur_ep,
'other_endpoints': other_eps,
'ring_id': ring_id,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': comm_id_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册