未验证 提交 25409dcc 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add sharding support (#54384)

* [CustomDevice] add sarding support

* update
上级 3535049a
......@@ -722,6 +722,43 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Send(
false,
false);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupCustom::Reduce(
phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) {
phi::distributed::CommStaticCheck::SameShape(*out_tensor,
in_tensor,
/*dst_rank*/ opts.root_rank,
/*cur_rank*/ rank_,
size_,
phi::AllocationType::CUSTOM);
std::vector<phi::DenseTensor> in_wrapper{in_tensor};
std::vector<phi::DenseTensor> out_wrapper{*out_tensor};
return Collective(
in_wrapper,
out_wrapper,
[&](phi::DenseTensor& input,
phi::DenseTensor& output,
phi::ccl::CCLComm comm,
const phi::stream::Stream& stream) {
phi::DeviceManager::CCLReduce(device_type_,
input.data(),
output.data(),
input.numel(),
phi::ccl::ToCCLDataType(input.dtype()),
ToCustomCCLRedType(opts.reduce_op),
opts.root_rank,
comm,
stream);
},
CommType::REDUCE,
sync_op,
use_calc_stream);
}
std::shared_ptr<ProcessGroupCustom>
ProcessGroupCustom::CreateProcessGroupCustom(
const std::shared_ptr<phi::distributed::Store>& store,
......
......@@ -163,6 +163,12 @@ class ProcessGroupCustom : public ProcessGroupWithStream {
std::shared_ptr<ProcessGroup::Task> Recv(
std::vector<phi::DenseTensor>& tensors, int src_rank) override;
std::shared_ptr<ProcessGroup::Task> Reduce(phi::DenseTensor* out_tensor,
const phi::DenseTensor& in_tensor,
const ReduceOptions& opts,
bool sync_op,
bool use_calc_stream) override;
protected:
virtual std::shared_ptr<ProcessGroupCustom::CustomTask> CreateTask(
std::vector<Place> places,
......
......@@ -29,6 +29,10 @@ namespace pybind {
void BindCustomDevicePy(py::module *m_ptr) {
auto &m = *m_ptr;
// Bind Methods
m.def("_get_device_min_chunk_size", [](const std::string &device_type) {
auto place = paddle::platform::CustomPlace(device_type);
return phi::DeviceManager::GetMinChunkSize(place);
});
m.def(
"_get_device_total_memory",
[](const std::string &device_type, int device_id) {
......
......@@ -82,8 +82,10 @@ class GroupShardedOptimizerStage2(Optimizer):
super().__init__(learning_rate=optim._learning_rate, parameters=params)
assert (
core.is_compiled_with_cuda() or core.is_compiled_with_xpu()
), "Only GPU and XPU is supported now"
core.is_compiled_with_cuda()
or core.is_compiled_with_xpu()
or (device in core.get_all_custom_device_type())
), "Only GPU and XPU and CustomDevice is supported now"
# Segmentation information
self._dtype_rank_params = (
......@@ -371,6 +373,13 @@ class GroupShardedOptimizerStage2(Optimizer):
Count the memory size of the parameters corresponding to rank under the corresponding dtype.
"""
# CUDA alignment 256 bytes
if self._default_device in core.get_all_custom_device_type():
device_alignment = core.libpaddle._get_device_min_chunk_size(
self._default_device
)
else:
device_alignment = alignment[self._default_device]
if len(self._rank_buffer_size) == 0:
for dtype in self.dtype_rank_params.keys():
if dtype not in self._rank_buffer_size.keys():
......@@ -384,11 +393,11 @@ class GroupShardedOptimizerStage2(Optimizer):
if not param.trainable:
continue
size = param._numel() * align[dtype]
remaining = size % alignment[self._default_device]
remaining = size % device_alignment
ali = (
0
if remaining == 0
else alignment[self._default_device] - remaining
else device_alignment - remaining
)
align_ = ali // align[dtype]
self._rank_buffer_size[dtype][dst_rank] += (
......@@ -439,14 +448,17 @@ class GroupShardedOptimizerStage2(Optimizer):
if self.offload:
self._optim._master_weights = self._master_params
cpu_master_params = list(self._master_params.values())
if self._default_device in core.get_all_custom_device_type():
device_alignment = core.libpaddle._get_device_min_chunk_size(
self._default_device
)
else:
device_alignment = alignment[self._default_device]
for param in cpu_master_params:
size = param._numel() * align[Type.fp32.value]
remaining = size % alignment[self.offload_device]
ali = (
0
if remaining == 0
else alignment[self.offload_device] - remaining
)
remaining = size % device_alignment
ali = 0 if remaining == 0 else device_alignment - remaining
align_ = ali // align[Type.fp32.value]
self.offload_buffer_size += param._numel() + align_
self.offload_param2align[param.name] = align_
......@@ -528,11 +540,26 @@ class GroupShardedOptimizerStage2(Optimizer):
for param in self._local_params:
if param.name in self._master_params.keys():
param.set_value(
self._master_params[param.name]
.cuda(self.dev_id)
.cast(dtype=param.dtype)
)
if (
self._default_device
in core.get_all_custom_device_type()
):
param.set_value(
self._master_params[param.name]
._copy_to(
paddle.CustomPlace(
self._default_device, self.dev_id
),
True,
)
.cast(dtype=param.dtype)
)
else:
param.set_value(
self._master_params[param.name]
.cuda(self.dev_id)
.cast(dtype=param.dtype)
)
else:
self._optim.step()
......
......@@ -89,7 +89,10 @@ class GroupShardedStage3(nn.Layer):
super().__init__()
# Default configs
assert core.is_compiled_with_cuda(), "Only support CUDA."
assert core.is_compiled_with_cuda() or (
device in core.get_all_custom_device_type()
), "Only support CUDA / CustomDevice."
self._layer = layer
self._default_device = device
self.__sync_buffers = sync_buffers
......@@ -243,7 +246,15 @@ class GroupShardedStage3(nn.Layer):
else:
for param in list(self._unslice_params):
param.clear_gradient(False)
tmp_var = param.cuda(DEV_ID)
if (
self._default_device
in paddle.device.get_all_custom_device_type()
):
tmp_var = param._copy_to(
paddle.CustomPlace(self._default_device, DEV_ID), True
)
else:
tmp_var = param.cuda(DEV_ID)
if (
tmp_var.dtype == Type.fp32.value
......@@ -718,10 +729,14 @@ class GroupShardedStage3(nn.Layer):
def _param2align(self, param):
# CUDA alignment 256 bytes
size = param._numel() * align[param.dtype]
remaining = size % alignment[self._default_device]
ali = (
0 if remaining == 0 else alignment[self._default_device] - remaining
)
if self._default_device in core.get_all_custom_device_type():
device_alignment = core.libpaddle._get_device_min_chunk_size(
self._default_device
)
else:
device_alignment = alignment[self._default_device]
remaining = size % device_alignment
ali = 0 if remaining == 0 else device_alignment - remaining
align_ = ali // align[param.dtype]
return align_
......@@ -1095,7 +1110,10 @@ def _device2cpu(trans_param, convert_dtype=False):
def _cpu2device(param):
tmp_p = param.fw_storage.cuda(DEV_ID)
if DEV in paddle.device.get_all_custom_device_type():
tmp_p = param.fw_storage._copy_to(paddle.CustomPlace(DEV, DEV_ID), True)
else:
tmp_p = param.fw_storage.cuda(DEV_ID)
if (
tmp_p.dtype == Type.fp32.value
and param2dtype[param.name] == Type.fp16.value
......
......@@ -76,11 +76,16 @@ class InternalStorage:
), "Conversion type is not supported now"
if self._device != device:
tmp_buffer = (
cvt_to_device(self.buffer, self.dev_id)
if device in ["gpu", "xpu"]
else self.buffer.cpu()
)
if device in paddle.device.get_all_custom_device_type():
tmp_buffer = self.buffer._copy_to(
paddle.CustomPlace(device, self.dev_id), True
)
else:
tmp_buffer = (
cvt_to_device(self.buffer, self.dev_id)
if device in ["gpu", "xpu"]
else self.buffer.cpu()
)
for param in self._params:
param.clear_gradient(False)
......@@ -133,8 +138,13 @@ class ParamStorage(InternalStorage):
cpu_param_shape.append(p_shape)
if convert_gpu:
# buffer convert from cpu to cuda
self.buffer = cvt_to_device(self.buffer, self.dev_id)
if self._device in paddle.device.get_all_custom_device_type():
self.buffer = self.buffer._copy_to(
paddle.CustomPlace(self._device, self.dev_id), True
)
else:
# buffer convert from cpu to cuda
self.buffer = cvt_to_device(self.buffer, self.dev_id)
self._fill = 0
......
......@@ -162,8 +162,14 @@ class GroupShardedClipGrad:
# add all reduce to get global norm of distributed params_and_grads
dev_id = int(self._device.split(":")[1])
dev_type = self._device.split(':')[0]
if paddle.device.get_device() == "cpu":
global_norm_var = global_norm_var.cuda(dev_id)
if dev_type in paddle.device.get_all_custom_device_type():
global_norm_var = global_norm_var._copy_to(
paddle.CustomPlace(dev_type, dev_id), True
)
else:
global_norm_var = global_norm_var.cuda(dev_id)
with device_guard(dev_id, self._device.split(":")[0]):
paddle.distributed.all_reduce(global_norm_var, group=self._group)
......@@ -207,6 +213,8 @@ def device_guard(dev_id=0, device="cpu"):
paddle.set_device(device)
elif device in ["gpu", "xpu"]:
paddle.set_device(f"{device}:{dev_id}")
elif device in paddle.device.get_all_custom_device_type():
paddle.set_device(f"{device}:{dev_id}")
try:
yield
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册