diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 1ce9db6294050bd9b8ef9f2f088fa9c3ee1d7a2a..ae02f4fbfb822c8b9977e9aaf5009d3e5f31128e 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -54,6 +54,10 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) { execution_strategy.num_threads_ = 1; break; } + case platform::DeviceType::CUSTOM_DEVICE: { + execution_strategy.num_threads_ = 1; + break; + } default: PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.", device_type)); diff --git a/paddle/fluid/framework/op_registry.h b/paddle/fluid/framework/op_registry.h index 2befc70b2d5ed5f93f86faf939842dee6f575cf7..535480602916a7d18c80e212138908fa88333bdf 100644 --- a/paddle/fluid/framework/op_registry.h +++ b/paddle/fluid/framework/op_registry.h @@ -171,6 +171,17 @@ inline void RegisterKernelClass(const char* op_type, if (library == "MKLDNN") { data_layout = "MKLDNNLAYOUT"; } +#ifdef PADDLE_WITH_CUSTOM_DEVICE + if (std::is_same::value) { + OpKernelType key(ToDataType(std::type_index(typeid(T))), + platform::CustomPlace(library_type), + StringToDataLayout(data_layout), + LibraryType::kPlain, + customized_type_value); + OperatorWithKernel::AllOpKernels()[op_type][key] = func; + return; + } +#endif OpKernelType key(ToDataType(std::type_index(typeid(T))), PlaceType(), StringToDataLayout(data_layout), diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 39faf87406d5899b6bbdb69de18dd665557d776b..d985baf8c9088b3a07ddcd37dba0b1db8c34e29f 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -254,3 +254,7 @@ cc_test(copy_cross_scope_test SRCS copy_cross_scope_test.cc DEPS op_registry cop endif() copy_if_different(${pybind_file} ${pybind_file_final}) + +if (WITH_CUSTOM_DEVICE) +cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator) +endif() diff --git a/paddle/fluid/operators/custom_device_common_op_registry.cc b/paddle/fluid/operators/custom_device_common_op_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..704d85acf13621fd33b852176c59eddd2b72f4ce --- /dev/null +++ b/paddle/fluid/operators/custom_device_common_op_registry.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/custom_device_common_op_registry.h" +#include "paddle/fluid/operators/run_program_op.h" +#include "paddle/fluid/operators/save_combine_op.h" +#include "paddle/phi/backends/device_manager.h" + +#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \ + static paddle::framework::OpKernelRegistrar \ + __op_custom_device_kernel_registrar_##op_type##_##__acosf##__( \ + #op_type, \ + dev_type, \ + paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \ + __op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch(); + +namespace paddle { +namespace operators { + +void RegisterCustomDeviceCommonKernel(const std::string& dev_type) { + auto device_type = dev_type.c_str(); + /* see [Why use single type kernel] */ + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + run_program, + device_type, + paddle::operators:: + RunProgramOpKernel); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + run_program_grad, + device_type, + paddle::operators :: + RunProgramGradOpKernel); + REGISTER_OP_CUSTOM_DEVICE_KERNEL( + save_combine, + device_type, + paddle::operators :: + SaveCombineOpKernel, + paddle::operators :: + SaveCombineOpKernel, + paddle::operators :: + SaveCombineOpKernel, + paddle::operators :: + SaveCombineOpKernel); +} + +} // namespace operators +} // namespace paddle + +#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL diff --git a/paddle/fluid/operators/custom_device_common_op_registry.h b/paddle/fluid/operators/custom_device_common_op_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..421c745c536845bcfe865513c862e732306d42d0 --- /dev/null +++ b/paddle/fluid/operators/custom_device_common_op_registry.h @@ -0,0 +1,29 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_WITH_CUSTOM_DEVICE + +#include + +namespace paddle { +namespace operators { + +void RegisterCustomDeviceCommonKernel(const std::string& device_type); + +} // namespace operators +} // namespace paddle + +#endif diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index e2fec11c190d3357f8a2df56153aea5e5bb9f276..f7c715d7905a52d9f24077e42db9ccabb1a0cfd3 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -65,6 +65,8 @@ DeviceType Place2DeviceType(const platform::Place& place) { return platform::DeviceType::NPU; } else if (platform::is_mlu_place(place)) { return platform::DeviceType::MLU; + } else if (platform::is_custom_place(place)) { + return platform::DeviceType::CUSTOM_DEVICE; } else { PADDLE_THROW(platform::errors::Unavailable( "Unsupported place %s to convert into platform::DeviceType.", place)); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 4b8833f9a6cd6d6478ad03dd05db40c7c5a7a2c4..7939f8ff7c066062036d31f709d9ac0b5d0e768a 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -117,8 +117,9 @@ enum DeviceType { XPU = 3, IPU = 4, MLU = 5, + CUSTOM_DEVICE = 6, - MAX_DEVICE_TYPES = 6, + MAX_DEVICE_TYPES = 7, }; DeviceType Place2DeviceType(const platform::Place& place); @@ -129,6 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU; constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kMLU = DeviceType::MLU; +constexpr DeviceType kCUSOTM_DEVICE = DeviceType::CUSTOM_DEVICE; using DeviceContext = phi::DeviceContext; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 6aa2e7394c90bd54a526a276ce0407342503f5cb..b784affc07e7ec5d278541bd066db281e4ca7ec5 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -137,6 +137,7 @@ set(PYBIND_SRCS if(WITH_CUSTOM_DEVICE) set(PYBIND_DEPS ${PYBIND_DEPS} phi_capi) + set(PYBIND_DEPS ${PYBIND_DEPS} custom_device_common_op_registry) endif() if(NOT ON_INFER) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 3e079531f89bf299fa56a29f2441c0800931d9b6..b71595373480e54caffc1e2ea005ad23c85b101c 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -155,6 +155,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_CUSTOM_DEVICE +#include "paddle/fluid/operators/custom_device_common_op_registry.h" #include "paddle/phi/capi/capi.h" #endif @@ -1694,7 +1695,14 @@ All parameter, weight, gradient are variables in Paddle. egr::Controller::Instance().MergeOpMetaInfoMap( framework::LoadOpMetaInfoAndRegisterOp(dso_name)); }); - m.def("init_devices", []() { framework::InitDevices(); }); + m.def("init_devices", []() { + framework::InitDevices(); +#ifdef PADDLE_WITH_CUSTOM_DEVICE + for (auto &dev_type : phi::DeviceManager::GetAllCustomDeviceTypes()) { + paddle::operators::RegisterCustomDeviceCommonKernel(dev_type); + } +#endif + }); m.def("init_default_kernel_signatures", []() { framework::InitDefaultKernelSignatureMap(); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);