未验证 提交 56fded1b 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add inference MP support, PART3 (#53703)

上级 e04f8d4a
...@@ -29,6 +29,19 @@ namespace pybind { ...@@ -29,6 +29,19 @@ namespace pybind {
void BindCustomDevicePy(py::module *m_ptr) { void BindCustomDevicePy(py::module *m_ptr) {
auto &m = *m_ptr; auto &m = *m_ptr;
// Bind Methods // Bind Methods
m.def(
"_get_device_total_memory",
[](const std::string &device_type, int device_id) {
auto place = paddle::platform::CustomPlace(
device_type,
device_id == -1 ? phi::DeviceManager::GetDevice(device_type)
: device_id);
size_t total = 0, free = 0;
phi::DeviceManager::MemoryStats(place, &total, &free);
return total;
},
py::arg("device_type"),
py::arg("device_id") = -1);
m.def( m.def(
"_get_current_custom_device_stream", "_get_current_custom_device_stream",
[](const std::string &device_type, int device_id) { [](const std::string &device_type, int device_id) {
......
...@@ -883,6 +883,13 @@ def get_default_cluster(json_config=None): ...@@ -883,6 +883,13 @@ def get_default_cluster(json_config=None):
assert global_device_count % local_device_count == 0 assert global_device_count % local_device_count == 0
node_count = int(global_device_count) // local_device_count node_count = int(global_device_count) // local_device_count
if os.getenv("PADDLE_DISTRI_BACKEND", None) == "xccl":
gpu_name = os.getenv("PADDLE_XCCL_BACKEND", None)
gpu_model = gpu_name
memory = int(
paddle.fluid.core._get_device_total_memory(gpu_name)
) // (1000**3)
else:
gpu_info = paddle.device.cuda.get_device_properties() gpu_info = paddle.device.cuda.get_device_properties()
assert gpu_info, "Auto parallel just runs on gpu now." assert gpu_info, "Auto parallel just runs on gpu now."
......
...@@ -148,6 +148,11 @@ class ProcessGroup: ...@@ -148,6 +148,11 @@ class ProcessGroup:
core.BKCLParallelContext(strategy, place).init_with_ring_id( core.BKCLParallelContext(strategy, place).init_with_ring_id(
ring_id ring_id
) )
elif genv.device_type in core.get_all_custom_device_type():
place = core.CustomPlace(genv.device_type, genv.device_id)
core.XCCLParallelContext(strategy, place).init_with_ring_id(
ring_id
)
else: else:
raise AssertionError('No CUDA device found') raise AssertionError('No CUDA device found')
...@@ -162,6 +167,14 @@ class ProcessGroup: ...@@ -162,6 +167,14 @@ class ProcessGroup:
paddle.set_device( paddle.set_device(
'xpu:%d' % paddle.distributed.ParallelEnv().dev_id 'xpu:%d' % paddle.distributed.ParallelEnv().dev_id
) )
elif genv.device_type in core.get_all_custom_device_type():
paddle.set_device(
'%s:%d'
% (
paddle.distributed.ParallelEnv().device_type,
paddle.distributed.ParallelEnv().dev_id,
),
)
tmp = ( tmp = (
paddle.to_tensor([1], dtype="int32") paddle.to_tensor([1], dtype="int32")
if in_dygraph_mode() if in_dygraph_mode()
......
...@@ -1040,6 +1040,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1040,6 +1040,7 @@ class ShardingOptimizer(MetaOptimizerBase):
"c_calc_comm_stream", "c_calc_comm_stream",
"c_gen_nccl_id", "c_gen_nccl_id",
"c_gen_bkcl_id", "c_gen_bkcl_id",
"c_gen_xccl_id",
"c_comm_init", "c_comm_init",
'send_v2', 'send_v2',
'recv_v2', 'recv_v2',
......
...@@ -167,6 +167,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs): ...@@ -167,6 +167,7 @@ def broadcast_input_data(hcg, *inputs, **kwargs):
place = paddle.CUDAPlace(dev_idx) place = paddle.CUDAPlace(dev_idx)
elif dev in paddle.device.get_all_custom_device_type(): elif dev in paddle.device.get_all_custom_device_type():
place = paddle.CustomPlace(dev, dev_idx) place = paddle.CustomPlace(dev, dev_idx)
dev = 'custom'
else: else:
place = eval(f"paddle.{dev.upper()}Place")(dev_idx) place = eval(f"paddle.{dev.upper()}Place")(dev_idx)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import paddle
from paddle.distributed.fleet.base.private_helper_function import ( from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready, wait_server_ready,
) )
...@@ -204,6 +205,38 @@ class Collective: ...@@ -204,6 +205,38 @@ class Collective:
self.op_role_key: OpRole.Forward, self.op_role_key: OpRole.Forward,
}, },
) )
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=unique_name.generate('xccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)}
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
def _broadcast_params(self): def _broadcast_params(self):
block = self.startup_program.global_block() block = self.startup_program.global_block()
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import paddle
from paddle.distributed.fleet.base.private_helper_function import ( from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready, wait_server_ready,
) )
...@@ -200,6 +201,37 @@ class Collective: ...@@ -200,6 +201,37 @@ class Collective:
self.op_role_key: OpRole.Forward, self.op_role_key: OpRole.Forward,
}, },
) )
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=unique_name.generate('xccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
self.op_role_key: OpRole.Forward,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
self.op_role_key: OpRole.Forward,
},
)
def _broadcast_params(self): def _broadcast_params(self):
block = self.startup_program.global_block() block = self.startup_program.global_block()
......
...@@ -184,6 +184,37 @@ def init_communicator( ...@@ -184,6 +184,37 @@ def init_communicator(
'ring_id': 0, 'ring_id': 0,
}, },
) )
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
xccl_id_var = block.create_var(
name=fluid.unique_name.generate('xccl_id'),
persistable=True,
type=fluid.core.VarDesc.VarType.RAW,
)
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': xccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
},
)
block.append_op(
type='c_comm_init',
inputs={'X': xccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': 0,
},
)
def prepare_distributed_context(place=None): def prepare_distributed_context(place=None):
......
...@@ -58,6 +58,21 @@ def init_communicator(block, rank, ranks, ring_id): ...@@ -58,6 +58,21 @@ def init_communicator(block, rank, ranks, ring_id):
'ring_id': ring_id, 'ring_id': ring_id,
}, },
) )
elif (
paddle.distributed.ParallelEnv().device_type
in paddle.device.get_all_custom_device_type()
):
block.append_op(
type='c_gen_xccl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': local_rank,
'endpoint': cur_ep,
'other_endpoints': other_eps,
'ring_id': ring_id,
},
)
block.append_op( block.append_op(
type='c_comm_init', type='c_comm_init',
inputs={'X': comm_id_var}, inputs={'X': comm_id_var},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册