From 145cc2625846f2e2645b4d95eea894e3b46a4bbb Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Fri, 25 Nov 2022 10:54:42 +0800 Subject: [PATCH] [XPU] Support Sharding stage2 on XPU (#48310) * support xpu scalar inplace * sharding for xpu Co-authored-by: heyanru <81976792+heyanru01@users.noreply.github.com> --- paddle/phi/api/lib/scalar.cc | 3 ++- .../sharding/group_sharded_optimizer_stage2.py | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/paddle/phi/api/lib/scalar.cc b/paddle/phi/api/lib/scalar.cc index 09b15f629d8..78207c7ae66 100644 --- a/paddle/phi/api/lib/scalar.cc +++ b/paddle/phi/api/lib/scalar.cc @@ -31,7 +31,8 @@ ScalarBase::ScalarBase(const Tensor& tensor_in) "now Tensor has `%d` elements", tensor_in.numel())); auto tensor_in_place = tensor_in.place().GetType(); - if (tensor_in_place == phi::AllocationType::GPU) { + if (tensor_in_place == phi::AllocationType::XPU || + tensor_in_place == phi::AllocationType::GPU) { Tensor dst_tensor; copy(tensor_in, phi::CPUPlace(), true, &dst_tensor); GetDataFromTensor(dst_tensor); diff --git a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py index 38b03225616..6a0a0b66cbe 100644 --- a/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py +++ b/python/paddle/distributed/fleet/meta_parallel/sharding/group_sharded_optimizer_stage2.py @@ -46,7 +46,7 @@ from .group_sharded_storage import ParamStorage, GradStorage from .group_sharded_utils import Type, device_guard, GroupShardedClipGrad # CUDA alignment 256 bytes, cpu alignment 4096 bytes -alignment = {"gpu": 256, "cpu": 4096} +alignment = {"gpu": 256, "cpu": 4096, "xpu": 256} align = { Type.fp16.value: 2, Type.bf16.value: 2, @@ -85,7 +85,9 @@ class GroupShardedOptimizerStage2(Optimizer): ): super().__init__(learning_rate=optim._learning_rate, parameters=params) - assert core.is_compiled_with_cuda(), "Only GPU is supported now" + assert ( + core.is_compiled_with_cuda() or core.is_compiled_with_xpu() + ), "Only GPU and XPU is supported now" # Segmentation information self._dtype_rank_params = ( -- GitLab