diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 87201c93c752e94fc4418cc92ce111e2d4a28e34..c34e727486bf29c088b035f6a327684c875a5f8f 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -41,6 +41,9 @@ limitations under the License. */ #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/utils/any.h" +#ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/phi/backends/device_manager.h" +#endif namespace paddle { namespace framework { @@ -712,6 +715,19 @@ static void RegisterOperatorKernel(const std::string& name, RegisterOperatorKernelWithPlace( name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace()); #endif +#ifdef PADDLE_WITH_CUSTOM_DEVICE + auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes(); + for (const auto& dev_type : device_types) { + for (size_t dev_id = 0; + dev_id < phi::DeviceManager::GetDeviceCount(dev_type); + dev_id++) { + RegisterOperatorKernelWithPlace(name, + op_kernel_func, + proto::VarType::RAW, + platform::CustomPlace(dev_type, dev_id)); + } + } +#endif } void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos,