未验证 提交 6a1957e7 编写于 作者: B Baibaifan 提交者: GitHub

slove develop bugs (#32560) (#32684)

上级 2c1ed9b8
...@@ -63,7 +63,6 @@ class CSyncCommStreamCudaKernel : public framework::OpKernel<T> { ...@@ -63,7 +63,6 @@ class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
int ring_id = ctx.Attr<int>("ring_id"); int ring_id = ctx.Attr<int>("ring_id");
auto stream = auto stream =
platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); platform::NCCLCommContext::Instance().Get(ring_id, place)->stream();
...@@ -75,7 +74,6 @@ class CSyncCommStreamCudaKernel : public framework::OpKernel<T> { ...@@ -75,7 +74,6 @@ class CSyncCommStreamCudaKernel : public framework::OpKernel<T> {
#endif #endif
#elif defined(PADDLE_WITH_ASCEND_CL) #elif defined(PADDLE_WITH_ASCEND_CL)
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(is_npu_place(place), true, PADDLE_ENFORCE_EQ(is_npu_place(place), true,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Sync stream op can run on npu place only for now.")); "Sync stream op can run on npu place only for now."));
......
...@@ -108,12 +108,14 @@ enum AttrType { ...@@ -108,12 +108,14 @@ enum AttrType {
AT_NAMEATTR AT_NAMEATTR
}; };
#ifdef PADDLE_WITH_ASCEND
void BindAscendDevice(py::module *m) { void BindAscendDevice(py::module *m) {
py::class_<platform::ascend::NPUDevice>(*m, "NPUDevice") py::class_<platform::ascend::NPUDevice>(*m, "NPUDevice")
.def_static( .def_static(
"get_device_count", "get_device_count",
static_cast<int (*)()>(&platform::ascend::NPUDevice::GetDeviceCount)); static_cast<int (*)()>(&platform::ascend::NPUDevice::GetDeviceCount));
} }
#endif
void BindAscendGraph(py::module *m) { void BindAscendGraph(py::module *m) {
m->def("ge_initialize", &ge_initialize, "GEInitialize"); m->def("ge_initialize", &ge_initialize, "GEInitialize");
......
...@@ -325,8 +325,8 @@ def which_distributed_mode(args): ...@@ -325,8 +325,8 @@ def which_distributed_mode(args):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
accelerators = fluid.core.get_cuda_device_count() accelerators = fluid.core.get_cuda_device_count()
elif fluid.core.is_compiled_with_ascend(): elif fluid.core.is_compiled_with_npu():
accelerators = fluid.core.NPUDevice.get_device_count() accelerators = fluid.core.get_npu_device_count()
elif fluid.core.is_compiled_with_xpu(): elif fluid.core.is_compiled_with_xpu():
accelerators = fluid.core.get_xpu_device_count() accelerators = fluid.core.get_xpu_device_count()
else: else:
......
...@@ -653,8 +653,8 @@ def get_xpus(xpus): ...@@ -653,8 +653,8 @@ def get_xpus(xpus):
def get_device_mode(): def get_device_mode():
if fluid.core.is_compiled_with_ascend() and \ if fluid.core.is_compiled_with_npu() and \
fluid.core.NPUDevice.get_device_count() > 0: fluid.core.get_npu_device_count() > 0:
print("launch train in ascend npu mode!") print("launch train in ascend npu mode!")
return DeviceMode.ASCEND_NPU return DeviceMode.ASCEND_NPU
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册