未验证 提交 664199aa 编写于 作者: A Allen Guo 提交者: GitHub

fix running error for ipu (#41533)

cherry from #41481
上级 b5f6d311
...@@ -55,6 +55,8 @@ enum class Backend : uint8_t { ...@@ -55,6 +55,8 @@ enum class Backend : uint8_t {
// paddle kernel primitives backend // paddle kernel primitives backend
KPS, KPS,
IPU,
// end of backend types // end of backend types
NUM_BACKENDS, NUM_BACKENDS,
...@@ -121,6 +123,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) { ...@@ -121,6 +123,9 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::KPS: case Backend::KPS:
os << "KPS"; os << "KPS";
break; break;
case Backend::IPU:
os << "IPU";
break;
default: { default: {
size_t device_type_id_ = static_cast<size_t>(backend) - size_t device_type_id_ = static_cast<size_t>(backend) -
static_cast<size_t>(Backend::NUM_BACKENDS); static_cast<size_t>(Backend::NUM_BACKENDS);
...@@ -155,6 +160,8 @@ inline Backend StringToBackend(const char* backend_cstr) { ...@@ -155,6 +160,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::GPUDNN; return Backend::GPUDNN;
} else if (s == std::string("KPS")) { } else if (s == std::string("KPS")) {
return Backend::KPS; return Backend::KPS;
} else if (s == std::string("IPU")) {
return Backend::IPU;
} else { } else {
return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) + return static_cast<Backend>(static_cast<size_t>(Backend::NUM_BACKENDS) +
phi::GetOrRegisterGlobalDeviceTypeId(s)); phi::GetOrRegisterGlobalDeviceTypeId(s));
......
...@@ -38,6 +38,8 @@ Backend TransToPhiBackend(const phi::Place& place) { ...@@ -38,6 +38,8 @@ Backend TransToPhiBackend(const phi::Place& place) {
return Backend::XPU; return Backend::XPU;
} else if (allocation_type == phi::AllocationType::NPU) { } else if (allocation_type == phi::AllocationType::NPU) {
return Backend::NPU; return Backend::NPU;
} else if (allocation_type == phi::AllocationType::IPU) {
return Backend::IPU;
} else if (allocation_type == phi::AllocationType::CUSTOM) { } else if (allocation_type == phi::AllocationType::CUSTOM) {
return static_cast<Backend>( return static_cast<Backend>(
static_cast<size_t>(Backend::NUM_BACKENDS) + static_cast<size_t>(Backend::NUM_BACKENDS) +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册