未验证 提交 85e92421 编写于 作者: H HongyuJia 提交者: GitHub

support custom_device (#49222)

上级 047cd95c
...@@ -41,6 +41,9 @@ limitations under the License. */ ...@@ -41,6 +41,9 @@ limitations under the License. */
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/phi/backends/device_manager.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -712,6 +715,19 @@ static void RegisterOperatorKernel(const std::string& name, ...@@ -712,6 +715,19 @@ static void RegisterOperatorKernel(const std::string& name,
RegisterOperatorKernelWithPlace( RegisterOperatorKernelWithPlace(
name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace()); name, op_kernel_func, proto::VarType::RAW, platform::XPUPlace());
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_types = phi::DeviceManager::GetAllCustomDeviceTypes();
for (const auto& dev_type : device_types) {
for (size_t dev_id = 0;
dev_id < phi::DeviceManager::GetDeviceCount(dev_type);
dev_id++) {
RegisterOperatorKernelWithPlace(name,
op_kernel_func,
proto::VarType::RAW,
platform::CustomPlace(dev_type, dev_id));
}
}
#endif
} }
void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册