From 07dad6d6ec415758d520e33960a0c53e50ef2ab5 Mon Sep 17 00:00:00 2001 From: huzhiqiang <912790387@qq.com> Date: Wed, 2 Mar 2022 02:16:04 -0600 Subject: [PATCH] [Infrt]add phi kernel dialect (#39726) --- .gitignore | 3 + .../pybind/kernel_signature_generator.cc | 26 +- paddle/infrt/dialect/infrt/common_type.h | 18 +- paddle/infrt/dialect/infrt/infrt_ops_base.td | 7 +- paddle/infrt/dialect/init_infrt_dialects.cc | 4 + paddle/infrt/dialect/phi/ir/CMakeLists.txt | 7 +- .../infrt/dialect/phi/ir/infrt_phi_kernel.td | 24 +- .../infrt/dialect/phi/ir/infrt_phi_tensor.td | 11 +- paddle/infrt/dialect/phi/ir/phi_kernels.cc | 44 +++ paddle/infrt/dialect/phi/ir/phi_kernels.h | 42 +++ .../infrt/dialect/phi/pass/kernel_op_desc.cc | 45 ++- paddle/infrt/host_context/mlir_exec.cc | 2 + paddle/infrt/kernel/phi/context_kernels.cc | 8 +- paddle/infrt/kernel/phi/context_kernels.h | 3 +- .../infrt/kernel/phi/dense_tensor_kernels.cc | 34 ++- .../infrt/kernel/phi/dense_tensor_kernels.h | 3 +- .../infershaped/infershape_launchers_test.cc | 2 +- paddle/infrt/kernel/phi/registry.cc | 2 + .../tests/dialect/pten/dense_tensor.mlir | 12 +- paddle/scripts/infrt_build.sh | 4 +- tools/infrt/generate_phi_kernel_dialect.py | 276 ++++++++++++++++++ tools/infrt/get_phi_kernel_info.py | 12 +- 22 files changed, 536 insertions(+), 53 deletions(-) create mode 100644 paddle/infrt/dialect/phi/ir/phi_kernels.cc create mode 100644 paddle/infrt/dialect/phi/ir/phi_kernels.h create mode 100644 tools/infrt/generate_phi_kernel_dialect.py diff --git a/.gitignore b/.gitignore index cecd6fa91c..debec551d9 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,9 @@ tools/__pycache__ # This file is automatically generated. # TODO(zhiqiang) Move this file to build directory. paddle/infrt/dialect/pd_ops.td +paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td +paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td +tools/infrt/kernels.json paddle/infrt/dialect/pd_ops_info.h .lit_test_times.txt paddle/infrt/tests/dialect/Output diff --git a/paddle/fluid/pybind/kernel_signature_generator.cc b/paddle/fluid/pybind/kernel_signature_generator.cc index 8283a249de..f0d5a4e477 100644 --- a/paddle/fluid/pybind/kernel_signature_generator.cc +++ b/paddle/fluid/pybind/kernel_signature_generator.cc @@ -49,24 +49,30 @@ int main(int argc, char **argv) { if (kernel_signature_map.Has(op_kernel_pair.first)) { std::cout << "\"" << op_kernel_pair.first << "\":{"; auto &args = kernel_signature_map.Get(op_kernel_pair.first).args; + std::cout << "\"inputs\":["; - for (auto name : std::get<0>(args)) { - std::cout << "\"" << name << "\","; + auto inputs_ = std::get<0>(args); + if (inputs_.size() > 0) std::cout << inputs_[0]; + for (size_t i = 1; i < inputs_.size(); i++) { + std::cout << ",\"" << inputs_[i] << "\""; } - if (std::get<0>(args).size() > 0) std::cout << "\b"; + std::cout << "],\"attrs\":["; - for (auto name : std::get<1>(args)) { - std::cout << "\"" << name << "\","; + auto attrs_ = std::get<1>(args); + if (attrs_.size() > 0) std::cout << attrs_[0]; + for (size_t i = 1; i < attrs_.size(); i++) { + std::cout << ",\"" << attrs_[i] << "\""; } - if (std::get<1>(args).size() > 0) std::cout << "\b"; + std::cout << "],\"outputs\":["; - for (auto name : std::get<2>(args)) { - std::cout << "\"" << name << "\","; + auto outputs_ = std::get<2>(args); + for (size_t i = 1; i < outputs_.size(); i++) { + std::cout << ",\"" << outputs_[i] << "\""; } - if (std::get<2>(args).size() > 0) std::cout << "\b"; + std::cout << "]},"; } } - std::cout << "\b}" << std::endl; + std::cout << "}" << std::endl; return 0; } diff --git a/paddle/infrt/dialect/infrt/common_type.h b/paddle/infrt/dialect/infrt/common_type.h index d6d6503c03..436e7920ca 100644 --- a/paddle/infrt/dialect/infrt/common_type.h +++ b/paddle/infrt/dialect/infrt/common_type.h @@ -21,8 +21,22 @@ namespace infrt { enum class TargetType : uint8_t { CPU, GPU, UNK }; -enum class PrecisionType : uint8_t { FLOAT32, FLOAT16, UNK }; -enum class LayoutType : uint8_t { NCHW, NHWC, UNK }; +enum class LayoutType : uint8_t { NCHW, NHWC, ANY, UNK }; +enum class PrecisionType : uint8_t { + UINT8, + INT8, + INT16, + INT32, + INT64, + FLOAT16, + BFLOAT16, + FLOAT32, + FLOAT64, + COMPLEX64, + COMPLEX128, + BOOL, + UNK +}; struct Place { TargetType target; diff --git a/paddle/infrt/dialect/infrt/infrt_ops_base.td b/paddle/infrt/dialect/infrt/infrt_ops_base.td index 978b126d75..f19912dc0c 100644 --- a/paddle/infrt/dialect/infrt/infrt_ops_base.td +++ b/paddle/infrt/dialect/infrt/infrt_ops_base.td @@ -34,9 +34,10 @@ def DenseTensor : Infrt_Type<"DenseTensor"> { let summary = "infrt dense tensor"; let description = [{dense_tensor<, 3>}]; let parameters = (ins - "TargetType":$target, - "PrecisionType":$precision, - "LayoutType":$layout + "::infrt::TargetType":$target, + "::infrt::PrecisionType":$precision, + "::infrt::LayoutType":$layout + ); } diff --git a/paddle/infrt/dialect/init_infrt_dialects.cc b/paddle/infrt/dialect/init_infrt_dialects.cc index c5c81b4b0f..5eae017193 100644 --- a/paddle/infrt/dialect/init_infrt_dialects.cc +++ b/paddle/infrt/dialect/init_infrt_dialects.cc @@ -23,6 +23,8 @@ #include "paddle/infrt/dialect/pd_ops.h" #include "paddle/infrt/dialect/phi/ir/infrt_phi_tensor.h" #include "paddle/infrt/dialect/phi/ir/phi_base.h" +#include "paddle/infrt/dialect/phi/ir/phi_kernels.h" + #include "paddle/infrt/dialect/tensor_shape.h" namespace infrt { @@ -34,6 +36,8 @@ void registerCinnDialects(mlir::DialectRegistry ®istry) { // NOLINT mlir::pd::PaddleDialect, #ifdef INFRT_WITH_PHI phi::PHIDenseTensorDialect, + phi::PHICPUKernelDialect, + phi::PHIGPUKernelDialect, phi::PHIDialect #endif >(); diff --git a/paddle/infrt/dialect/phi/ir/CMakeLists.txt b/paddle/infrt/dialect/phi/ir/CMakeLists.txt index 8c1d75629d..0497b98321 100644 --- a/paddle/infrt/dialect/phi/ir/CMakeLists.txt +++ b/paddle/infrt/dialect/phi/ir/CMakeLists.txt @@ -1,9 +1,12 @@ #mlir_tablegen_on(infrt_phi_base DIALECT phi) add_mlir_dialect(infrt_phi_base phi) add_mlir_dialect(infrt_phi_tensor phi_dt) -add_mlir_dialect(infrt_phi_kernel phi_kernel) +add_mlir_dialect(phi_cpu_kernels phi_cpu) +add_mlir_dialect(phi_gpu_kernels phi_gpu) + #mlir_tablegen_on(infrt_phi_tensor) gather_srcs(infrt_src SRCS phi_base.cc - infrt_phi_tensor.cc) + infrt_phi_tensor.cc + phi_kernels.cc) diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td b/paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td index 37bf0b5ef2..ee23470fc7 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td @@ -6,24 +6,32 @@ include "mlir/IR/OpBase.td" include "paddle/infrt/dialect/infrt_base.td" include "paddle/infrt/dialect/phi/ir/infrt_phi_base.td" -def PHI_KernelDialect : Dialect { - let name = "phi_kernel"; +def PHI_CPUKernelDialect : Dialect { + let name = "phi_cpu"; let description = [{ - The PHI Kernel dialect. + The PHI CPU Kernel dialect. + }]; + + let cppNamespace = "::infrt::phi"; +} + +def PHI_GPUKernelDialect : Dialect { + let name = "phi_gpu"; + + let description = [{ + The PHI GPU Kernel dialect. }]; let cppNamespace = "::infrt::phi"; } // PHI Kernel related ops. -class PDT_Kernel traits = []> : Op { +class PDTCPU_Kernel traits = []> : Op { } -def PDCK_AbsOp : PDT_Kernel<"phi.abs.host.fp32"> { - let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x); - let results = (outs DenseTensor:$output); +// PHI Kernel related ops. +class PDTGPU_Kernel traits = []> : Op { } #endif - diff --git a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td index dc3a4b340d..39677871ff 100644 --- a/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td +++ b/paddle/infrt/dialect/phi/ir/infrt_phi_tensor.td @@ -34,6 +34,14 @@ class FillDenseTensorOp : attr_type:$value ); let results = (outs); + let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict"; +} + +class PrintDenseTensorOp: + PDT_Op<"print_tensor"> { + let arguments = (ins DenseTensor:$input); + let results = (outs); + let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict"; } class CreateCPUAllocatorOp @@ -44,7 +52,7 @@ class CreateCPUAllocatorOp class CreateCPUContextOp : PDT_Op<"create_context." # "cpu", [NoSideEffect]> { - let arguments = (ins); + let arguments = (ins CPU_Allocator:$input); let results = (outs CPU_Context:$output); } @@ -52,6 +60,7 @@ def PDT_CreateDenseTensorOp_cpu_f32_nchw : CreateDenseTensorOp<"cpu", "f32", "nc def PDT_FillDenseTensorOp_f32 : FillDenseTensorOp; def PDT_CreateAllocatorOp_cpu : CreateCPUAllocatorOp; def PDT_CreateContextOp_cpu : CreateCPUContextOp; +def PDT_PrintDenseTensor_cpu : PrintDenseTensorOp; def FakeKernelOp : PDT_Op<"fake_phi_kernel"> { let arguments = (ins CPU_Context:$dev_ctx, DenseTensor:$x, DenseTensor:$y, BoolAttr:$transpose_x, BoolAttr:$transpose_y); diff --git a/paddle/infrt/dialect/phi/ir/phi_kernels.cc b/paddle/infrt/dialect/phi/ir/phi_kernels.cc new file mode 100644 index 0000000000..c7a837b83f --- /dev/null +++ b/paddle/infrt/dialect/phi/ir/phi_kernels.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/phi/ir/phi_kernels.h" +#include + +#include "paddle/infrt/dialect/phi/ir/phi_gpu_kernelsDialect.cpp.inc" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/phi/ir/phi_cpu_kernels.cpp.inc" // NOLINT + +#include "paddle/infrt/dialect/phi/ir/phi_cpu_kernelsDialect.cpp.inc" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/phi/ir/phi_gpu_kernels.cpp.inc" // NOLINT + +namespace infrt { +namespace phi { + +void PHICPUKernelDialect::initialize() { +#define GET_OP_LIST + addOperations< +#include "paddle/infrt/dialect/phi/ir/phi_cpu_kernels.cpp.inc" // NOLINT + >(); +} + +void PHIGPUKernelDialect::initialize() { +#define GET_OP_LIST + addOperations< +#include "paddle/infrt/dialect/phi/ir/phi_gpu_kernels.cpp.inc" // NOLINT + >(); +} + +} // namespace phi +} // namespace infrt diff --git a/paddle/infrt/dialect/phi/ir/phi_kernels.h b/paddle/infrt/dialect/phi/ir/phi_kernels.h new file mode 100644 index 0000000000..b84d1b2b72 --- /dev/null +++ b/paddle/infrt/dialect/phi/ir/phi_kernels.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/infrt/dialect/dense_tensor.h" +#include "paddle/infrt/dialect/infrt/infrt_dialect.h" +#include "paddle/infrt/dialect/phi/ir/phi_base.h" + +#include "paddle/infrt/dialect/phi/ir/phi_cpu_kernelsDialect.h.inc" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/phi/ir/phi_cpu_kernels.h.inc" + +#include "paddle/infrt/dialect/phi/ir/phi_gpu_kernelsDialect.h.inc" +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/phi/ir/phi_gpu_kernels.h.inc" diff --git a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc index 63869b7d7b..6c0f6df892 100644 --- a/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc +++ b/paddle/infrt/dialect/phi/pass/kernel_op_desc.cc @@ -41,26 +41,49 @@ TargetType cvtTargetFromPhi(phi::Backend backend) { } phi::DataType cvtPrecision2Phi(PrecisionType precision) { +#define CONVERT_PRECISION_TO_PHI(Precision) \ + case PrecisionType::Precision: \ + return phi::DataType::Precision; + switch (precision) { - case PrecisionType::FLOAT32: - return phi::DataType::FLOAT32; - break; - case PrecisionType::FLOAT16: - return phi::DataType::FLOAT16; + CONVERT_PRECISION_TO_PHI(FLOAT32) + CONVERT_PRECISION_TO_PHI(FLOAT16) + CONVERT_PRECISION_TO_PHI(FLOAT64) + CONVERT_PRECISION_TO_PHI(UINT8) + CONVERT_PRECISION_TO_PHI(INT8) + CONVERT_PRECISION_TO_PHI(INT16) + CONVERT_PRECISION_TO_PHI(INT32) + CONVERT_PRECISION_TO_PHI(INT64) + CONVERT_PRECISION_TO_PHI(COMPLEX64) + CONVERT_PRECISION_TO_PHI(COMPLEX128) + CONVERT_PRECISION_TO_PHI(BOOL) default: return phi::DataType::UNDEFINED; } +#undef CONVERT_PRECISION_TO_PHI } PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) { +#define CONVERT_PRECISION_FROM_PHI(Precision) \ + case phi::DataType::Precision: \ + return PrecisionType::Precision; + switch (datatype) { - case phi::DataType::FLOAT32: - return PrecisionType::FLOAT32; - case phi::DataType::FLOAT16: - return PrecisionType::FLOAT16; + CONVERT_PRECISION_FROM_PHI(FLOAT32) + CONVERT_PRECISION_FROM_PHI(FLOAT16) + CONVERT_PRECISION_FROM_PHI(FLOAT64) + CONVERT_PRECISION_FROM_PHI(UINT8) + CONVERT_PRECISION_FROM_PHI(INT8) + CONVERT_PRECISION_FROM_PHI(INT16) + CONVERT_PRECISION_FROM_PHI(INT32) + CONVERT_PRECISION_FROM_PHI(INT64) + CONVERT_PRECISION_FROM_PHI(COMPLEX64) + CONVERT_PRECISION_FROM_PHI(COMPLEX128) + CONVERT_PRECISION_FROM_PHI(BOOL) default: return PrecisionType::UNK; } +#undef CONVERT_PRECISION_FROM_PHI } phi::DataLayout cvtLayout2Phi(LayoutType layout) { @@ -69,6 +92,8 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) { return phi::DataLayout::NCHW; case LayoutType::NHWC: return phi::DataLayout::NHWC; + case LayoutType::ANY: + return phi::DataLayout::ANY; default: return phi::DataLayout::UNDEFINED; } @@ -80,6 +105,8 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) { return LayoutType::NCHW; case phi::DataLayout::NHWC: return LayoutType::NHWC; + case phi::DataLayout::ANY: + return LayoutType::ANY; default: return LayoutType::UNK; } diff --git a/paddle/infrt/host_context/mlir_exec.cc b/paddle/infrt/host_context/mlir_exec.cc index 79717ba2cc..7823681079 100644 --- a/paddle/infrt/host_context/mlir_exec.cc +++ b/paddle/infrt/host_context/mlir_exec.cc @@ -29,6 +29,7 @@ #include "paddle/infrt/kernel/tensor_shape_kernels.h" #include "paddle/infrt/kernel/test_kernels.h" #ifdef INFRT_WITH_PHI +#include "paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.h" #include "paddle/infrt/kernel/phi/registry.h" #endif @@ -58,6 +59,7 @@ int main(int argc, char** argv) { kernel::RegisterControlFlowKernels(®istry); #ifdef INFRT_WITH_PHI kernel::RegisterPhiKernels(®istry); + kernel::RegisterInferShapeLaunchers(®istry); #endif // load extra shared library diff --git a/paddle/infrt/kernel/phi/context_kernels.cc b/paddle/infrt/kernel/phi/context_kernels.cc index 5284f49991..3caaf1788e 100644 --- a/paddle/infrt/kernel/phi/context_kernels.cc +++ b/paddle/infrt/kernel/phi/context_kernels.cc @@ -18,7 +18,13 @@ namespace infrt { namespace kernel { namespace phi { -::phi::CPUContext CreateCpuContext() { return {}; } +::phi::CPUContext CreateCpuContext( + infrt::backends::CpuPhiAllocator* allocator) { + ::phi::CPUContext context; + context.SetAllocator(allocator); + context.Init(); + return context; +} } // namespace phi } // namespace kernel diff --git a/paddle/infrt/kernel/phi/context_kernels.h b/paddle/infrt/kernel/phi/context_kernels.h index 8082dc6c2f..7f1e7ef6cd 100644 --- a/paddle/infrt/kernel/phi/context_kernels.h +++ b/paddle/infrt/kernel/phi/context_kernels.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/infrt/backends/host/phi_allocator.h" #include "paddle/infrt/backends/host/phi_context.h" #include "paddle/phi/core/dense_tensor.h" @@ -21,7 +22,7 @@ namespace infrt { namespace kernel { namespace phi { -::phi::CPUContext CreateCpuContext(); +::phi::CPUContext CreateCpuContext(::infrt::backends::CpuPhiAllocator*); } // namespace phi } // namespace kernel diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc index ce9200b991..871336e876 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.cc +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/infrt/kernel/phi/dense_tensor_kernels.h" - +#include namespace infrt { namespace kernel { namespace phi { @@ -30,8 +30,38 @@ namespace phi { } void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, - host_context::Attribute> values) {} + host_context::Attribute> values) { + auto place = ::phi::CPUPlace(); + float* a_data = dense_tensor->mutable_data(place); + for (int64_t i = 0; i < dense_tensor->numel(); ++i) { + a_data[i] = (values.get())[i]; + } +} +void PrintDenseTensor(::phi::DenseTensor* dense_tensor) { +#define PRINT_META_DATA(PHI_DATATYPE, DTYPE) \ + case ::phi::DataType::PHI_DATATYPE: { \ + DTYPE* data = dense_tensor->data(); \ + if (dense_tensor->numel() == 0) break; \ + std::cout << data[0]; \ + for (int64_t i = 1; i < dense_tensor->numel(); i++) { \ + std::cout << "," << data[i]; \ + } \ + break; \ + } + + ::phi::DDim dims = dense_tensor->dims(); + std::cout << "dense_tensor: shape=shape" << dims.to_str() << "," + << " values=["; + switch (dense_tensor->dtype()) { + PRINT_META_DATA(FLOAT32, float); + PRINT_META_DATA(INT32, int32_t); + default: + std::cout << "Error! Unsupported data type!\n"; + } + std::cout << "]\n"; +#undef PRINT_META_DATA +} } // namespace phi } // namespace kernel } // namespace infrt diff --git a/paddle/infrt/kernel/phi/dense_tensor_kernels.h b/paddle/infrt/kernel/phi/dense_tensor_kernels.h index 25daf7027e..920c0b1c8a 100644 --- a/paddle/infrt/kernel/phi/dense_tensor_kernels.h +++ b/paddle/infrt/kernel/phi/dense_tensor_kernels.h @@ -28,7 +28,8 @@ namespace phi { host_context::Attribute> lod); void FillDenseTensorF32(::phi::DenseTensor* dense_tensor, - host_context::Attribute> values); + host_context::Attribute> values); +void PrintDenseTensor(::phi::DenseTensor* dense_tensor); } // namespace phi } // namespace kernel diff --git a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc index 2161e98fac..37f9197edb 100644 --- a/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc +++ b/paddle/infrt/kernel/phi/infershaped/infershape_launchers_test.cc @@ -54,7 +54,7 @@ TEST(ElementwiseAdd, launcher_registry) { host_context::KernelRegistry registry; RegisterInferShapeLaunchers(®istry); ASSERT_GE(registry.size(), 1UL); - auto creator = registry.GetKernel("pten.add.cpu.any.fp32"); + auto creator = registry.GetKernel("phi_cpu.add.any.float32"); const phi::DDim dims({1, 2}); const phi::DataType dtype{phi::DataType::FLOAT32}; diff --git a/paddle/infrt/kernel/phi/registry.cc b/paddle/infrt/kernel/phi/registry.cc index 5d79814d4b..15e2d21005 100644 --- a/paddle/infrt/kernel/phi/registry.cc +++ b/paddle/infrt/kernel/phi/registry.cc @@ -42,6 +42,8 @@ void RegisterPhiKernels(host_context::KernelRegistry* registry) { INFRT_KERNEL(infrt::kernel::phi::CreateDenseTensorCpuF32Nchw)); registry->AddKernel("phi_dt.fill_dense_tensor.f32", INFRT_KERNEL(infrt::kernel::phi::FillDenseTensorF32)); + registry->AddKernel("phi_dt.print_tensor", + INFRT_KERNEL(infrt::kernel::phi::PrintDenseTensor)); registry->AddKernel( "phi_dt.fake_phi_kernel", std::bind(&KernelLauncherFunc !phi.CPU_allocator - %ctx = "phi_dt.create_context.cpu" (): () -> !phi.CPU_context + %ctx = "phi_dt.create_context.cpu" (%allocator): (!phi.CPU_allocator) -> !phi.CPU_context %t = "phi_dt.create_dense_tensor.cpu.f32.nchw" (%allocator) {dims=[1:i64], lod=[1:i64]}: (!phi.CPU_allocator) -> (!infrt.dense_tensor) + "phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor) -> () + %e = "phi_cpu.sign.any.float32"(%ctx, %t) : (!phi.CPU_context, !infrt.dense_tensor) -> (!infrt.dense_tensor) - // CHECK: @FakePhiKernel@ - %d = "phi_dt.fake_phi_kernel" (%ctx, %t, %t) {transpose_x=false, transpose_y=false} : (!phi.CPU_context, !infrt.dense_tensor, !infrt.dense_tensor) -> (!infrt.dense_tensor) + // CHECK: dense_tensor: shape=shape[1], values=[1] + "phi_dt.print_tensor" (%e) : (!infrt.dense_tensor) -> () Infrt.return } diff --git a/paddle/scripts/infrt_build.sh b/paddle/scripts/infrt_build.sh index a013250138..75b27e4165 100755 --- a/paddle/scripts/infrt_build.sh +++ b/paddle/scripts/infrt_build.sh @@ -33,14 +33,16 @@ function update_pd_ops() { rm -rf ${PADDLE_ROOT}/build && mkdir -p ${PADDLE_ROOT}/build cd ${PADDLE_ROOT}/build cmake .. -DWITH_PYTHON=ON -DWITH_GPU=OFF -DPYTHON_EXECUTABLE=`which python3` -DWITH_XBYAK=OFF -DWITH_NCCL=OFF -DWITH_RCCL=OFF -DWITH_CRYPTO=OFF - make -j8 paddle_python + make -j8 paddle_python print_pten_kernels cd ${PADDLE_ROOT}/build + ./paddle/phi/tools/print_pten_kernels > ../tools/infrt/kernels.json cd python/dist/ python3 -m pip uninstall -y paddlepaddle python3 -m pip install *whl # update pd_ops.td cd ${PADDLE_ROOT}/tools/infrt/ python3 generate_pd_op_dialect_from_paddle_op_maker.py + python3 generate_phi_kernel_dialect.py ./kernels.json } function init() { diff --git a/tools/infrt/generate_phi_kernel_dialect.py b/tools/infrt/generate_phi_kernel_dialect.py new file mode 100644 index 0000000000..80cf3958b1 --- /dev/null +++ b/tools/infrt/generate_phi_kernel_dialect.py @@ -0,0 +1,276 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import sys + +attr_type_converter = {"i": 'SI32Attr', "b": 'BoolAttr', "l": 'SI64Attr'} +supported_kernels = ['sign', 'dot', 'digamma', 'conj'] + +target_type_converter = {"CPU": "CPU", "GPU": "GPU"} +layout_type_converter = { + "NCHW": "NCHW", + "NHWC": "NHWC", + "Undefined(AnyLayout)": "ANY" +} +precision_type_converter = { + "uint8": "UINT8", + "int8": "INT8", + "int16": "INT16", + "int32": "INT32", + "int64": "INT64", + "float16": "FLOAT16", + "bfloat16": "BFLOAT16", + "float32": "FLOAT32", + "float64": "FLOAT64", + "complex64": "COMPLEX64", + "complex128": "COMPLEX128", + "bool": "BOOL" +} + + +def generate_kernel_name(op_name, place_str): + [target_, layout_, precision_] = place_str[1:-1].split(',') + target_ = target_type_converter[target_.strip()] + layout_ = layout_type_converter[layout_.strip()] + precision_ = precision_type_converter[precision_.strip()] + alias_ = "{}.{}".format(op_name, ".".join( + [target_.strip(), layout_.strip(), precision_.strip()])) + return alias_ + + +def generate_attrs_info(op_name, attrs_info): + kernel_attrs_names = { + 'split': ['sections', 'num', 'axis', 'mkldnn_data_type'], + 'sign': [], + 'masked_select': [], + 'trace': ['offset', 'axis1', 'axis2'], + 'concat': ['axis'], + 'empty': ['shape', 'dtype'], + 'conj': [], + 'norm': ['axis', 'epsilon', 'is_test'], + 'histogram': ['bins', 'min', 'max'], + 'dot': [], + 'scale': ['scale', 'bias', 'bias_after_scale'], + 'digamma': [], + 'lerp': [], + 'cast': ['out_dtype', 'in_dtype'], + 'abs': [] + } + attrs_args_ = "" + if len(kernel_attrs_names[op_name]) == len(attrs_info): + for index in range(len(attrs_info)): + attr_name = kernel_attrs_names[op_name][index] + attr_type = attr_type_converter[attrs_info[index]] + attrs_args_ += '{type_}:${name_},'.format( + type_=attr_type, name_=attr_name) + return attrs_args_[:-1] + + +def generate_inputs_info(input_info): + input_args_ = "" + for index in range(len(input_info)): + [target_, layout_, precision_] = input_info[index].split(',') + # todo: check vadility + target_ = target_type_converter[target_.strip()] + layout_ = layout_type_converter[layout_.strip()] + precision_ = precision_type_converter[precision_.strip()] + input_args_ += " DenseTensor<\"{}\",\"{}\",\"{}\">:$in{},".format( + target_.strip(), precision_.strip(), layout_.strip(), str(index)) + input_args_ = input_args_[:-1] + return input_args_ + + +def generate_arguments_info(op_name, input_info, attr_info): + input_args = generate_inputs_info(input_info) + attr_args = generate_attrs_info(op_name, attr_info) + context_args = "CPU_Context:$dev_ctx" + argument_ = "{},{},{}".format(context_args, input_args, attr_args) + return (("let arguments = (ins {});".format(argument_.strip(",")))) + + +def generate_results_info(output_info): + output_args_ = "let results = (outs " + for index in range(len(output_info)): + [target_, layout_, precision_] = output_info[index].split(',') + # todo: check vadility + target_ = target_type_converter[target_.strip()] + layout_ = layout_type_converter[layout_.strip()] + precision_ = precision_type_converter[precision_.strip()] + output_args_ += " DenseTensor<\"{}\",\"{}\",\"{}\">:$out{},".format( + target_.strip(), precision_.strip(), layout_.strip(), str(index)) + return ("{});".format(output_args_[:-1])) + + +def generate_supported_kernel_list(load_dict): + supported_kernels_list_ = [] + for op_name in load_dict: + kernel_list = load_dict[op_name] + for kernel_info in kernel_list: + for kernel_alias_ in kernel_info: + attributes = kernel_info[kernel_alias_]["attribute"] + flag = True + for attribute in attributes: + if attribute not in attr_type_converter: + flag = False + if flag: + supported_kernels_list_.append(op_name) + + alias_ = generate_kernel_dialect(op_name, kernel_alias_, + kernel_info[kernel_alias_]) + supported_kernels_list_ = list(set(supported_kernels_list_)) + print(supported_kernels_list_) + + +def scan_kernel_info(load_dict): + target_type_ = [] + layout_type_ = [] + precision_type_ = [] + for op_name in load_dict: + kernel_list = load_dict[op_name] + for kernel_info in kernel_list: + for kernel_alias_ in kernel_info: + [target_, layout_, precision_] = kernel_alias_[1:-1].split(',') + target_type_.append(target_.strip()) + layout_type_.append(layout_.strip()) + precision_type_.append(precision_.strip()) + target_type_ = list(set(target_type_)) + layout_type_ = list(set(layout_type_)) + precision_type_ = list(set(precision_type_)) + print(target_type_) + print(layout_type_) + print(precision_type_) + + +def generate_cpu_kernel_dialect(op_name, kernel_alias_, kernel_info): + + alias = generate_kernel_name(op_name, kernel_alias_) + summary = 'let summary = "{name}";'.format(name=alias) + dialect_name = alias.split(".") + dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[ + 3] + + header = 'def {kernel_name} : PDTCPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format( + kernel_name=alias.replace(".", ""), + name=dialect_name.lower(), + left_brace="{") + + inputs_ = kernel_info["input"] + attributes = kernel_info["attribute"] + arguments = generate_arguments_info(op_name, inputs_, attributes) + + outputs = kernel_info["output"] + results = generate_results_info(outputs) + + kernel_dialect = '{header_}\n {summary_}\n {arguments_}\n {results_}\n{right_brace}\n'.format( + header_=header, + summary_=summary, + arguments_=arguments, + results_=results, + right_brace="}") + return kernel_dialect + + +def generate_gpu_kernel_dialect(op_name, kernel_alias_, kernel_info): + + alias = generate_kernel_name(op_name, kernel_alias_) + summary = 'let summary = "{name}";'.format(name=alias) + dialect_name = alias.split(".") + dialect_name = dialect_name[0] + "." + dialect_name[2] + "." + dialect_name[ + 3] + + header = 'def {kernel_name} : PDTGPU_Kernel<"{name}",[NoSideEffect]> {left_brace}'.format( + kernel_name=alias.replace(".", ""), + name=dialect_name.lower(), + left_brace="{") + inputs_ = kernel_info["input"] + attributes = kernel_info["attribute"] + arguments = generate_arguments_info(op_name, inputs_, attributes) + + outputs = kernel_info["output"] + results = generate_results_info(outputs) + + kernel_dialect = '{header_}\n {summary_}\n {arguments_}\n {results_}\n{right_brace}\n'.format( + header_=header, + summary_=summary, + arguments_=arguments, + results_=results, + right_brace="}") + return kernel_dialect + + +def generate_dialect_head(): + comment_ = "/*===- TableGen'source file -----------------------------------------------===*\\\n\ +|* *|\n\ +|* Kernel Definitions *|\n\ +|* *|\n\ +|* Automatically generated file, do not edit! *|\n\ +|* Generated by tools/infrt/generate_pten_kernel_dialect.py *|\n\ +|* *|\n\ +\*===----------------------------------------------------------------------===*/\n" + + includes_ = "#ifndef PTEN_KERNELS\n\ +#define PTEN_KERNELS\n\ +include \"mlir/Interfaces/InferTypeOpInterface.td\"\n\ +include \"mlir/Interfaces/LoopLikeInterface.td\"\n\ +include \"mlir/IR/OpBase.td\"\n\ +include \"paddle/infrt/dialect/phi/ir/infrt_phi_kernel.td\"" + + return (comment_ + includes_) + + +def get_kernel_target(kernel_alias_): + target = kernel_alias_[1:-1].split(",") + return target[0] + + +def main(path_): + with open(path_, "r") as f: + load_dict = json.load(f) + + head = generate_dialect_head() + + cpu_registry_ = "" + gpu_registry_ = "" + for op_name in load_dict: + if op_name not in supported_kernels: + continue + kernel_list = load_dict[op_name] + for kernel_info in kernel_list: + for kernel_alias_ in kernel_info: + if get_kernel_target(kernel_alias_) == "CPU": + kernel_registry = generate_cpu_kernel_dialect( + op_name, kernel_alias_, kernel_info[kernel_alias_]) + cpu_registry_ += kernel_registry + elif get_kernel_target(kernel_alias_) == "GPU": + kernel_registry = generate_gpu_kernel_dialect( + op_name, kernel_alias_, kernel_info[kernel_alias_]) + gpu_registry_ += kernel_registry + else: + print("Unsupported backend:" + get_kernel_target( + kernel_alias_)) + end = "#endif // PTEN_KERNELS" + with open("../../paddle/infrt/dialect/phi/ir/phi_cpu_kernels.td", + "w") as dst: + dst.write('{start_}\n{dialect_}\n{end_}'.format( + start_=head, dialect_=cpu_registry_, end_=end)) + with open("../../paddle/infrt/dialect/phi/ir/phi_gpu_kernels.td", + "w") as dst: + dst.write('{start_}\n{dialect_}\n{end_}'.format( + start_=head, dialect_=gpu_registry_, end_=end)) + + +if __name__ == '__main__': + path = sys.argv[1] + main(path) diff --git a/tools/infrt/get_phi_kernel_info.py b/tools/infrt/get_phi_kernel_info.py index f3e9f345da..9ea3fef003 100644 --- a/tools/infrt/get_phi_kernel_info.py +++ b/tools/infrt/get_phi_kernel_info.py @@ -150,19 +150,19 @@ def gen_dtype(vals: List[str]): ir_dtypes, origin_dtypes = [], [] for val in vals: if val == "float": - ir_dtypes.append("fp32") + ir_dtypes.append("float32") origin_dtypes.append("float") elif val == "double": - ir_dtypes.append("fp64") + ir_dtypes.append("float64") origin_dtypes.append("double") elif val == "float16": - ir_dtypes.append("fp16") + ir_dtypes.append("float16") origin_dtypes.append("paddle::experimental::float16") elif val == "bfloat16": ir_dtypes.append("bf16") origin_dtypes.append("paddle::experimental::bfloat16") elif val == "bool": - ir_dtypes.append("int1") + ir_dtypes.append("bool") origin_dtypes.append("bool") elif val == "int8_t": ir_dtypes.append("int8") @@ -219,8 +219,8 @@ def gen_register_info(resources: List[List[str]]): for ir_dtype, origin_dtype in zip(ir_dtypes, origin_dtypes): kernel_func = gen_kernel_func(update_item[3], ctx_name, origin_dtype) - ir_name = 'pten.' + '.'.join( - [it.lower() for it in update_item[:3]]) + "." + ir_dtype + ir_name = 'phi_cpu.' + update_item[0].lower() + '.' + update_item[ + 2].lower() + '.' + ir_dtype res += f""" registry->AddKernel("{ir_name}",""" -- GitLab