未验证 提交 49677636 编写于 作者: A Aganlengzi 提交者: GitHub

[custom kernel] change kernel name judgement and remove macro control for selected_row (#39977)

上级 5471d162
...@@ -20,16 +20,16 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) { ...@@ -20,16 +20,16 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
auto& kernel_info_map = custom_kernel_map.GetMap(); auto& kernel_info_map = custom_kernel_map.GetMap();
VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size(); VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size();
auto& kernels = KernelFactory::Instance().kernels();
for (auto& pair : kernel_info_map) { for (auto& pair : kernel_info_map) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_NE(
KernelFactory::Instance().HasCompatiblePhiKernel(pair.first), kernels.find(pair.first),
true, kernels.end(),
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The kernel %s is not ready for custom kernel registering.", "The kernel %s is not ready for custom kernel registering.",
pair.first)); pair.first));
for (auto& info_pair : pair.second) { for (auto& info_pair : pair.second) {
auto& kernels = KernelFactory::Instance().kernels();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
kernels[pair.first].find(info_pair.first), kernels[pair.first].find(info_pair.first),
kernels[pair.first].end(), kernels[pair.first].end(),
......
...@@ -87,13 +87,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -87,13 +87,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) { } else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#endif
} else if (arg_type == std::type_index(typeid(DenseTensor*))) { } else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(default_key.backend(), args_def->AppendOutput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
...@@ -105,13 +103,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -105,13 +103,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(SelectedRows*))) { } else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(default_key.backend(), args_def->AppendOutput(default_key.backend(),
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
#endif
} else { } else {
// Attribute deal with // Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe // TODO(chenweihang): now here allow any types of attribute, maybe
......
...@@ -23,9 +23,7 @@ ...@@ -23,9 +23,7 @@
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_context.h" #include "paddle/phi/core/kernel_context.h"
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/selected_rows.h" #include "paddle/phi/core/selected_rows.h"
#endif
#include "paddle/phi/core/sparse_coo_tensor.h" #include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h" #include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/type_defs.h" #include "paddle/phi/core/type_defs.h"
...@@ -222,9 +220,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -222,9 +220,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
...@@ -259,9 +255,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> { ...@@ -259,9 +255,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);
......
...@@ -23,13 +23,6 @@ limitations under the License. */ ...@@ -23,13 +23,6 @@ limitations under the License. */
#include "paddle/utils/any.h" #include "paddle/utils/any.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
// Note: mixed_vector include many header now, LoD will be
// used on CUDA device? Can we use small_vector here?
// @zhanlve: Rollback to original LoD for now
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/fluid/framework/mixed_vector.h"
#endif
namespace phi { namespace phi {
using DDim = phi::DDim; using DDim = phi::DDim;
......
...@@ -146,12 +146,10 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -146,12 +146,10 @@ TEST(CustomKernel, custom_kernel_dot) {
custom_fake_dot_kernels.end()); custom_fake_dot_kernels.end());
// 3.before register // 3.before register
auto& kernel_factory_instance = phi::KernelFactory::Instance();
auto& kernels = phi::KernelFactory::Instance().kernels(); auto& kernels = phi::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePhiKernel(op_name)); EXPECT_TRUE(kernels.find(op_name) == kernels.end());
// mock fake_dot is supported by phi for HasCompatiblePhiKernel check while // mock fake_dot is supported by phi for check while registering
// registering
auto& fake_dot_kernels = kernels[op_name]; auto& fake_dot_kernels = kernels[op_name];
EXPECT_TRUE(fake_dot_kernels.find( EXPECT_TRUE(fake_dot_kernels.find(
...@@ -196,7 +194,7 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -196,7 +194,7 @@ TEST(CustomKernel, custom_kernel_dot) {
fake_dot_kernels.end()); fake_dot_kernels.end());
// 4.kernel select // 4.kernel select
auto kernel = kernel_factory_instance.SelectKernelOrThrowError( auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, phi::KernelKey(backend, layout, phi::DataType::UINT8)); op_name, phi::KernelKey(backend, layout, phi::DataType::UINT8));
// 5.prepare parameters for kernel // 5.prepare parameters for kernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册