未验证 提交 50d92531 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor,...

[CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device (#50042)

* [CustomDevice] registering feed_dense_tensor, feed_sparse_coo_tensor, feed_strings kernels for custom device

* update

* update

* update
上级 decbb588
......@@ -305,24 +305,15 @@ PD_REGISTER_GENERAL_KERNEL(
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
PD_REGISTER_GENERAL_KERNEL(
feed_dense_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_sparse_coo_tensor,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedSparseCooTensorKernel<phi::CustomContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
feed_strings,
custom_cpu,
ALL_LAYOUT,
paddle::operators::FeedStringsKernel<phi::CustomContext>,
ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
namespace paddle {
namespace operators {
template void FeedDenseTensorKernel<phi::CustomContext>(
const phi::CustomContext& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
} // namespace operators
} // namespace paddle
#endif
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/run_program_op.h"
#include "paddle/fluid/operators/save_combine_op.h"
#include "paddle/phi/backends/device_manager.h"
#include "paddle/phi/core/kernel_registry.h"
#define REGISTER_OP_CUSTOM_DEVICE_KERNEL(op_type, dev_type, ...) \
static paddle::framework::OpKernelRegistrar<phi::CustomPlace, __VA_ARGS__> \
......@@ -26,10 +27,30 @@ limitations under the License. */
paddle::framework::OpKernelType::kDefaultCustomizedTypeValue); \
__op_custom_device_kernel_registrar_##op_type##_##__acosf##__.Touch();
#define REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL( \
kernel_name, dev_type, layout, kernel_fn) \
static phi::KernelRegistrar \
__reg_custom_device_phi_kernel_##kernel_name##_##backend##_##layout( \
phi::RegType::INNER, \
#kernel_name, \
dev_type, \
DATALAYOUT(layout), \
::phi::KernelArgsParseFunctor<decltype(&kernel_fn)>::Parse, \
[](const phi::KernelKey& kernel_key, phi::Kernel* kernel) {}, \
PHI_KERNEL(kernel_fn), \
PHI_VARIADIC_KERNEL(kernel_fn))
namespace paddle {
namespace operators {
template <typename Context>
void FeedDenseTensorKernel(const Context& dev_ctx,
const phi::ExtendedTensor& x,
int col,
phi::DenseTensor* out);
void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
auto device_type = dev_type.c_str();
/* see [Why use single type kernel] */
REGISTER_OP_CUSTOM_DEVICE_KERNEL(
......@@ -66,9 +87,16 @@ void RegisterCustomDeviceCommonKernel(const std::string& dev_type) {
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int8_t>,
paddle::operators::
LoadCombineOpKernel<paddle::platform::CustomDeviceContext, int64_t>);
REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL(
feed_dense_tensor,
device_type,
ALL_LAYOUT,
paddle::operators::FeedDenseTensorKernel<phi::CustomContext>);
#endif
}
} // namespace operators
} // namespace paddle
#undef REGISTER_OP_CUSTOM_DEVICE_KERNEL
#undef REGISTER_CUSTOM_DEVICE_GENERAL_KERNEL
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册