未验证 提交 c5ae21f4 编写于 作者: G gongweibao 提交者: GitHub

Fix bugs of pipeline on ascend. (#32737)

上级 f1c68a08
...@@ -639,7 +639,7 @@ class PSGPUWorker : public HogwildWorker { ...@@ -639,7 +639,7 @@ class PSGPUWorker : public HogwildWorker {
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
class SectionWorker : public DeviceWorker { class SectionWorker : public DeviceWorker {
public: public:
SectionWorker() {} SectionWorker() {}
......
...@@ -80,7 +80,7 @@ REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); ...@@ -80,7 +80,7 @@ REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker);
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
REGISTER_DEVICE_WORKER_CLASS(SectionWorker); REGISTER_DEVICE_WORKER_CLASS(SectionWorker);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/framework/data_feed_factory.h" #include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
...@@ -37,7 +37,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -37,7 +37,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
int place_id = section_config.place_id(); int place_id = section_config.place_id();
#if (defined PADDLE_WITH_NCCL) #if (defined PADDLE_WITH_NCCL)
place_ = platform::CUDAPlace(place_id); place_ = platform::CUDAPlace(place_id);
#elif (defined WITH_ASCEND_CL) // NOLINT #elif (defined PADDLE_WITH_ASCEND_CL) // NOLINT
place_ = platform::NPUPlace(place_id); place_ = platform::NPUPlace(place_id);
#endif #endif
worker_ = DeviceWorkerFactory::CreateDeviceWorker( worker_ = DeviceWorkerFactory::CreateDeviceWorker(
......
...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and ...@@ -10,7 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
#include <float.h> #include <float.h>
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
......
...@@ -332,7 +332,7 @@ class PSGPUTrainer : public TrainerBase { ...@@ -332,7 +332,7 @@ class PSGPUTrainer : public TrainerBase {
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(WITH_ASCEND_CL) defined(PADDLE_WITH_ASCEND_CL)
class PipelineTrainer : public TrainerBase { class PipelineTrainer : public TrainerBase {
public: public:
PipelineTrainer() {} PipelineTrainer() {}
......
...@@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterBoxTrainer); ...@@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterBoxTrainer);
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
REGISTER_TRAINER_CLASS(PSGPUTrainer); REGISTER_TRAINER_CLASS(PSGPUTrainer);
#endif #endif
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
defined(PADDLE_WITH_ASCEND_CL)
REGISTER_TRAINER_CLASS(PipelineTrainer); REGISTER_TRAINER_CLASS(PipelineTrainer);
#endif #endif
} // namespace framework } // namespace framework
......
...@@ -131,6 +131,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> { ...@@ -131,6 +131,7 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
int64_t numel = in->numel(); int64_t numel = in->numel();
void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>())); void* sendbuff = reinterpret_cast<void*>(const_cast<T*>(in->data<T>()));
out->mutable_data<T>(in->dims(), ctx.GetPlace());
void* recvbuff = reinterpret_cast<void*>(out->data<T>()); void* recvbuff = reinterpret_cast<void*>(out->data<T>());
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
......
...@@ -6124,9 +6124,9 @@ def device_guard(device=None): ...@@ -6124,9 +6124,9 @@ def device_guard(device=None):
device, index = device.split(':') device, index = device.split(':')
if device == 'cpu': if device == 'cpu':
raise ValueError("Should not set device id for cpu.") raise ValueError("Should not set device id for cpu.")
if device not in ['cpu', 'gpu', '', None]: if device not in ['cpu', 'gpu', 'npu', '', None]:
raise ValueError( raise ValueError(
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None " "The Attr(device) should be 'cpu' 'npu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device) "when there is no need to specify device. But received %s" % device)
if index: if index:
device = ":".join([device, index]) device = ":".join([device, index])
......
...@@ -4116,7 +4116,7 @@ class PipelineOptimizer(object): ...@@ -4116,7 +4116,7 @@ class PipelineOptimizer(object):
device = op.attr(self._op_device_key) \ device = op.attr(self._op_device_key) \
if op.has_attr(self._op_device_key) else None if op.has_attr(self._op_device_key) else None
if device: if device:
assert device[0:3] == 'gpu', "Now, only gpu devices are " \ assert device[0:3] == 'gpu' or dev_type == 'npu', "Now, only gpu and npu devices are " \
"supported in pipeline parallemism." "supported in pipeline parallemism."
return device return device
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册