未验证 提交 abc85c50 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add dy2static support (#45878)

* [CustomDevice] add dy2static support

* update
上级 d0096eaf
...@@ -54,6 +54,10 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) { ...@@ -54,6 +54,10 @@ static ExecutionStrategy GetExecutionStrategy(const platform::Place &place) {
execution_strategy.num_threads_ = 1; execution_strategy.num_threads_ = 1;
break; break;
} }
case platform::DeviceType::CUSTOM_DEVICE: {
execution_strategy.num_threads_ = 1;
break;
}
default: default:
PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.", PADDLE_THROW(platform::errors::Unavailable("Unsupported Device type %d.",
device_type)); device_type));
......
...@@ -171,6 +171,17 @@ inline void RegisterKernelClass(const char* op_type, ...@@ -171,6 +171,17 @@ inline void RegisterKernelClass(const char* op_type,
if (library == "MKLDNN") { if (library == "MKLDNN") {
data_layout = "MKLDNNLAYOUT"; data_layout = "MKLDNNLAYOUT";
} }
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (std::is_same<PlaceType, platform::CustomPlace>::value) {
OpKernelType key(ToDataType(std::type_index(typeid(T))),
platform::CustomPlace(library_type),
StringToDataLayout(data_layout),
LibraryType::kPlain,
customized_type_value);
OperatorWithKernel::AllOpKernels()[op_type][key] = func;
return;
}
#endif
OpKernelType key(ToDataType(std::type_index(typeid(T))), OpKernelType key(ToDataType(std::type_index(typeid(T))),
PlaceType(), PlaceType(),
StringToDataLayout(data_layout), StringToDataLayout(data_layout),
......
...@@ -254,3 +254,7 @@ cc_test(copy_cross_scope_test SRCS copy_cross_scope_test.cc DEPS op_registry cop ...@@ -254,3 +254,7 @@ cc_test(copy_cross_scope_test SRCS copy_cross_scope_test.cc DEPS op_registry cop
endif() endif()
copy_if_different(${pybind_file} ${pybind_file_final}) copy_if_different(${pybind_file} ${pybind_file_final})
if (WITH_CUSTOM_DEVICE)
cc_library(custom_device_common_op_registry SRCS custom_device_common_op_registry.cc DEPS operator)
endif()
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h"
#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \
static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__( \
#op_type, \
dev_type, \
paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch();
namespace paddle {
namespace operators {
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
auto device_type = dev_type.c_str();
/* see [Why use single type kernel] */
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
run_program,
device_type,
paddle::operators::
RunProgramOpKernel<paddle::platform::CustomDeviceContext, float>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
run_program_grad,
device_type,
paddle::operators ::
RunProgramGradOpKernel<paddle::platform::CustomDeviceContext, float>);
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
save_combine,
device_type,
paddle::operators ::
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, float>,
paddle::operators ::
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, double>,
paddle::operators ::
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int>,
paddle::operators ::
SaveCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
}
} // namespace operators
} // namespace paddle
#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_CUSTOM_DEVICE
#include <string>
namespace paddle {
namespace operators {
void RegisterCustomDeviceCommonKernel(const std::string& device_type);
} // namespace operators
} // namespace paddle
#endif
...@@ -65,6 +65,8 @@ DeviceType Place2DeviceType(const platform::Place& place) { ...@@ -65,6 +65,8 @@ DeviceType Place2DeviceType(const platform::Place& place) {
return platform::DeviceType::NPU; return platform::DeviceType::NPU;
} else if (platform::is_mlu_place(place)) { } else if (platform::is_mlu_place(place)) {
return platform::DeviceType::MLU; return platform::DeviceType::MLU;
} else if (platform::is_custom_place(place)) {
return platform::DeviceType::CUSTOM_DEVICE;
} else { } else {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(platform::errors::Unavailable(
"Unsupported place %s to convert into platform::DeviceType.", place)); "Unsupported place %s to convert into platform::DeviceType.", place));
......
...@@ -117,8 +117,9 @@ enum DeviceType { ...@@ -117,8 +117,9 @@ enum DeviceType {
XPU = 3, XPU = 3,
IPU = 4, IPU = 4,
MLU = 5, MLU = 5,
CUSTOM_DEVICE = 6,
MAX_DEVICE_TYPES = 6, MAX_DEVICE_TYPES = 7,
}; };
DeviceType Place2DeviceType(const platform::Place& place); DeviceType Place2DeviceType(const platform::Place& place);
...@@ -129,6 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU; ...@@ -129,6 +130,7 @@ constexpr DeviceType kXPU = DeviceType::XPU;
constexpr DeviceType kNPU = DeviceType::NPU; constexpr DeviceType kNPU = DeviceType::NPU;
constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU; constexpr DeviceType kMLU = DeviceType::MLU;
constexpr DeviceType kCUSOTM_DEVICE = DeviceType::CUSTOM_DEVICE;
using DeviceContext = phi::DeviceContext; using DeviceContext = phi::DeviceContext;
......
...@@ -137,6 +137,7 @@ set(PYBIND_SRCS ...@@ -137,6 +137,7 @@ set(PYBIND_SRCS
if(WITH_CUSTOM_DEVICE) if(WITH_CUSTOM_DEVICE)
set(PYBIND_DEPS ${PYBIND_DEPS} phi_capi) set(PYBIND_DEPS ${PYBIND_DEPS} phi_capi)
set(PYBIND_DEPS ${PYBIND_DEPS} custom_device_common_op_registry)
endif() endif()
if(NOT ON_INFER) if(NOT ON_INFER)
......
...@@ -155,6 +155,7 @@ limitations under the License. */ ...@@ -155,6 +155,7 @@ limitations under the License. */
#endif #endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
#include "paddle/fluid/operators/custom_device_common_op_registry.h"
#include "paddle/phi/capi/capi.h" #include "paddle/phi/capi/capi.h"
#endif #endif
...@@ -1694,7 +1695,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1694,7 +1695,14 @@ All parameter, weight, gradient are variables in Paddle.
egr::Controller::Instance().MergeOpMetaInfoMap( egr::Controller::Instance().MergeOpMetaInfoMap(
framework::LoadOpMetaInfoAndRegisterOp(dso_name)); framework::LoadOpMetaInfoAndRegisterOp(dso_name));
}); });
m.def("init_devices", []() { framework::InitDevices(); }); m.def("init_devices", []() {
framework::InitDevices();
#ifdef PADDLE_WITH_CUSTOM_DEVICE
for (auto &dev_type : phi::DeviceManager::GetAllCustomDeviceTypes()) {
paddle::operators::RegisterCustomDeviceCommonKernel(dev_type);
}
#endif
});
m.def("init_default_kernel_signatures", m.def("init_default_kernel_signatures",
[]() { framework::InitDefaultKernelSignatureMap(); }); []() { framework::InitDefaultKernelSignatureMap(); });
m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册