未验证 提交 145cc262 编写于 作者: R Roc 提交者: GitHub

[XPU] Support Sharding stage2 on XPU (#48310)

* support xpu scalar inplace

* sharding for xpu
Co-authored-by: Nheyanru <81976792+heyanru01@users.noreply.github.com>
上级 db7d6808
......@@ -31,7 +31,8 @@ ScalarBase<Tensor>::ScalarBase(const Tensor& tensor_in)
"now Tensor has `%d` elements",
tensor_in.numel()));
auto tensor_in_place = tensor_in.place().GetType();
if (tensor_in_place == phi::AllocationType::GPU) {
if (tensor_in_place == phi::AllocationType::XPU ||
tensor_in_place == phi::AllocationType::GPU) {
Tensor dst_tensor;
copy(tensor_in, phi::CPUPlace(), true, &dst_tensor);
GetDataFromTensor(dst_tensor);
......
......@@ -46,7 +46,7 @@ from .group_sharded_storage import ParamStorage, GradStorage
from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad
# CUDA alignment 256 bytes, cpu alignment 4096 bytes
alignment = {"gpu": 256, "cpu": 4096}
alignment = {"gpu": 256, "cpu": 4096, "xpu": 256}
align = {
Type.fp16.value: 2,
Type.bf16.value: 2,
......@@ -85,7 +85,9 @@ class GroupShardedOptimizerStage2(Optimizer):
):
super().__init__(learning_rate=optim._learning_rate, parameters=params)
assert core.is_compiled_with_cuda(), "Only GPU is supported now"
assert (
core.is_compiled_with_cuda() or core.is_compiled_with_xpu()
), "Only GPU and XPU is supported now"
# Segmentation information
self._dtype_rank_params = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册