未验证 提交 49677636 编写于 作者: A Aganlengzi 提交者: GitHub

[custom kernel] change kernel name judgement and remove macro control for selected_row (#39977)

上级 5471d162
......@@ -20,16 +20,16 @@ void RegisterCustomKernels(const CustomKernelMap& custom_kernel_map) {
auto& kernel_info_map = custom_kernel_map.GetMap();
VLOG(3) << "Size of custom_kernel_map: " << kernel_info_map.size();
auto& kernels = KernelFactory::Instance().kernels();
for (auto& pair : kernel_info_map) {
PADDLE_ENFORCE_EQ(
KernelFactory::Instance().HasCompatiblePhiKernel(pair.first),
true,
PADDLE_ENFORCE_NE(
kernels.find(pair.first),
kernels.end(),
phi::errors::InvalidArgument(
"The kernel %s is not ready for custom kernel registering.",
pair.first));
for (auto& info_pair : pair.second) {
auto& kernels = KernelFactory::Instance().kernels();
PADDLE_ENFORCE_EQ(
kernels[pair.first].find(info_pair.first),
kernels[pair.first].end(),
......
......@@ -87,13 +87,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(const SelectedRows&))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
#endif
} else if (arg_type == std::type_index(typeid(DenseTensor*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
......@@ -105,13 +103,11 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout,
default_key.dtype(),
arg_type);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
} else if (arg_type == std::type_index(typeid(SelectedRows*))) {
args_def->AppendOutput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
#endif
} else {
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
......
......@@ -23,9 +23,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_context.h"
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/phi/core/selected_rows.h"
#endif
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/type_defs.h"
......@@ -222,9 +220,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SparseCooTensor);
......@@ -259,9 +255,7 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(DenseTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(DenseTensor);
#ifndef PADDLE_WITH_CUSTOM_KERNEL
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SelectedRows);
#endif
PT_SPECIALIZE_KernelCallHelper_FOR_OUTPUT(SparseCooTensor);
PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_OUTPUT(SparseCooTensor);
......
......@@ -23,13 +23,6 @@ limitations under the License. */
#include "paddle/utils/any.h"
#include "paddle/utils/optional.h"
// Note: mixed_vector include many header now, LoD will be
// used on CUDA device? Can we use small_vector here?
// @zhanlve: Rollback to original LoD for now
#ifndef PADDLE_WITH_CUSTOM_KERNEL
#include "paddle/fluid/framework/mixed_vector.h"
#endif
namespace phi {
using DDim = phi::DDim;
......
......@@ -146,12 +146,10 @@ TEST(CustomKernel, custom_kernel_dot) {
custom_fake_dot_kernels.end());
// 3.before register
auto& kernel_factory_instance = phi::KernelFactory::Instance();
auto& kernels = phi::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePhiKernel(op_name));
EXPECT_TRUE(kernels.find(op_name) == kernels.end());
// mock fake_dot is supported by phi for HasCompatiblePhiKernel check while
// registering
// mock fake_dot is supported by phi for check while registering
auto& fake_dot_kernels = kernels[op_name];
EXPECT_TRUE(fake_dot_kernels.find(
......@@ -196,7 +194,7 @@ TEST(CustomKernel, custom_kernel_dot) {
fake_dot_kernels.end());
// 4.kernel select
auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
auto kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
op_name, phi::KernelKey(backend, layout, phi::DataType::UINT8));
// 5.prepare parameters for kernel
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册