未验证 提交 63d2333e 编写于 作者: A Aganlengzi 提交者: GitHub

[PluggableDevice] custom kernel supports multi cpp_dtype registering (#39385)

上级 2a5d858c
...@@ -35,13 +35,12 @@ limitations under the License. */ ...@@ -35,13 +35,12 @@ limitations under the License. */
// user kernel function // user kernel function
namespace custom_kernel { namespace custom_kernel {
// Here we use dot <CPU, ANY, UINT8> for test // Here we use fake_dot for test
// This test will fail when these two kernels are aupported in framework
// input 3: two Tensors and one std::vector<Tensor> // input 3: two Tensors and one std::vector<Tensor>
// attribute 11: fake_attributes // attribute 11: fake_attributes
// output 2: one Tensor* and one std::vector<Tensor*> // output 2: one Tensor* and one std::vector<Tensor*>
template <typename T> template <typename T, typename Context>
void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, void FakeDot(const Context& dev_ctx, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
const std::vector<paddle::Tensor>& fake_input_vec, const std::vector<paddle::Tensor>& fake_input_vec,
bool fake_attr_bool, int fake_attr_int, float fake_attr_float, bool fake_attr_bool, int fake_attr_int, float fake_attr_float,
...@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x, ...@@ -93,53 +92,91 @@ void FakeDot(const paddle::CPUContext& dev_ctx, const paddle::Tensor& x,
} }
} // namespace custom_kernel } // namespace custom_kernel
PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, UINT8, PD_REGISTER_KERNEL(fake_dot, CPU, ALL_LAYOUT, custom_kernel::FakeDot, float,
custom_kernel::FakeDot<uint8_t>) { double, int, int64_t, int8_t, uint8_t) {}
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UINT8);
}
// Upper code will store dot kernels info into OpKernelInfoMap // Upper code will store dot kernels info into OpKernelInfoMap
TEST(CustomKernel, custom_kernel_dot) { TEST(CustomKernel, custom_kernel_dot) {
std::string op_name = "dot"; std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU; pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY; pten::DataLayout layout = pten::DataLayout::ALL_LAYOUT;
pten::DataType dtype = pten::DataType::UINT8;
// 1.custom kernel info parsed and store // 1.custom kernel info parsed and store
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find("dot") != EXPECT_TRUE(paddle::OpKernelInfoMap::Instance().GetMap().find(op_name) !=
paddle::OpKernelInfoMap::Instance().GetMap().end()); paddle::OpKernelInfoMap::Instance().GetMap().end());
// 2.info check // 2.info check
EXPECT_EQ( EXPECT_EQ(
1, static_cast<int>(paddle::OpKernelInfoMap::Instance()["dot"].size())); 6, static_cast<int>(paddle::OpKernelInfoMap::Instance()[op_name].size()));
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetBackend() == // index 0
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetBackend() ==
backend); backend);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataLayout() == EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataLayout() ==
layout); layout);
EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()["dot"][0].GetDataType() == EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][0].GetDataType() ==
dtype); pten::DataType::FLOAT32);
// index 5
// 3.register EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetBackend() ==
EXPECT_TRUE(pten::KernelFactory::Instance().kernels().end() != backend);
pten::KernelFactory::Instance().kernels().find("dot")); EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataLayout() ==
layout);
pten::KernelKey kernel_key(backend, layout, dtype); EXPECT_TRUE(paddle::OpKernelInfoMap::Instance()[op_name][5].GetDataType() ==
EXPECT_TRUE( pten::DataType::UINT8);
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) ==
pten::KernelFactory::Instance().kernels()["dot"].end()); // 3.before register
auto& kernel_factory_instance = pten::KernelFactory::Instance();
auto& kernels = pten::KernelFactory::Instance().kernels();
EXPECT_TRUE(!kernel_factory_instance.HasCompatiblePtenKernel(op_name));
// mock fake_dot is supported by pten for HasCompatiblePtenKernel check while
// registering
auto& fake_dot_kernels = kernels[op_name];
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) ==
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) ==
fake_dot_kernels.end());
// register
paddle::framework::RegisterKernelWithMetaInfoMap( paddle::framework::RegisterKernelWithMetaInfoMap(
paddle::OpKernelInfoMap::Instance()); paddle::OpKernelInfoMap::Instance());
EXPECT_TRUE( EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelFactory::Instance().kernels()["dot"].find(kernel_key) != pten::KernelKey(backend, layout, pten::DataType::FLOAT32)) !=
pten::KernelFactory::Instance().kernels()["dot"].end()); fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::FLOAT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT32)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT64)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::INT8)) !=
fake_dot_kernels.end());
EXPECT_TRUE(fake_dot_kernels.find(
pten::KernelKey(backend, layout, pten::DataType::UINT8)) !=
fake_dot_kernels.end());
// 4.kernel select // 4.kernel select
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( auto kernel = kernel_factory_instance.SelectKernelOrThrowError(
op_name, kernel_key); op_name, pten::KernelKey(backend, layout, pten::DataType::UINT8));
// 5.prepare parameters for kernel // 5.prepare parameters for kernel
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>( const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
...@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) { ...@@ -252,10 +289,10 @@ TEST(CustomKernel, custom_kernel_dot) {
// test OpKernelInfoHelper // test OpKernelInfoHelper
TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper; using OpKernelInfoHelper = paddle::framework::OpKernelInfoHelper;
std::string op_name = "dot"; std::string op_name = "fake_dot";
pten::Backend backend = pten::Backend::CPU; pten::Backend backend = pten::Backend::CPU;
pten::DataLayout layout = pten::DataLayout::ANY; pten::DataLayout layout = pten::DataLayout::ANY;
pten::DataType dtype = pten::DataType::UINT8; pten::DataType dtype = pten::DataType::FLOAT32;
auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0]; auto op_kernel_info = paddle::OpKernelInfoMap::Instance()[op_name][0];
...@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) { ...@@ -268,10 +305,11 @@ TEST(OpKernelInfoHelper, op_kernel_info_help_getters) {
OpKernelInfoHelper::GetKernelKey(op_kernel_info)); OpKernelInfoHelper::GetKernelKey(op_kernel_info));
paddle::CustomKernelFunc kernel_fn = paddle::CustomKernelFunc kernel_fn =
PD_PT_KERNEL(custom_kernel::FakeDot<uint8_t>); PD_PT_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info)); EXPECT_EQ(kernel_fn, OpKernelInfoHelper::GetKernelFn(op_kernel_info));
void* variadic_func = PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<uint8_t>); void* variadic_func =
PD_PT_VARIADIC_KERNEL(custom_kernel::FakeDot<float, paddle::CPUContext>);
EXPECT_EQ(variadic_func, EXPECT_EQ(variadic_func,
OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info)); OpKernelInfoHelper::GetVariadicKernelFn(op_kernel_info));
......
...@@ -20,8 +20,8 @@ namespace custom_kernel { ...@@ -20,8 +20,8 @@ namespace custom_kernel {
// Here we use dot <CPU, ANY, INT8> for test // Here we use dot <CPU, ANY, INT8> for test
// This test will fail when this kernel is supported in framework // This test will fail when this kernel is supported in framework
template <typename T> template <typename T, typename Context>
void Dot(const paddle::CPUContext& dev_ctx, void Dot(const Context& dev_ctx,
const paddle::Tensor& x, const paddle::Tensor& x,
const paddle::Tensor& y, const paddle::Tensor& y,
paddle::Tensor* out) { paddle::Tensor* out) {
...@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx, ...@@ -45,9 +45,6 @@ void Dot(const paddle::CPUContext& dev_ctx,
} // namespace custom_kernel } // namespace custom_kernel
} // namespace paddle } // namespace paddle
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(dot, CPU, ALL_LAYOUT, paddle::custom_kernel::Dot, int8_t) {
dot, CPU, ALL_LAYOUT, INT8, paddle::custom_kernel::Dot<int8_t>) {
/* do some args define here
* the only param can be used is OpKernelInfo* kernel */
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT8);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册