From c5ae21f43503382520badcbd78aad4d2148561f1 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Thu, 6 May 2021 11:16:27 +0800 Subject: [PATCH] Fix bugs of pipeline on ascend. (#32737) --- paddle/fluid/framework/device_worker.h | 2 +- paddle/fluid/framework/device_worker_factory.cc | 2 +- paddle/fluid/framework/pipeline_trainer.cc | 4 ++-- paddle/fluid/framework/section_worker.cc | 2 +- paddle/fluid/framework/trainer.h | 2 +- paddle/fluid/framework/trainer_factory.cc | 3 ++- paddle/fluid/operators/collective/c_allreduce_op.h | 1 + python/paddle/fluid/framework.py | 4 ++-- python/paddle/fluid/optimizer.py | 2 +- 9 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index a49e492e480..cd5de19bdc0 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -639,7 +639,7 @@ class PSGPUWorker : public HogwildWorker { #endif #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) class SectionWorker : public DeviceWorker { public: SectionWorker() {} diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index 5780a953433..fb2323d96e2 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -80,7 +80,7 @@ REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); #endif #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) REGISTER_DEVICE_WORKER_CLASS(SectionWorker); #endif } // namespace framework diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index cdd2dbd5b1d..75c42fa3e52 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -13,7 +13,7 @@ // limitations under the License. #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) #include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/trainer.h" @@ -37,7 +37,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, int place_id = section_config.place_id(); #if (defined PADDLE_WITH_NCCL) place_ = platform::CUDAPlace(place_id); -#elif (defined WITH_ASCEND_CL) // NOLINT +#elif (defined PADDLE_WITH_ASCEND_CL) // NOLINT place_ = platform::NPUPlace(place_id); #endif worker_ = DeviceWorkerFactory::CreateDeviceWorker( diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 7860b69313e..00ff50abadd 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -10,7 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) #include #include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/executor_gc_helper.h" diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 10f6c1ddbd0..3ac36bd2e4a 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -332,7 +332,7 @@ class PSGPUTrainer : public TrainerBase { #endif #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ - defined(WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) class PipelineTrainer : public TrainerBase { public: PipelineTrainer() {} diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index 6b9dbece897..15073b6f78c 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterBoxTrainer); (defined PADDLE_WITH_PSLIB) REGISTER_TRAINER_CLASS(PSGPUTrainer); #endif -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_ASCEND_CL) REGISTER_TRAINER_CLASS(PipelineTrainer); #endif } // namespace framework diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 0eaa377869e..3a74f551e7a 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -131,6 +131,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel { int64_t numel = in->numel(); void* sendbuff = reinterpret_cast(const_cast(in->data())); + out->mutable_data(in->dims(), ctx.GetPlace()); void* recvbuff = reinterpret_cast(out->data()); int ring_id = ctx.Attr("ring_id"); diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 0e9d756848a..2eac5adcf22 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6124,9 +6124,9 @@ def device_guard(device=None): device, index = device.split(':') if device == 'cpu': raise ValueError("Should not set device id for cpu.") - if device not in ['cpu', 'gpu', '', None]: + if device not in ['cpu', 'gpu', 'npu', '', None]: raise ValueError( - "The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None " + "The Attr(device) should be 'cpu' 'npu' or 'gpu', and it can also be empty string or None " "when there is no need to specify device. But received %s" % device) if index: device = ":".join([device, index]) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 4ae90b3c72c..41b2843ea33 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4116,7 +4116,7 @@ class PipelineOptimizer(object): device = op.attr(self._op_device_key) \ if op.has_attr(self._op_device_key) else None if device: - assert device[0:3] == 'gpu', "Now, only gpu devices are " \ + assert device[0:3] == 'gpu' or dev_type == 'npu', "Now, only gpu and npu devices are " \ "supported in pipeline parallemism." return device -- GitLab