From 70dea13868a1945a3e7c6dd892d7b880d4fd7cbb Mon Sep 17 00:00:00 2001 From: Yan Chunwei Date: Tue, 7 Dec 2021 18:40:30 +0800 Subject: [PATCH] introduce INF-RT (#37669) * add infrt code refined with Paddle's code style. * rename CinnRtConfig to InfRtConfig * rename CinnRt to InfRt of some code * rename CINNRT to INFRT * remove unnecessary code * replace CINN to INFRT in the source code * replace all "cinn" in code to "infrt" * remove some const_cast --- CMakeLists.txt | 1 + cmake/external/llvm.cmake | 110 ++++ cmake/third_party.cmake | 5 + paddle/CMakeLists.txt | 1 + paddle/infrt/CMakeLists.txt | 79 +++ paddle/infrt/api/CMakeLists.txt | 8 + paddle/infrt/api/infrt_api.cc | 246 ++++++++ paddle/infrt/api/infrt_api.h | 63 ++ paddle/infrt/api/infrt_api_test.cc | 79 +++ paddle/infrt/common/CMakeLists.txt | 14 + paddle/infrt/common/buffer.cc | 98 +++ paddle/infrt/common/buffer.h | 296 ++++++++++ paddle/infrt/common/common.h | 61 ++ paddle/infrt/common/dtype.cc | 50 ++ paddle/infrt/common/dtype.def | 18 + paddle/infrt/common/dtype.h | 85 +++ paddle/infrt/common/global.cc | 30 + paddle/infrt/common/global.h | 32 + paddle/infrt/common/macros.h | 52 ++ paddle/infrt/common/memory.cc | 42 ++ paddle/infrt/common/memory.h | 76 +++ paddle/infrt/common/object.cc | 19 + paddle/infrt/common/object.h | 81 +++ paddle/infrt/common/shared.cc | 15 + paddle/infrt/common/shared.h | 153 +++++ paddle/infrt/common/string.cc | 128 ++++ paddle/infrt/common/string.h | 84 +++ paddle/infrt/common/target.cc | 118 ++++ paddle/infrt/common/target.h | 112 ++++ paddle/infrt/common/type.cc | 358 +++++++++++ paddle/infrt/common/type.h | 223 +++++++ paddle/infrt/dialect/CMakeLists.txt | 61 ++ paddle/infrt/dialect/basic_kernels.cc | 164 +++++ paddle/infrt/dialect/basic_kernels.h | 24 + paddle/infrt/dialect/basic_kernels.td | 139 +++++ paddle/infrt/dialect/dense_tensor.cc | 277 +++++++++ paddle/infrt/dialect/dense_tensor.h | 79 +++ paddle/infrt/dialect/dense_tensor.td | 150 +++++ paddle/infrt/dialect/diagnostic_utils.cc | 52 ++ paddle/infrt/dialect/diagnostic_utils.h | 39 ++ paddle/infrt/dialect/dialect.cc | 36 ++ paddle/infrt/dialect/infrt_base.cc | 127 ++++ paddle/infrt/dialect/infrt_base.h | 73 +++ paddle/infrt/dialect/infrt_base.td | 42 ++ paddle/infrt/dialect/init_infrt_dialects.cc | 34 ++ paddle/infrt/dialect/init_infrt_dialects.h | 23 + paddle/infrt/dialect/mlir_loader.cc | 72 +++ paddle/infrt/dialect/mlir_loader.h | 30 + paddle/infrt/dialect/mlir_loader_test.cc | 57 ++ paddle/infrt/dialect/mlir_tests/basic.mlir | 40 ++ .../infrt/dialect/mlir_tests/benchmark.mlir | 23 + .../dialect/mlir_tests/dense_tensor.mlir | 22 + .../infrt/dialect/mlir_tests/paddle_ops.mlir | 8 + paddle/infrt/dialect/mlir_tests/rewrite.mlir | 24 + .../dialect/mlir_tests/rewrite_conv_bn.mlir | 15 + .../infrt/dialect/mlir_tests/tensor_map.mlir | 31 + .../dialect/mlir_tests/tensor_shape.mlir | 5 + .../infrt/dialect/mlir_tests/tensor_type.mlir | 9 + paddle/infrt/dialect/ops.td | 6 + paddle/infrt/dialect/opt.cc | 45 ++ paddle/infrt/dialect/pd_op_base.td | 77 +++ paddle/infrt/dialect/pd_ops.cc | 177 ++++++ paddle/infrt/dialect/pd_ops.h | 57 ++ paddle/infrt/dialect/pd_ops.td | 182 ++++++ paddle/infrt/dialect/pd_types.cc | 15 + paddle/infrt/dialect/pd_types.h | 57 ++ paddle/infrt/dialect/print_ir.cc | 134 +++++ paddle/infrt/dialect/rewrite.td | 90 +++ paddle/infrt/dialect/tensor_shape.cc | 68 +++ paddle/infrt/dialect/tensor_shape.h | 40 ++ paddle/infrt/dialect/tensor_shape.td | 49 ++ paddle/infrt/dialect/tensor_shape_base.td | 36 ++ paddle/infrt/dialect/test_kernels.cc | 163 +++++ paddle/infrt/dialect/test_kernels.h | 23 + paddle/infrt/dialect/test_kernels.td | 65 ++ paddle/infrt/dialect/types.cc | 17 + paddle/infrt/dialect/types.h | 16 + paddle/infrt/external_kernels/CMakeLists.txt | 13 + paddle/infrt/external_kernels/basic.mlir | 21 + .../infrt/external_kernels/basic_kernels.cc | 59 ++ paddle/infrt/external_kernels/fc.mlir | 43 ++ paddle/infrt/external_kernels/paddle.mlir | 50 ++ paddle/infrt/gtest_main.cc | 23 + paddle/infrt/host_context/CMakeLists.txt | 29 + paddle/infrt/host_context/core_runtime.cc | 93 +++ paddle/infrt/host_context/core_runtime.h | 86 +++ .../infrt/host_context/core_runtime_test.cc | 96 +++ paddle/infrt/host_context/function.cc | 19 + paddle/infrt/host_context/function.h | 62 ++ paddle/infrt/host_context/kernel_frame.cc | 29 + paddle/infrt/host_context/kernel_frame.h | 166 ++++++ paddle/infrt/host_context/kernel_registry.cc | 70 +++ paddle/infrt/host_context/kernel_registry.h | 67 +++ .../host_context/kernel_registry_test.cc | 47 ++ paddle/infrt/host_context/kernel_utils.cc | 19 + paddle/infrt/host_context/kernel_utils.h | 352 +++++++++++ .../infrt/host_context/kernel_utils_test.cc | 69 +++ paddle/infrt/host_context/mlir_exec.cc | 80 +++ .../host_context/mlir_function_executable.cc | 135 +++++ .../host_context/mlir_function_executable.h | 78 +++ .../host_context/mlir_program_executor.cc | 19 + .../host_context/mlir_program_executor.h | 79 +++ .../infrt/host_context/mlir_tests/basic.mlir | 30 + .../host_context/mlir_tests/dense_tensor.mlir | 9 + .../infrt/host_context/mlir_tests/shape.mlir | 7 + .../host_context/mlir_to_runtime_translate.cc | 558 ++++++++++++++++++ .../host_context/mlir_to_runtime_translate.h | 107 ++++ .../mlir_to_runtime_translate_test.cc | 160 +++++ paddle/infrt/host_context/op_executable.cc | 151 +++++ paddle/infrt/host_context/op_executable.h | 92 +++ .../infrt/host_context/op_executable_test.cc | 56 ++ paddle/infrt/host_context/symbol_table.cc | 82 +++ paddle/infrt/host_context/symbol_table.h | 65 ++ paddle/infrt/host_context/value.cc | 69 +++ paddle/infrt/host_context/value.h | 156 +++++ paddle/infrt/host_context/value_test.cc | 34 ++ paddle/infrt/kernel/CMakeLists.txt | 9 + paddle/infrt/kernel/basic_kernels.cc | 85 +++ paddle/infrt/kernel/basic_kernels.h | 34 ++ paddle/infrt/kernel/control_flow_kernels.cc | 44 ++ paddle/infrt/kernel/control_flow_kernels.h | 31 + paddle/infrt/kernel/tensor_kernels.cc | 79 +++ paddle/infrt/kernel/tensor_kernels.h | 25 + paddle/infrt/kernel/tensor_shape_kernels.cc | 38 ++ paddle/infrt/kernel/tensor_shape_kernels.h | 27 + paddle/infrt/kernel/test_kernels.cc | 200 +++++++ paddle/infrt/kernel/test_kernels.h | 31 + paddle/infrt/paddle/CMakeLists.txt | 24 + paddle/infrt/paddle/cpp/CMakeLists.txt | 16 + paddle/infrt/paddle/cpp/desc_api.h | 229 +++++++ paddle/infrt/paddle/framework.proto | 213 +++++++ paddle/infrt/paddle/model_parser.cc | 172 ++++++ paddle/infrt/paddle/model_parser.h | 55 ++ paddle/infrt/paddle/pb/CMakeLists.txt | 20 + paddle/infrt/paddle/pb/block_desc.cc | 43 ++ paddle/infrt/paddle/pb/block_desc.h | 77 +++ paddle/infrt/paddle/pb/op_desc.cc | 139 +++++ paddle/infrt/paddle/pb/op_desc.h | 198 +++++++ paddle/infrt/paddle/pb/program_desc.cc | 35 ++ paddle/infrt/paddle/pb/program_desc.h | 61 ++ paddle/infrt/paddle/pb/var_desc.cc | 367 ++++++++++++ paddle/infrt/paddle/pb/var_desc.h | 124 ++++ paddle/infrt/paddle/scope.cc | 44 ++ paddle/infrt/paddle/scope.h | 68 +++ paddle/infrt/paddle/tensor.cc | 19 + paddle/infrt/paddle/tensor.h | 107 ++++ paddle/infrt/support/CMakeLists.txt | 1 + paddle/infrt/support/type_traits.h | 147 +++++ paddle/infrt/support/variant.h | 219 +++++++ paddle/infrt/tensor/CMakeLists.txt | 20 + paddle/infrt/tensor/dense_host_tensor.cc | 86 +++ paddle/infrt/tensor/dense_host_tensor.h | 92 +++ paddle/infrt/tensor/dense_tensor_view.cc | 17 + paddle/infrt/tensor/dense_tensor_view.h | 64 ++ paddle/infrt/tensor/tensor_map.cc | 95 +++ paddle/infrt/tensor/tensor_map.h | 29 + paddle/infrt/tensor/tensor_metadata.cc | 30 + paddle/infrt/tensor/tensor_metadata.h | 58 ++ paddle/infrt/tensor/tensor_shape.cc | 96 +++ paddle/infrt/tensor/tensor_shape.h | 82 +++ paddle/scripts/paddle_build.sh | 2 + 161 files changed, 12742 insertions(+) create mode 100644 cmake/external/llvm.cmake create mode 100644 paddle/infrt/CMakeLists.txt create mode 100644 paddle/infrt/api/CMakeLists.txt create mode 100644 paddle/infrt/api/infrt_api.cc create mode 100644 paddle/infrt/api/infrt_api.h create mode 100644 paddle/infrt/api/infrt_api_test.cc create mode 100644 paddle/infrt/common/CMakeLists.txt create mode 100644 paddle/infrt/common/buffer.cc create mode 100644 paddle/infrt/common/buffer.h create mode 100644 paddle/infrt/common/common.h create mode 100644 paddle/infrt/common/dtype.cc create mode 100644 paddle/infrt/common/dtype.def create mode 100644 paddle/infrt/common/dtype.h create mode 100644 paddle/infrt/common/global.cc create mode 100644 paddle/infrt/common/global.h create mode 100644 paddle/infrt/common/macros.h create mode 100644 paddle/infrt/common/memory.cc create mode 100644 paddle/infrt/common/memory.h create mode 100644 paddle/infrt/common/object.cc create mode 100644 paddle/infrt/common/object.h create mode 100644 paddle/infrt/common/shared.cc create mode 100644 paddle/infrt/common/shared.h create mode 100644 paddle/infrt/common/string.cc create mode 100644 paddle/infrt/common/string.h create mode 100644 paddle/infrt/common/target.cc create mode 100644 paddle/infrt/common/target.h create mode 100644 paddle/infrt/common/type.cc create mode 100644 paddle/infrt/common/type.h create mode 100644 paddle/infrt/dialect/CMakeLists.txt create mode 100644 paddle/infrt/dialect/basic_kernels.cc create mode 100644 paddle/infrt/dialect/basic_kernels.h create mode 100644 paddle/infrt/dialect/basic_kernels.td create mode 100644 paddle/infrt/dialect/dense_tensor.cc create mode 100644 paddle/infrt/dialect/dense_tensor.h create mode 100644 paddle/infrt/dialect/dense_tensor.td create mode 100644 paddle/infrt/dialect/diagnostic_utils.cc create mode 100644 paddle/infrt/dialect/diagnostic_utils.h create mode 100644 paddle/infrt/dialect/dialect.cc create mode 100644 paddle/infrt/dialect/infrt_base.cc create mode 100644 paddle/infrt/dialect/infrt_base.h create mode 100644 paddle/infrt/dialect/infrt_base.td create mode 100644 paddle/infrt/dialect/init_infrt_dialects.cc create mode 100644 paddle/infrt/dialect/init_infrt_dialects.h create mode 100644 paddle/infrt/dialect/mlir_loader.cc create mode 100644 paddle/infrt/dialect/mlir_loader.h create mode 100644 paddle/infrt/dialect/mlir_loader_test.cc create mode 100644 paddle/infrt/dialect/mlir_tests/basic.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/benchmark.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/dense_tensor.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/paddle_ops.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/rewrite.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/tensor_map.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/tensor_shape.mlir create mode 100644 paddle/infrt/dialect/mlir_tests/tensor_type.mlir create mode 100644 paddle/infrt/dialect/ops.td create mode 100644 paddle/infrt/dialect/opt.cc create mode 100644 paddle/infrt/dialect/pd_op_base.td create mode 100644 paddle/infrt/dialect/pd_ops.cc create mode 100644 paddle/infrt/dialect/pd_ops.h create mode 100644 paddle/infrt/dialect/pd_ops.td create mode 100644 paddle/infrt/dialect/pd_types.cc create mode 100644 paddle/infrt/dialect/pd_types.h create mode 100644 paddle/infrt/dialect/print_ir.cc create mode 100644 paddle/infrt/dialect/rewrite.td create mode 100644 paddle/infrt/dialect/tensor_shape.cc create mode 100644 paddle/infrt/dialect/tensor_shape.h create mode 100644 paddle/infrt/dialect/tensor_shape.td create mode 100644 paddle/infrt/dialect/tensor_shape_base.td create mode 100644 paddle/infrt/dialect/test_kernels.cc create mode 100644 paddle/infrt/dialect/test_kernels.h create mode 100644 paddle/infrt/dialect/test_kernels.td create mode 100644 paddle/infrt/dialect/types.cc create mode 100644 paddle/infrt/dialect/types.h create mode 100644 paddle/infrt/external_kernels/CMakeLists.txt create mode 100644 paddle/infrt/external_kernels/basic.mlir create mode 100644 paddle/infrt/external_kernels/basic_kernels.cc create mode 100644 paddle/infrt/external_kernels/fc.mlir create mode 100644 paddle/infrt/external_kernels/paddle.mlir create mode 100644 paddle/infrt/gtest_main.cc create mode 100644 paddle/infrt/host_context/CMakeLists.txt create mode 100644 paddle/infrt/host_context/core_runtime.cc create mode 100644 paddle/infrt/host_context/core_runtime.h create mode 100644 paddle/infrt/host_context/core_runtime_test.cc create mode 100644 paddle/infrt/host_context/function.cc create mode 100644 paddle/infrt/host_context/function.h create mode 100644 paddle/infrt/host_context/kernel_frame.cc create mode 100644 paddle/infrt/host_context/kernel_frame.h create mode 100644 paddle/infrt/host_context/kernel_registry.cc create mode 100644 paddle/infrt/host_context/kernel_registry.h create mode 100644 paddle/infrt/host_context/kernel_registry_test.cc create mode 100644 paddle/infrt/host_context/kernel_utils.cc create mode 100644 paddle/infrt/host_context/kernel_utils.h create mode 100644 paddle/infrt/host_context/kernel_utils_test.cc create mode 100644 paddle/infrt/host_context/mlir_exec.cc create mode 100644 paddle/infrt/host_context/mlir_function_executable.cc create mode 100644 paddle/infrt/host_context/mlir_function_executable.h create mode 100644 paddle/infrt/host_context/mlir_program_executor.cc create mode 100644 paddle/infrt/host_context/mlir_program_executor.h create mode 100644 paddle/infrt/host_context/mlir_tests/basic.mlir create mode 100644 paddle/infrt/host_context/mlir_tests/dense_tensor.mlir create mode 100644 paddle/infrt/host_context/mlir_tests/shape.mlir create mode 100644 paddle/infrt/host_context/mlir_to_runtime_translate.cc create mode 100644 paddle/infrt/host_context/mlir_to_runtime_translate.h create mode 100644 paddle/infrt/host_context/mlir_to_runtime_translate_test.cc create mode 100644 paddle/infrt/host_context/op_executable.cc create mode 100644 paddle/infrt/host_context/op_executable.h create mode 100644 paddle/infrt/host_context/op_executable_test.cc create mode 100644 paddle/infrt/host_context/symbol_table.cc create mode 100644 paddle/infrt/host_context/symbol_table.h create mode 100644 paddle/infrt/host_context/value.cc create mode 100644 paddle/infrt/host_context/value.h create mode 100644 paddle/infrt/host_context/value_test.cc create mode 100644 paddle/infrt/kernel/CMakeLists.txt create mode 100644 paddle/infrt/kernel/basic_kernels.cc create mode 100644 paddle/infrt/kernel/basic_kernels.h create mode 100644 paddle/infrt/kernel/control_flow_kernels.cc create mode 100644 paddle/infrt/kernel/control_flow_kernels.h create mode 100644 paddle/infrt/kernel/tensor_kernels.cc create mode 100644 paddle/infrt/kernel/tensor_kernels.h create mode 100644 paddle/infrt/kernel/tensor_shape_kernels.cc create mode 100644 paddle/infrt/kernel/tensor_shape_kernels.h create mode 100644 paddle/infrt/kernel/test_kernels.cc create mode 100644 paddle/infrt/kernel/test_kernels.h create mode 100644 paddle/infrt/paddle/CMakeLists.txt create mode 100644 paddle/infrt/paddle/cpp/CMakeLists.txt create mode 100644 paddle/infrt/paddle/cpp/desc_api.h create mode 100644 paddle/infrt/paddle/framework.proto create mode 100644 paddle/infrt/paddle/model_parser.cc create mode 100644 paddle/infrt/paddle/model_parser.h create mode 100644 paddle/infrt/paddle/pb/CMakeLists.txt create mode 100644 paddle/infrt/paddle/pb/block_desc.cc create mode 100644 paddle/infrt/paddle/pb/block_desc.h create mode 100644 paddle/infrt/paddle/pb/op_desc.cc create mode 100644 paddle/infrt/paddle/pb/op_desc.h create mode 100644 paddle/infrt/paddle/pb/program_desc.cc create mode 100644 paddle/infrt/paddle/pb/program_desc.h create mode 100644 paddle/infrt/paddle/pb/var_desc.cc create mode 100644 paddle/infrt/paddle/pb/var_desc.h create mode 100644 paddle/infrt/paddle/scope.cc create mode 100644 paddle/infrt/paddle/scope.h create mode 100644 paddle/infrt/paddle/tensor.cc create mode 100644 paddle/infrt/paddle/tensor.h create mode 100644 paddle/infrt/support/CMakeLists.txt create mode 100644 paddle/infrt/support/type_traits.h create mode 100644 paddle/infrt/support/variant.h create mode 100644 paddle/infrt/tensor/CMakeLists.txt create mode 100644 paddle/infrt/tensor/dense_host_tensor.cc create mode 100644 paddle/infrt/tensor/dense_host_tensor.h create mode 100644 paddle/infrt/tensor/dense_tensor_view.cc create mode 100644 paddle/infrt/tensor/dense_tensor_view.h create mode 100644 paddle/infrt/tensor/tensor_map.cc create mode 100644 paddle/infrt/tensor/tensor_map.h create mode 100644 paddle/infrt/tensor/tensor_metadata.cc create mode 100644 paddle/infrt/tensor/tensor_metadata.h create mode 100644 paddle/infrt/tensor/tensor_shape.cc create mode 100644 paddle/infrt/tensor/tensor_shape.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 55f1e4cd22..03f8522ad5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -216,6 +216,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE} option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF) option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF) option(WITH_CINN "Compile PaddlePaddle with CINN" OFF) +option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON) option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF) diff --git a/cmake/external/llvm.cmake b/cmake/external/llvm.cmake new file mode 100644 index 0000000000..8fd4a0741e --- /dev/null +++ b/cmake/external/llvm.cmake @@ -0,0 +1,110 @@ +include(FetchContent) + +set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz) +set(LLVM_MD5 39d32b6be466781dddf5869318dcba53) + +set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm) +set(FETCHCONTENT_QUIET OFF) +FetchContent_Declare(external_llvm + URL ${LLVM_DOWNLOAD_URL} + URL_MD5 ${LLVM_MD5} + PREFIX ${THIRD_PARTY_PATH}/llvm + SOURCE_DIR ${THIRD_PARTY_PATH}/install/llvm +) +if (NOT LLVM_PATH) + FetchContent_GetProperties(external_llvm) + if (NOT external_llvm_POPULATED) + FetchContent_Populate(external_llvm) + endif() + set(LLVM_PATH ${THIRD_PARTY_PATH}/install/llvm) + set(LLVM_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/llvm) + set(MLIR_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/mlir) +else () + set(LLVM_DIR ${LLVM_PATH}/lib/cmake/llvm) + set(MLIR_DIR ${LLVM_PATH}/lib/cmake/mlir) +endif() + +if (${CMAKE_CXX_COMPILER} STREQUAL "clang++") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi") +endif() + +message(STATUS "set LLVM_DIR: ${LLVM_DIR}") +message(STATUS "set MLIR_DIR: ${MLIR_DIR}") +find_package(LLVM REQUIRED CONFIG HINTS ${LLVM_DIR}) +find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR}) +find_package(ZLIB REQUIRED) + +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +include(AddLLVM) + +include_directories(${LLVM_INCLUDE_DIRS}) +list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}") +list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}") +include(AddLLVM) +include(TableGen) +include(AddMLIR) + +message(STATUS "Found MLIR: ${MLIR_DIR}") +message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}") +message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}") + +# To build with MLIR, the LLVM is build from source code using the following flags: + +#[==[ +cmake -G Ninja ../llvm \ + -DLLVM_ENABLE_PROJECTS="mlir;clang" \ + -DLLVM_BUILD_EXAMPLES=OFF \ + -DLLVM_TARGETS_TO_BUILD="X86" \ + -DCMAKE_BUILD_TYPE=Release \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_ZLIB=OFF \ + -DLLVM_ENABLE_RTTI=ON \ +#]==] +# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit) + +add_definitions(${LLVM_DEFINITIONS}) + +llvm_map_components_to_libnames(llvm_libs Support Core irreader + X86 executionengine orcjit mcjit all codegen) + +message(STATUS "LLVM libs: ${llvm_libs}") + +get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS) +message(STATUS "MLIR libs: ${mlir_libs}") +add_definitions(${LLVM_DEFINITIONS}) + + +# The minimum needed libraries for MLIR IR parse and transform. +set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib) + + +# tb_base is the name of a xxx.td file (without the .td suffix) +function(mlir_tablegen_on td_base) + set(options) + set(oneValueArgs DIALECT) + cmake_parse_arguments(mlir_tablegen_on "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + set(LLVM_TARGET_DEFINITIONS ${td_base}.td) + mlir_tablegen(${td_base}.hpp.inc -gen-op-decls) + mlir_tablegen(${td_base}.cpp.inc -gen-op-defs) + if (mlir_tablegen_on_DIALECT) + mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT}) + endif() + add_public_tablegen_target(${td_base}_IncGen) + add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) +endfunction() + +function(mlir_add_rewriter td_base) + set(LLVM_TARGET_DEFINITIONS ${td_base}.td) + mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass") + add_public_tablegen_target(${td_base}_IncGen) + add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen) +endfunction() + +# Execute the mlir script with infrt-exec program. +# @name: name of the test +# @script: path to the mlir script file +function (infrt_exec_check name script) + add_test(NAME ${name} + COMMAND sh -c "${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i ${CMAKE_CURRENT_SOURCE_DIR}/${script}| ${LLVM_PATH}/bin/FileCheck ${CMAKE_CURRENT_SOURCE_DIR}/${script}") +endfunction() diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 7aa1e78abb..71e1856147 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -391,6 +391,11 @@ if (WIN32) list(APPEND third_party_deps extern_dirent) endif (WIN32) +if (WITH_INFRT) + include(external/llvm) + list(APPEND third_party_deps external_llvm) +endif() + if (WITH_IPU) include(external/poplar) list(APPEND third_party_deps extern_poplar) diff --git a/paddle/CMakeLists.txt b/paddle/CMakeLists.txt index b3a1b2e8c9..4b88689b9b 100644 --- a/paddle/CMakeLists.txt +++ b/paddle/CMakeLists.txt @@ -2,4 +2,5 @@ add_subdirectory(scripts) add_subdirectory(testing) set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") add_subdirectory(pten) +add_subdirectory(infrt) add_subdirectory(fluid) diff --git a/paddle/infrt/CMakeLists.txt b/paddle/infrt/CMakeLists.txt new file mode 100644 index 0000000000..b8f6f4738d --- /dev/null +++ b/paddle/infrt/CMakeLists.txt @@ -0,0 +1,79 @@ +if (NOT WITH_INFRT) + return() +endif() + +set(infrt_src CACHE INTERNAL "" FORCE) + +# Gather headers for library publish. +function(core_gather_headers) + file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + + foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") + endforeach() +endfunction() + +function(gather_srcs SRC_GROUP) + set(options) + set(oneValueArgs) + set(multiValueArgs "SRCS") + cmake_parse_arguments(prefix "" "" "${multiValueArgs}" ${ARGN}) + foreach(cpp ${prefix_SRCS}) + set(${SRC_GROUP} "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" CACHE INTERNAL "") + endforeach() +endfunction() + +# This method is similar to the global cc_test, but discard the huge amount default dependencies those are +# not needed by INFRT. +function(cc_test_tiny TARGET_NAME) + if(WITH_TESTING) + set(options SERIAL) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS ARGS) + cmake_parse_arguments(cc_test_tiny "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + add_executable(${TARGET_NAME} ${cc_test_tiny_SRCS}) + get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) + target_link_libraries(${TARGET_NAME} ${cc_test_tiny_DEPS} ${os_dependency_modules} infrt_gtest_main gtest ) + add_dependencies(${TARGET_NAME} ${cc_test_tiny_DEPS} infrt_gtest_main gtest extern_gtest) + + add_test(NAME ${TARGET_NAME} + COMMAND ${TARGET_NAME} "${cc_test_tiny_ARGS}" + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + if (${cc_test_tiny_SERIAL}) + set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1) + endif() + endif() + +endfunction() + +if (WITH_TESTING) + cc_library(infrt_gtest_main SRCS gtest_main.cc DEPS gtest glog gflags) +endif() + + +add_subdirectory(api) +add_subdirectory(common) +add_subdirectory(dialect) +add_subdirectory(host_context) +add_subdirectory(kernel) +add_subdirectory(tensor) +add_subdirectory(support) +add_subdirectory(external_kernels) +add_subdirectory(paddle) + + +# MLIR td file generations +set(infrt_mlir_incs + ops_inc + basic_kernels_inc + test_kernels_inc + infrt_base_inc + tensor_shape_inc + dense_tensor_inc + pd_ops_inc + rewrite_inc + ) +message(STATUS "infrt srcs:\n${infrt_src}") + +cc_library(infrt SRCS ${infrt_src} DEPS glog ${mlir_libs} paddle_framework_proto) +add_dependencies(infrt ${infrt_mlir_incs}) diff --git a/paddle/infrt/api/CMakeLists.txt b/paddle/infrt/api/CMakeLists.txt new file mode 100644 index 0000000000..93a7ae8369 --- /dev/null +++ b/paddle/infrt/api/CMakeLists.txt @@ -0,0 +1,8 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + infrt_api.cc + ) + +# Disable temporarily for the external-kernel's mkldnn is outdate +# cc_test(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/api/infrt_api.cc b/paddle/infrt/api/infrt_api.cc new file mode 100644 index 0000000000..c2a4e0aff7 --- /dev/null +++ b/paddle/infrt/api/infrt_api.cc @@ -0,0 +1,246 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/api/infrt_api.h" + +#include +#include +#include +#include + +#include +#include + +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/dense_tensor.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" +#include "paddle/infrt/host_context/op_executable.h" +#include "paddle/infrt/host_context/value.h" +#include "paddle/infrt/kernel/basic_kernels.h" +#include "paddle/infrt/kernel/control_flow_kernels.h" +#include "paddle/infrt/kernel/tensor_kernels.h" +#include "paddle/infrt/kernel/tensor_shape_kernels.h" +#include "paddle/infrt/kernel/test_kernels.h" +#include "paddle/infrt/tensor/tensor_map.h" + +using namespace infrt::host_context; // NOLINT +using namespace infrt::tensor; // NOLINT +using namespace infrt::tensor; // NOLINT +using infrt::dt::TensorMapType; // NOLINT +using infrt::dt::TensorType; // NOLINT + +namespace infrt { + +template +std::string DumpToString(T& op) { // NOLINT + std::string buffer; + llvm::raw_string_ostream os(buffer); + op.print(os); + os.flush(); + return buffer; +} + +struct MlirToRuntimeTranslator::Impl { + mlir::ModuleOp module; + // The runtime for a function call. + CoreRuntimeBuilder* runtime{}; + + // The current working op, the translator process the ops one by one, each + // time it updates `cur_op` here to current op + // working on. + OpExecutableBuilder* cur_op{}; + + // record the current function name. + std::string cur_func_name; + + // Name to function definitions. + std::unordered_map func_defs; + + // Map from an operation to its results. + std::unordered_map> op_results; + llvm::DenseMap value_map; +}; + +/** + * Execute the mlir program in predict mode. + */ +class PredictExecutor : public MlirToRuntimeTranslator { + public: + CoreRuntimeBuilder core_runtime; + + PredictExecutor(mlir::ModuleOp module, + KernelRegistry* registry, + TensorMap* map) + : MlirToRuntimeTranslator(module, &core_runtime), + core_runtime(registry), + registry_(registry) { + CHECK(registry_); + Init(map); + } + + void Run() { + auto arguments = llvm::makeArrayRef(arguments_); + auto results = llvm::makeMutableArrayRef(results_.begin(), results_.size()); + function_executable_->Execute(arguments, results); + } + + int GetInputNum() { return inputs_.size(); } + + DenseHostTensor* GetInput(int i) { return inputs_[i]; } + + int GetOutputNum() { return outputs_.size(); } + + DenseHostTensor* GetOutput(int i) { return outputs_[i]; } + + private: + void Init(TensorMap* map) { + EmitFunctions(); + llvm::Optional predict_func_ = llvm::None; + for (auto func_op : impl_->module.getOps()) { + if (func_op.getName().str() != "predict") continue; + predict_func_ = func_op; + break; + } + if (!predict_func_) { + std::cout << "ERROR: init failed, no predict function found in mlir." + << std::endl; + return; + } + auto& predict_func = predict_func_.getValue(); + function_executable_ = + new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs); + + // process parammeters + for (size_t i = 0; i < predict_func.getNumArguments(); ++i) { + auto arg = predict_func.getArgument(i); + auto type = arg.getType(); + // this param is TensorMap + if (type.isa()) { + auto* value = new host_context::Value(std::move(*map)); + arguments_.push_back(value); + AddValue(predict_func.getArgument(i), value); + } else { + // this param is an input Tensor + auto dht = DenseHostTensor(); + auto* value = new host_context::Value(std::move(dht)); + arguments_.push_back(value); + inputs_.push_back(&(value->get())); + } + } + + // process results + auto& last_op = predict_func.front().back(); + if (last_op.getName().getStringRef() == "infrt.return") { + for (size_t i = 0; i < last_op.getNumOperands(); ++i) { + auto* value = AddValue(mlir::Value(last_op.getOperand(i))); + results_.push_back(ValueRef(value)); + outputs_.push_back(&(value->get())); + } + } + } + + protected: + std::unordered_map func_def_table; + + void EmitFunction(mlir::FuncOp op) override { + CHECK(!impl_->func_defs.count(op.getName().str())) + << "Duplicate function defition found for function [" + << op.getName().str(); + impl_->func_defs.emplace(op.getName().str(), op); + } + + private: + KernelRegistry* registry_{}; + MlirFunctionExecutable* function_executable_; + llvm::SmallVector inputs_; + llvm::SmallVector arguments_; + llvm::SmallVector outputs_; + llvm::SmallVector results_; +}; + +std::shared_ptr CreateInfRtPredictor( + const InfRtConfig& config) { + auto x = std::make_shared(); + x->Init(config); + return x; +} + +struct InfRtPredictor::Impl { + mlir::OwningModuleRef module_ref; + std::unique_ptr executor; +}; + +InfRtPredictor::InfRtPredictor() : impl_(new Impl) {} +InfRtPredictor::~InfRtPredictor() {} + +void InfRtPredictor::Run() { impl_->executor->Run(); } + +int InfRtPredictor::Init(const InfRtConfig& config) { + mlir::MLIRContext* context = infrt::Global::getMLIRContext(); + auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context); + + KernelRegistry* registry = new KernelRegistry(); + + kernel::RegisterBasicKernels(registry); + kernel::RegisterTestKernels(registry); + kernel::RegisterTensorShapeKernels(registry); + kernel::RegisterTensorKernels(registry); + kernel::RegisterControlFlowKernels(registry); + + impl_->module_ref = std::move(module_ref); + + // load extra shared library + for (const std::string& lib_path : config.shared_libs()) { + std::string err; + llvm::sys::DynamicLibrary dynLib = + llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err); + if (!dynLib.isValid()) { + llvm::errs() << "Load shared library failed. Error: " << err << "\n"; + return 1; + } + if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) { + auto reg_func = reinterpret_cast(reg_sym); + reg_func(registry); + } else { + llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path + << "\". Skip.\n"; + } + } + + // Load params + TensorMap* tensor_map = LoadParams(config.model_dir()); + + // Create PredictExecutor + impl_->executor.reset( + new PredictExecutor(impl_->module_ref.get(), registry, tensor_map)); + return 0; +} + +int InfRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); } + +DenseHostTensor* InfRtPredictor::GetInput(int i) { + return impl_->executor->GetInput(i); +} + +int InfRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); } + +DenseHostTensor* InfRtPredictor::GetOutput(int i) { + return impl_->executor->GetOutput(i); +} + +} // namespace infrt diff --git a/paddle/infrt/api/infrt_api.h b/paddle/infrt/api/infrt_api.h new file mode 100644 index 0000000000..82b6cb8df9 --- /dev/null +++ b/paddle/infrt/api/infrt_api.h @@ -0,0 +1,63 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/infrt/tensor/dense_host_tensor.h" + +namespace infrt { + +class InfRtConfig { + std::string model_dir_; + std::string mlir_path_; + std::vector shared_libs_; + + public: + InfRtConfig() = default; + void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; } + const std::string& model_dir() const { return model_dir_; } + + void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; } + const std::string& mlir_path() const { return mlir_path_; } + + void set_shared_libs(const std::vector& shared_libs) { + shared_libs_ = shared_libs; + } + const std::vector& shared_libs() const { return shared_libs_; } + + virtual ~InfRtConfig() = default; +}; + +class InfRtPredictor { + public: + InfRtPredictor(); + ~InfRtPredictor(); + void Run(); + int Init(const InfRtConfig& config); + int GetInputNum(); + tensor::DenseHostTensor* GetInput(int i); + int GetOutputNum(); + tensor::DenseHostTensor* GetOutput(int i); + + protected: + struct Impl; + std::unique_ptr impl_; +}; + +std::shared_ptr CreateInfRtPredictor(const InfRtConfig& config); + +} // namespace infrt diff --git a/paddle/infrt/api/infrt_api_test.cc b/paddle/infrt/api/infrt_api_test.cc new file mode 100644 index 0000000000..92e069f475 --- /dev/null +++ b/paddle/infrt/api/infrt_api_test.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/api/infrt_api.h" + +#include + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "paddle/infrt/common/buffer.h" +#include "paddle/infrt/common/dtype.h" + +using infrt::InfRtConfig; +using infrt::InfRtPredictor; +using infrt::CreateInfRtPredictor; + +namespace infrt { + +TEST(InfRtPredictor, predictor) { + std::vector shared_libs; + shared_libs.push_back("../../paddle/libexternal_kernels.so"); + + InfRtConfig config; + + // set external shared libraries that contain kernels. + config.set_shared_libs(shared_libs); + // set model dir + config.set_model_dir("../../paddle/paddle_1.8_fc_model"); + // set mlir path + config.set_mlir_path("../../../infrt/dialect/mlir_tests/tensor_map.mlir"); + + std::shared_ptr predictor = CreateInfRtPredictor(config); + + auto* input = predictor->GetInput(0); + std::vector shape = {3, 3}; + input->Init(shape, infrt::GetDType()); + llvm::outs() << input->shape() << "\n"; + + // init input tensor + auto* input_data = reinterpret_cast(input->buffer()->data()->memory); + for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0; + + predictor->Run(); + + // get and print output tensor + auto* output = predictor->GetOutput(0); + auto* output_data = + reinterpret_cast(output->buffer()->data()->memory); + + std::vector ans = {0.428458, + 0.244493, + 0.572342, + 0.572008, + 0.509771, + 0.495599, + 0.651287, + 0.326426, + 0.404649}; + + ASSERT_EQ(output->shape().GetNumElements(), ans.size()); + for (int i = 0; i < output->shape().GetNumElements(); ++i) { + ASSERT_NEAR(output_data[i], ans[i], 0.000001); + } +} + +} // namespace infrt diff --git a/paddle/infrt/common/CMakeLists.txt b/paddle/infrt/common/CMakeLists.txt new file mode 100644 index 0000000000..931e3e4230 --- /dev/null +++ b/paddle/infrt/common/CMakeLists.txt @@ -0,0 +1,14 @@ +core_gather_headers() +set(core_includes "${core_includes};infrt/common/dtype.def" CACHE INTERNAL "") + +gather_srcs(infrt_src SRCS + dtype.cc + global.cc + target.cc + type.cc + shared.cc + object.cc + string.cc + buffer.cc + memory.cc + ) diff --git a/paddle/infrt/common/buffer.cc b/paddle/infrt/common/buffer.cc new file mode 100644 index 0000000000..bc4ec7fead --- /dev/null +++ b/paddle/infrt/common/buffer.cc @@ -0,0 +1,98 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/buffer.h" + +#include +#include + +#include + +namespace infrt { +void Buffer::Resize(uint32_t size) { + if (size_ > 0) { + Free(); + size_ = 0; + } + + if (size_ != size) { + data_.memory = reinterpret_cast(Malloc(size)); + size_ = size; + } +} + +void Buffer::Resize(uint32_t alignment, uint32_t size) { + if (size_ > 0) { + Free(); + size_ = 0; + } + + if (size_ != size) { + data_.memory = reinterpret_cast(AlignedAlloc(alignment, size)); + size_ = size; + } +} + +void Buffer::SetTarget(const infrt::common::Target& target) { + target_ = target; + memory_mng_cache_ = MemoryManager::Global().RetrieveSafely(target_.arch); +} + +void Buffer::ResizeLazy(uint32_t size) { + if (size <= size_) return; + Resize(size); +} + +void Buffer::ResizeLazy(uint32_t alignment, uint32_t size) { + if (size <= size_) return; + Resize(alignment, size); +} + +void Buffer::Resize(uint32_t size, const infrt::common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + Resize(size); +} + +void Buffer::Resize(uint32_t alignment, + uint32_t size, + const infrt::common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + Resize(alignment, size); +} + +void Buffer::ResizeLazy(uint32_t size, const infrt::common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + ResizeLazy(size); +} + +void Buffer::ResizeLazy(uint32_t alignment, + uint32_t size, + const infrt::common::Target& target) { + if (target.arch != target_.arch) { + Free(); + SetTarget(target); + } + ResizeLazy(alignment, size); +} + +} // namespace infrt diff --git a/paddle/infrt/common/buffer.h b/paddle/infrt/common/buffer.h new file mode 100644 index 0000000000..cae2a7ead9 --- /dev/null +++ b/paddle/infrt/common/buffer.h @@ -0,0 +1,296 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +#include "paddle/infrt/common/macros.h" +#include "paddle/infrt/common/memory.h" +#include "paddle/infrt/common/target.h" + +namespace infrt { + +#ifdef __cplusplus +extern "C" { +#endif + +#define INFRT_ALWAYS_INLINE __attribute__((always_inline)) inline + +//! Code for the primitive types supported in INFRT. +typedef enum infrt_type_code_t { + infrt_type_unk = -1, //! Unknown type + infrt_type_int = 0, //! signed int + infrt_type_uint = 1, //! unsigned int + infrt_type_float = 2, //! floating point + infrt_type_handle = 3 //! void* +} infrt_type_code_t; + +#ifndef INFRT_ATTRIBUTE_ALIGN +#define INFRT_ATTRIBUTE_ALIGN(n) __attribute__((aligned(n))) +#endif + +/** + * A tuntime tag for type in INFRT system. + */ +typedef struct infrt_type_t { +#if __cplusplus >= 201103L + INFRT_ATTRIBUTE_ALIGN(1) infrt_type_code_t code; +#else + uint8_t code; +#endif + + //! Number of bits. + uint8_t bits; + + //! Number of elements in a vector, 1 for scalar. + uint16_t lanes; + + //! Number of '*', e.g. for `float*`, the num_asterisks is 1, `float**` it is + //! 2. + uint8_t num_asterisks{0}; + +#ifdef __cplusplus + INFRT_ALWAYS_INLINE infrt_type_t() + : code(infrt_type_int), bits(0), lanes(0) {} + INFRT_ALWAYS_INLINE infrt_type_t(infrt_type_code_t code, + uint8_t bits, + uint16_t lanes = 1, + uint8_t num_asterisks = 0) + : code(code), bits(bits), lanes(lanes), num_asterisks(num_asterisks) {} + INFRT_ALWAYS_INLINE bool operator==(const infrt_type_t& other) const { + return code == other.code && bits == other.bits && lanes == other.lanes; + } + INFRT_ALWAYS_INLINE bool operator!=(const infrt_type_t& other) const { + return !(*this == other); + } + INFRT_ALWAYS_INLINE uint16_t bytes() const { return (bits + 7) / 8; } +#endif // __cplusplus +} infrt_type_t; + +//! Help to define the size of a dimension, due to polyhedral representation, we +//! no need to record the extend or +//! min(default to 0). +typedef int infrt_dimension_t; + +//! Help to tell the kind of the device. +typedef enum infrt_device_kind_t { + infrt_unk_device = -1, // Undefined device. + infrt_x86_device = 0, // X86 device + infrt_opencl_device = 1, // OpenCL device + infrt_arm_device = 2 // ARM device +} infrt_device_kind_t; + +struct infrt_buffer_t; + +/** + * All INFRT backends implementation should provide an interface to be used. + */ +struct infrt_device_interface_impl_t; + +struct infrt_device_interface_t { + int (*malloc)(void* context, struct infrt_buffer_t* buf); + int (*free)(void* context, struct infrt_buffer_t* buf); + int (*sync)(void* context, struct infrt_buffer_t* buf); + int (*release)(void* context, + const struct infrt_device_interface_t* device_interface); + int (*copy_to_host)(void* context, struct infrt_buffer_t* buf); + int (*copy_to_device)(void* context, struct infrt_buffer_t* buf); + int (*buffer_copy)(void* context, + struct infrt_buffer_t* src, + struct infrt_buffer_t* dst); + struct infrt_device_interface_impl_t* impl; +}; + +//! The raw representation of a buffer,used in the generated code/lib. +#define INFRT_BUFFER_MAX_DIMS 8 +typedef struct infrt_buffer_t { + //! Tell which kind of device this buffer locates. + infrt_device_kind_t device; + + //! The interface used to operate on device. + const struct infrt_device_interface_t* device_interface; + + //! A pointer to the memory in host. + uint8_t* memory; + + //! Extra flags. + uint64_t flag; + + //! Data type. + infrt_type_t type; + + //! Number of dimensions. + int32_t dimensions; + infrt_dimension_t dims[INFRT_BUFFER_MAX_DIMS]; + + //! Allocate and deallocate lazily, default true. + char lazy; + + //! The actual memory size(in bytes). + uint64_t memory_size; + + uint16_t align; + +#ifdef __cplusplus + infrt_buffer_t() + : device(infrt_unk_device), + device_interface(NULL), + memory(NULL), + flag(0UL), + type(infrt_type_t()), + dimensions(0), + lazy(true), + memory_size(0), + align(0) {} + + static void delete_(struct infrt_buffer_t* x) { delete x; } + + ~infrt_buffer_t() {} + + // NOTE the buffer should be resized first. + static void alloc(struct infrt_buffer_t*); + + //! Set the shape of the buffer. NOTE this just record the shape, not allocate + //! the memory. + INFRT_ALWAYS_INLINE void resize(const infrt_dimension_t* dims, + int dimensions) { + this->dimensions = dimensions; + memcpy(this->dims, dims, dimensions * sizeof(infrt_dimension_t)); + } + + INFRT_ALWAYS_INLINE uint64_t num_elements() const { + uint64_t res = 1; + for (int i = 0; i < dimensions; i++) { + res *= dims[i]; + } + return res; + } + + INFRT_ALWAYS_INLINE int device_sync(void* ctx = NULL) { + if (device_interface && device_interface->sync) { + return device_interface->sync(ctx, this); + } + return 0; + } + + INFRT_ALWAYS_INLINE uint8_t* begin() const { return 0; } + INFRT_ALWAYS_INLINE uint8_t* end() const { + return memory + num_elements() * type.bytes(); + } + +#endif // __cplusplus +} infrt_buffer_t; + +#ifdef __cplusplus +struct infrt_device_interface_impl_t { + int (*malloc)(void* context, struct infrt_buffer_t* buf); + int (*free)(void* context, struct infrt_buffer_t* buf); + int (*sync)(void* context, struct infrt_buffer_t* buf); + int (*release)(void* context); + int (*copy_to_host)(void* context, struct infrt_buffer_t* buf); + int (*copy_to_device)(void* context, struct infrt_buffer_t* buf); + int (*buffer_copy)(void* context, + struct infrt_buffer_t* src, + struct infrt_buffer_t* dst); +}; + +// The device implementations +extern struct infrt_device_interface_t* infrt_x86_device_interface(); +#endif // __cplusplus + +#ifdef __cplusplus +} // extern "C" +#endif + +#define INFRT_LOG(fmt, ...) \ + do { \ + fprintf(stderr, \ + "%s:%d:%s(): " fmt, \ + __FILE__, \ + __LINE__, \ + __func__, \ + __VA_ARGS__); \ + } while (0) + +#define INFRT_CHECK(cond) \ + if (!(cond)) { \ + INFRT_LOG("check %s failed", #cond); \ + abort(); \ + } +/** + * Buffer helps to hold the memory, and offers a set of methods to help manage + * the memory. + */ +struct Buffer final { + Buffer() = default; + explicit Buffer(const common::Target& target) { SetTarget(target); } + + //! Resize the memory hold by this buffer *exactlly* to \p size. + void Resize(uint32_t size); + void Resize(uint32_t alignment, uint32_t size); + + //! Lazily resize the memory. + void ResizeLazy(uint32_t size); + void ResizeLazy(uint32_t alignment, uint32_t size); + + //! Resize the memory to \p size in target \p target. + void Resize(uint32_t size, const common::Target& target); + void Resize(uint32_t alignment, uint32_t size, const common::Target& target); + + //! Lazily resize the memory to \p size in target \p target. + void ResizeLazy(uint32_t size, const common::Target& target); + void ResizeLazy(uint32_t alignment, + uint32_t size, + const common::Target& target); + + void SetTarget(const common::Target& target); + + const infrt_buffer_t* data() const { return &data_; } + infrt_buffer_t* data() { return &data_; } + + //! Free all the memory owned by this buffer. + void Free() { + if (!data_.memory) return; + memory_mng_cache_->free(data_.memory); + } + + private: + inline void* Malloc(uint32_t size) INFRT_RESULT_SHOULD_USE { + CHECK(memory_mng_cache_) << "Should set target first"; + return memory_mng_cache_->malloc(size); + } + + inline void* AlignedAlloc(uint32_t alignment, + uint32_t size) INFRT_RESULT_SHOULD_USE { + CHECK(memory_mng_cache_) << "Should set target first"; + return memory_mng_cache_->aligned_alloc(alignment, size); + } + + private: + infrt_buffer_t data_; + + //! The place where this buffer locates. + common::Target target_; + + //! Number of bytes of this buffer. + uint32_t size_{}; + + //! Hold the corresponding memory manager for speed. + MemoryInterface* memory_mng_cache_{}; +}; + +} // namespace infrt diff --git a/paddle/infrt/common/common.h b/paddle/infrt/common/common.h new file mode 100644 index 0000000000..a15bc69b60 --- /dev/null +++ b/paddle/infrt/common/common.h @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/infrt/common/macros.h" +#include "paddle/infrt/common/shared.h" +#include "paddle/infrt/common/target.h" +#include "paddle/infrt/common/type.h" + +namespace infrt { + +// export some general concepts. +using common::make_shared; +using common::Object; +using common::ref_count; +using common::Shared; + +// Type related. +using common::Bool; +using common::Float; +using common::Int; +using common::UInt; +using common::Void; + +using common::type_of; + +using common::Target; +using common::Type; +using common::UnkTarget; + +template +T& Reference(const T* x) { + return *const_cast(x); +} + +static void CheckVarNameValid(const std::string& name) { + CHECK(!name.empty()); + CHECK(name.find(' ') == std::string::npos && // + name.find('.') == std::string::npos && // + name.find('/') == std::string::npos && // + name.find('\t') == std::string::npos && // + name.find('\n') == std::string::npos && // + name.find('\r') == std::string::npos) + << "Some invalid character found"; +} + +} // namespace infrt diff --git a/paddle/infrt/common/dtype.cc b/paddle/infrt/common/dtype.cc new file mode 100644 index 0000000000..d5cf67d8a3 --- /dev/null +++ b/paddle/infrt/common/dtype.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/dtype.h" + +namespace infrt { + +const char* DType::name() const { + switch (kind_) { + case Kind::Unk: + return "Unk"; + break; +#define INFRT_DTYPE(enum__, value__) \ + case Kind::enum__: \ + return #enum__; \ + break; +#include "paddle/infrt/common/dtype.def" +#undef INFRT_DTYPE + } + + return ""; +} + +size_t DType::GetHostSize() const { + switch (kind_) { +#define INFRT_DTYPE(enum__, value__) \ + case DType::Kind::enum__: \ + return sizeof(DTypeInternal::type); +#include "paddle/infrt/common/dtype.def" // NOLINT +#undef INFRT_DTYPE + + case Kind::Unk: + return 0; + break; + } + return 0; +} + +} // namespace infrt diff --git a/paddle/infrt/common/dtype.def b/paddle/infrt/common/dtype.def new file mode 100644 index 0000000000..32df72aa76 --- /dev/null +++ b/paddle/infrt/common/dtype.def @@ -0,0 +1,18 @@ +// Define all INFRT dtypes +// DTYPE(ENUM, VALUE) +#ifdef INFRT_DTYPE + +INFRT_DTYPE(UI8, 1) +INFRT_DTYPE(UI16, 2) +INFRT_DTYPE(UI32, 3) +INFRT_DTYPE(UI64, 4) +INFRT_DTYPE(I1, 5) +INFRT_DTYPE(I8, 6) +INFRT_DTYPE(I16, 7) +INFRT_DTYPE(I32, 8) +INFRT_DTYPE(I64, 9) +INFRT_DTYPE(F32, 10) +INFRT_DTYPE(F64, 11) +INFRT_DTYPE(STRING, 12) + +#endif \ No newline at end of file diff --git a/paddle/infrt/common/dtype.h b/paddle/infrt/common/dtype.h new file mode 100644 index 0000000000..8b57299fa9 --- /dev/null +++ b/paddle/infrt/common/dtype.h @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +namespace infrt { +class DType { + public: + enum class Kind : uint8_t { + Unk = 0, + +// Automatically generate the enum definition +#define INFRT_DTYPE(enum__, value__) enum__ = value__, +#include "paddle/infrt/common/dtype.def" +#undef INFRT_DTYPE + + BOOL = I1, + }; + + DType() = default; + explicit constexpr DType(Kind kind) : kind_(kind) { assert(IsValid()); } + + DType(const DType&) = default; + DType& operator=(const DType&) = default; + bool operator==(DType other) const { return kind_ == other.kind_; } + bool operator!=(DType other) const { return !(*this == other); } + + constexpr Kind kind() const { return kind_; } + + bool IsValid() const { return kind_ != Kind::Unk; } + bool IsInvalid() const { return !IsValid(); } + + const char* name() const; + + size_t GetHostSize() const; + + private: + Kind kind_{Kind::Unk}; +}; + +template +constexpr DType GetDType(); + +template +struct DTypeInternal; + +#define INFRT_IMPL_GET_DTYPE(cpp_type__, enum__) \ + template <> \ + inline constexpr DType GetDType() { \ + return DType{DType::Kind::enum__}; \ + } \ + template <> \ + struct DTypeInternal { \ + using type = cpp_type__; \ + }; + +INFRT_IMPL_GET_DTYPE(bool, I1); +INFRT_IMPL_GET_DTYPE(int8_t, I8); +INFRT_IMPL_GET_DTYPE(int16_t, I16); +INFRT_IMPL_GET_DTYPE(int32_t, I32); +INFRT_IMPL_GET_DTYPE(int64_t, I64); +INFRT_IMPL_GET_DTYPE(uint8_t, UI8); +INFRT_IMPL_GET_DTYPE(uint16_t, UI16); +INFRT_IMPL_GET_DTYPE(uint32_t, UI32); +INFRT_IMPL_GET_DTYPE(uint64_t, UI64); +INFRT_IMPL_GET_DTYPE(float, F32); +INFRT_IMPL_GET_DTYPE(double, F64); +INFRT_IMPL_GET_DTYPE(std::string, STRING); + +} // namespace infrt diff --git a/paddle/infrt/common/global.cc b/paddle/infrt/common/global.cc new file mode 100644 index 0000000000..54ecf1589a --- /dev/null +++ b/paddle/infrt/common/global.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/global.h" + +namespace infrt { + +Global::Global() {} + +mlir::MLIRContext* Global::context = nullptr; + +mlir::MLIRContext* Global::getMLIRContext() { + if (nullptr == context) { + context = new mlir::MLIRContext(); + } + return context; +} + +} // namespace infrt diff --git a/paddle/infrt/common/global.h b/paddle/infrt/common/global.h new file mode 100644 index 0000000000..f89164d03f --- /dev/null +++ b/paddle/infrt/common/global.h @@ -0,0 +1,32 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/MLIRContext.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" + +namespace infrt { + +// global variables +class Global { + private: + static mlir::MLIRContext *context; + Global(); + + public: + static mlir::MLIRContext *getMLIRContext(); +}; // class Global + +} // namespace infrt diff --git a/paddle/infrt/common/macros.h b/paddle/infrt/common/macros.h new file mode 100644 index 0000000000..4481f6b38a --- /dev/null +++ b/paddle/infrt/common/macros.h @@ -0,0 +1,52 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#if !defined(NDEBUG) +#define INFRT_DEBUG +#endif + +#define INFRT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +#ifndef INFRT_NOT_IMPLEMENTED +#define INFRT_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented"; +#endif + +#define INFRT_RESULT_SHOULD_USE __attribute__((warn_unused_result)) + +/** + * A trick to enforce the registry. + * + * usage: + * + * INFRT_REGISTER_HELPER(some_key) { + * // register methods + * } + * + * INFRT_USE_REGISTER(some_key); + */ +#define INFRT_REGISTER_HELPER(symbol__) bool __infrt__##symbol__##__registrar() +#define INFRT_USE_REGISTER(symbol__) \ + extern bool __infrt__##symbol__##__registrar(); \ + [[maybe_unused]] static bool __infrt_extern_registrar_##symbol__ = \ + __infrt__##symbol__##__registrar(); + +#if __cplusplus >= 201703L +#define INFRT_NODISCARD [[nodiscard]] +#else +#define INFRT_NODISCARD +#endif diff --git a/paddle/infrt/common/memory.cc b/paddle/infrt/common/memory.cc new file mode 100644 index 0000000000..aa5983a56c --- /dev/null +++ b/paddle/infrt/common/memory.cc @@ -0,0 +1,42 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/memory.h" + +namespace infrt { + +using infrt::common::Target; + +namespace { + +class X86MemoryMng : public MemoryInterface { + public: + void* malloc(size_t nbytes) override { return ::malloc(nbytes); } + void free(void* data) override { + if (!data) return; + ::free(data); + } + void* aligned_alloc(size_t alignment, size_t nbytes) override { + return ::aligned_alloc(alignment, nbytes); + } +}; + +} // namespace + +MemoryManager::MemoryManager() { + Register(Target::Arch::Unk, new X86MemoryMng); + Register(Target::Arch::X86, new X86MemoryMng); +} + +} // namespace infrt diff --git a/paddle/infrt/common/memory.h b/paddle/infrt/common/memory.h new file mode 100644 index 0000000000..678529b8b7 --- /dev/null +++ b/paddle/infrt/common/memory.h @@ -0,0 +1,76 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "paddle/infrt/common/macros.h" +#include "paddle/infrt/common/target.h" + +namespace infrt { + +class MemoryInterface { + public: + virtual void* malloc(size_t nbytes) = 0; + virtual void free(void* data) = 0; + virtual void* aligned_alloc(size_t alignment, size_t nbytes) { + return nullptr; + } + virtual ~MemoryInterface() {} +}; + +/** + * MemoryManager holds a map of MemoryInterface for each articture. + */ +class MemoryManager final { + public: + using key_t = common::Target::Arch; + + static MemoryManager& Global() { + static auto* x = new MemoryManager; + return *x; + } + + MemoryInterface* Retrieve(key_t key) INFRT_RESULT_SHOULD_USE { + auto it = memory_mngs_.find(key); + if (it != memory_mngs_.end()) return it->second.get(); + return nullptr; + } + + MemoryInterface* RetrieveSafely(key_t key) { + auto* res = Retrieve(key); + CHECK(res) << "no MemoryInterface for architecture [" << key << "]"; + return res; + } + + MemoryInterface* Register(key_t key, MemoryInterface* item) { + CHECK(!memory_mngs_.count(key)) << "Duplicate register [" << key << "]"; + memory_mngs_[key].reset(item); + return item; + } + + private: + MemoryManager(); + + std::unordered_map> + memory_mngs_; + + INFRT_DISALLOW_COPY_AND_ASSIGN(MemoryManager); +}; + +} // namespace infrt diff --git a/paddle/infrt/common/object.cc b/paddle/infrt/common/object.cc new file mode 100644 index 0000000000..6842ff7ba0 --- /dev/null +++ b/paddle/infrt/common/object.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/object.h" + +namespace infrt { +namespace common {} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/object.h b/paddle/infrt/common/object.h new file mode 100644 index 0000000000..ab2d00cce9 --- /dev/null +++ b/paddle/infrt/common/object.h @@ -0,0 +1,81 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include "paddle/infrt/common/shared.h" + +namespace infrt { +namespace common { + +template +class Shared; +/** + * Object is the basic element in the INFRT, with `Shared` wrapper, the object + * can be shared accross the system. + */ +struct Object { + //! Get the type representation of this object. + virtual const char* type_info() const = 0; + virtual ~Object() {} + + //! Cast to a derived type. + template + T* as() { + return static_cast(this); + } + + //! Cast to a derived type. + template + const T* as() const { + return static_cast(this); + } + + //! Type safe cast. + template + T* safe_as() { + if (std::strcmp(type_info(), T::__type_info__) == 0) { + return static_cast(this); + } + return nullptr; + } + //! Type safe cast. + template + const T* safe_as() const { + if (std::strcmp(type_info(), T::__type_info__) == 0) { + return static_cast(this); + } + return nullptr; + } + + //! Check if the type is right. + template + bool is_type() const { + if (std::strcmp(type_info(), T::__type_info__) == 0) { + return true; + } + return false; + } + + //! The reference count, which make all the derived type able to share. + mutable RefCount __ref_count__; +}; + +using object_ptr = Object*; +using shared_object = Shared; + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/shared.cc b/paddle/infrt/common/shared.cc new file mode 100644 index 0000000000..78457b7ed3 --- /dev/null +++ b/paddle/infrt/common/shared.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/shared.h" diff --git a/paddle/infrt/common/shared.h b/paddle/infrt/common/shared.h new file mode 100644 index 0000000000..dbcf2b0597 --- /dev/null +++ b/paddle/infrt/common/shared.h @@ -0,0 +1,153 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace infrt { +namespace common { + +class RefCount { + public: + using value_type = int32_t; + RefCount() = default; + + value_type Inc() { return ++count_; } + value_type Dec() { return --count_; } + bool is_zero() const { return 0 == count_; } + std::string to_string() { return std::to_string(count_.load()); } + int32_t val() const { return count_; } + + private: + std::atomic count_{0}; +}; + +class Object; +/** + * The templated methods are used to unify the way to get the RefCount instance + * in client classes. + */ +template +RefCount& ref_count(const T* t) { + static_assert(std::is_base_of::value, "T is not a Object"); + return t->__ref_count__; +} +template +void Destroy(const T* t) { + delete t; +} + +template +struct Shared { + using object_ptr = T*; + + Shared() = default; + explicit Shared(T* p) : p_(p) { + if (p) IncRef(p); + } + Shared(const Shared& other) : p_(other.p_) { IncRef(p_); } + Shared(Shared&& other) : p_(other.p_) { other.p_ = nullptr; } + Shared& operator=(const Shared& other); + + //! Reset to another pointer \p x. + void Reset(T* x = nullptr); + + //! Access the pointer in various ways. + // @{ + inline T* get() const { return p_; } + inline T& operator*() const { return *p_; } + inline T* operator->() const { return p_; } + inline T* self() { return p_; } + inline const T* self() const { return p_; } + // @} + + inline bool same_as(const Shared& other) { return p_ == other.p_; } + inline bool defined() const { return p_; } + inline bool operator<(const Shared& other) const { return p_ < other.p_; } + inline Shared& operator=(T* x); + inline bool operator==(const Shared& other) const { return p_ == other.p_; } + + ~Shared(); + + private: + //! Increase the share count. + void IncRef(T* p); + + //! Decrease the share count. + void DecRef(T* p); + + protected: + T* p_{}; +}; + +template +void Shared::IncRef(T* p) { + if (p) { + ref_count(p).Inc(); + } +} +template +void Shared::DecRef(T* p) { + if (p) { + if (ref_count(p).Dec() == 0) { + Destroy(p); + } + } +} +template +Shared& Shared::operator=(const Shared& other) { + if (other.p_ == p_) return *this; + // Other can be inside of something owned by this, so we should be careful to + // incref other before we decref + // ourselves. + T* tmp = other.p_; + IncRef(tmp); + DecRef(p_); + p_ = tmp; + return *this; +} + +template +T* make_shared(Args&&... args) { + return new T(args...); +} + +template +Shared& Shared::operator=(T* x) { + if (p_ == x) return *this; + + T* tmp = x; + IncRef(tmp); + DecRef(p_); + p_ = tmp; + return *this; +} + +template +Shared::~Shared() { + DecRef(p_); + p_ = nullptr; +} + +template +void Shared::Reset(T* x) { + if (x) IncRef(x); + DecRef(p_); + p_ = x; +} + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/string.cc b/paddle/infrt/common/string.cc new file mode 100644 index 0000000000..d02643825a --- /dev/null +++ b/paddle/infrt/common/string.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/string.h" + +#include + +#include + +namespace infrt { +namespace infrt { + +std::string StringFormat(const std::string &fmt_str, ...) { + /* Reserve two times as much as the length of the fmt_str */ + int final_n, n = (static_cast(fmt_str.size())) * 2; + std::unique_ptr formatted; + va_list ap; + while (1) { + formatted.reset( + new char[n]); /* Wrap the plain char array into the unique_ptr */ + std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT + va_start(ap, fmt_str); + final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap); + va_end(ap); + if (final_n < 0 || final_n >= n) + n += abs(final_n - n + 1); + else + break; + } + return std::string(formatted.get()); +} + +std::string Trim(const std::string &s, const char *empty) { + if (s.empty()) return s; + auto start = s.find_first_not_of(empty); + if (start == std::string::npos) return ""; + auto end = s.find_last_not_of(empty); + return s.substr(start, end - start + 1); +} + +std::string Uppercase(const std::string &x) { + auto res = x; + for (auto &c : res) { + c = toupper(c); + } + return res; +} + +bool Startswith(const std::string &x, const std::string &str) { + return x.find(str) == 0; +} +bool Endswith(const std::string &x, const std::string &str) { + if (x.length() >= str.length()) { + return std::equal(str.rbegin(), str.rend(), x.rbegin()); + } + return false; +} + +std::vector Split(const std::string &str, + const std::string &splitter) { + std::vector results; + std::string::size_type pos1, pos2; + pos2 = str.find(splitter); + pos1 = 0; + while (std::string::npos != pos2) { + results.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = pos2 + splitter.size(); + pos2 = str.find(splitter, pos1); + } + if (pos1 != str.length()) { + results.push_back(str.substr(pos1)); + } + return results; +} + +void Replace(std::string *s, const std::string &from, const std::string &to) { + size_t pos = 0; + while ((pos = s->find(from, pos)) != std::string::npos) { + s->replace(pos, from.size(), to); + pos += to.length(); + } +} + +size_t Count(std::string *s, const std::string &sub) { + size_t pos = 0; + size_t times = 0; + while ((pos = s->find(sub, pos)) != std::string::npos) { + if ((pos == 0 || !IsPrefix(s->at(pos - 1))) && + (pos + sub.length() == s->size() || + !IsSuffix(s->at(pos + sub.length())))) { + pos += sub.length(); + times++; + } else { + pos++; + } + } + return times; +} + +bool IsPrefix(const char &c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_'); +} + +bool IsSuffix(const char &c) { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_') || + (c >= '0' && c <= '9') || (c == '\''); +} + +std::string TransValidVarName(std::string name) { + Replace(&name, ".", "__"); + Replace(&name, "/", "___"); + name.erase(0, name.find_first_not_of("_")); + return name; +} + +} // namespace infrt +} // namespace infrt diff --git a/paddle/infrt/common/string.h b/paddle/infrt/common/string.h new file mode 100644 index 0000000000..f744470603 --- /dev/null +++ b/paddle/infrt/common/string.h @@ -0,0 +1,84 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +namespace infrt { +namespace infrt { + +//! Get the content of a stream. +template +std::string GetStreamCnt(const T& x); + +/** + * Construct a formatted string with arguments. + * @param fmt_str The format. + * @param ... The parameters of the format. + * @return The formated string. + */ +std::string StringFormat(const std::string& fmt_str, ...); + +/** + * Join multiple fields to a single string. Similar to Python's str.join method. + */ +template +std::string Join(const std::vector& fields, const std::string& splitter) { + if (fields.empty()) return ""; + std::stringstream ss; + for (int i = 0; i < fields.size() - 1; i++) ss << fields[i] << splitter; + ss << fields.back(); + return ss.str(); +} + +std::vector Split(const std::string& str, + const std::string& splitter); + +std::string Trim(const std::string& s, const char* empty = " \n\r\t"); + +//! Convert a string to its uppercase. +std::string Uppercase(const std::string& x); + +//! Replace a substr 'from' to 'to' in string s. +void Replace(std::string* s, const std::string& from, const std::string& to); + +//! Count how many times substr 'sub' appears in string s. +size_t Count(std::string* s, const std::string& sub); + +//! Tell if a char is prefix of a tensor's name. +bool IsPrefix(const char& c); + +//! Tell if a char is suffix of a tensor's name. +bool IsSuffix(const char& c); + +//! Tell if a string \p x start with \p str. +bool Startswith(const std::string& x, const std::string& str); + +//! Tell if a string \p x ends with \p str. +bool Endswith(const std::string& x, const std::string& str); + +template +std::string GetStreamCnt(const T& x) { + std::stringstream os; + os << x; + return os.str(); +} + +std::string TransValidVarName(std::string name); + +} // namespace infrt +} // namespace infrt diff --git a/paddle/infrt/common/target.cc b/paddle/infrt/common/target.cc new file mode 100644 index 0000000000..d376ad7db0 --- /dev/null +++ b/paddle/infrt/common/target.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/target.h" + +#include + +namespace infrt { +namespace common { + +bool Target::operator==(const Target &other) const { + return os == other.os && // + arch == other.arch && // + bits == other.bits && // + features == other.features; +} + +int Target::max_num_threads() const { + CHECK(arch == Arch::NVGPU) + << "The target is not NVGPU! Cannot get max number of threads."; + return 1024; +} + +std::vector Target::get_target_libs() const { return libs; } + +int Target::get_target_bits() const { + switch (bits) { + case Bit::k32: + return 32; + case Bit::k64: + return 64; + case Bit::Unk: + return 0; + default: + LOG(FATAL) << "Not supported Bit"; + } + return -1; +} + +std::ostream &operator<<(std::ostream &os, const Target &target) { + os << "Target<"; + switch (target.os) { + case Target::OS::Linux: + os << "linux"; + break; + case Target::OS::Windows: + os << "windows"; + break; + case Target::OS::Unk: + os << "unk"; + break; + } + + os << ","; + + switch (target.arch) { + case Target::Arch::X86: + os << "x86"; + break; + case Target::Arch::ARM: + os << "arm"; + break; + case Target::Arch::NVGPU: + os << "nvgpu"; + break; + case Target::Arch::Unk: + os << "unk"; + break; + } + os << ","; + + switch (target.bits) { + case Target::Bit::k32: + os << "32"; + break; + case Target::Bit::k64: + os << "64"; + break; + case Target::Bit::Unk: + os << "unk"; + break; + } + os << ">"; + + return os; +} + +std::ostream &operator<<(std::ostream &os, Target::Arch arch) { + switch (arch) { + case Target::Arch::Unk: + os << "Unk"; + break; + case Target::Arch::X86: + os << "X86"; + break; + case Target::Arch::ARM: + os << "ARM"; + break; + case Target::Arch::NVGPU: + os << "NVGPU"; + break; + } + return os; +} + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/target.h b/paddle/infrt/common/target.h new file mode 100644 index 0000000000..eaf19efbfe --- /dev/null +++ b/paddle/infrt/common/target.h @@ -0,0 +1,112 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +namespace infrt { +namespace common { + +struct Target { + /** + * The operating system used by the target. Determines which system calls to + * generate. + */ + enum class OS : int { + Unk = -1, + Linux, + Windows, + }; + + /** + * The architecture used by the target. Determines the instruction set to use. + */ + enum class Arch : int { + Unk = -1, + X86, + ARM, + NVGPU, + }; + + enum class Bit : int { + Unk = -1, + k32, + k64, + }; + + OS os{OS::Unk}; + Arch arch{Arch::Unk}; + Bit bits{Bit::Unk}; + + enum class Feature : int { + JIT = 0, + Debug, + }; + + /** + * The library used by the target. + */ + enum class Lib : int { + Unk = -1, + MKL, + }; + std::vector features; + std::vector libs; + + explicit Target(OS o = OS::Linux, + Arch a = Arch::Unk, + Bit b = Bit::Unk, + const std::vector& features = {}, + const std::vector& libs = {}) + : os(o), arch(a), bits(b), features(features), libs(libs) {} + + bool defined() const { + return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk; + } + + int max_num_threads() const; + + int get_target_bits() const; + + std::vector get_target_libs() const; + + bool operator==(const Target& other) const; + bool operator!=(const Target& other) const { return !(*this == other); } + friend std::ostream& operator<<(std::ostream& os, const Target& target); +}; + +static const Target& UnkTarget() { + static Target target( + Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {}); + return target; +} + +static const Target& DefaultHostTarget() { + static Target target( + Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {}); + return target; +} + +static const Target& DefaultNVGPUTarget() { + static Target target( + Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {}); + return target; +} + +std::ostream& operator<<(std::ostream& os, Target::Arch arch); + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/type.cc b/paddle/infrt/common/type.cc new file mode 100644 index 0000000000..f262bd4697 --- /dev/null +++ b/paddle/infrt/common/type.cc @@ -0,0 +1,358 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/common/type.h" + +#include + +namespace infrt { +namespace common { + +struct Type::Storage { + Storage() = default; + Storage(type_t t, int b, int w) : type_(t), bits_(b), lanes_(w) {} + + type_t type_{type_t::Unk}; + cpp_type_t cpp_type_{cpp_type_t::None}; + + //! How many bits per element. + int bits_{}; + + //! How many elements(if a vector type), for scalar types, it should be 1. + int lanes_{1}; + + //! Name of the customized type. + std::string customized_type_; +}; + +Type::~Type() {} + +std::ostream &operator<<(std::ostream &os, const Type &t) { + if (t.is_cpp_const()) os << "const "; + switch (t.type()) { + case Type::type_t::Int: + if (t.bits() == 1) { + os << "bool"; + } else { + os << "int" << t.bits(); + } + + break; + case Type::type_t::UInt: + os << "uint" << t.bits(); + break; + + case Type::type_t::Float: + os << "float" << t.bits(); + break; + case Type::type_t::Void: + os << "void"; + break; + case Type::type_t::Customized: + os << t.customized_type(); + break; + case Type::type_t::String: + os << "string"; + break; + case Type::type_t::Unk: + os << "unk"; + break; + } + + if (t.lanes() > 1) os << "<" << t.lanes() << ">"; + if (t.is_cpp_handle()) os << "*"; + if (t.is_cpp_handle2()) os << "**"; + + return os; +} + +std::ostream &operator<<(std::ostream &os, Type::type_t t) { + switch (t) { + case Type::type_t::String: + os << "String"; + break; + case Type::type_t::Void: + os << "Void"; + break; + case Type::type_t::UInt: + os << "UInt"; + break; + case Type::type_t::Int: + os << "Int"; + break; + case Type::type_t::Float: + os << "Float"; + break; + case Type::type_t::Unk: + os << "Unk"; + break; + case Type::type_t::Customized: + os << "Customized"; + } + return os; +} + +Type &Type::set_cpp_handle(bool x) { + // unset the other handle-related bits. + set_cpp_handle2(false); + + auto &v = (*reinterpret_cast(&GetStorage().cpp_type_)); + // unset the other handle-related bits. + v &= ~static_cast(cpp_type_t::Handle); + v &= ~static_cast(cpp_type_t::HandleHandle); + + if (x) + v |= static_cast(cpp_type_t::Handle); + else + v &= ~static_cast(cpp_type_t::Handle); + + return *this; +} + +Type &Type::set_cpp_handle2(bool x) { + auto &v = (*reinterpret_cast(&GetStorage().cpp_type_)); + + // unset the other handle-related bits. + v &= ~static_cast(cpp_type_t::Handle); + v &= ~static_cast(cpp_type_t::HandleHandle); + + if (x) + v |= static_cast(cpp_type_t::HandleHandle); + else + v &= ~static_cast(cpp_type_t::HandleHandle); + + return *this; +} + +Type Type::VectorOf(int w) const { + CheckTypeValid(); + return Type(type(), w, bits()); +} + +Type::Type(const Type &other) { + if (other.storage_) storage_.reset(new Storage(*other.storage_)); +} + +Type Type::ElementOf() const { + CheckTypeValid(); + auto type = *this; + type.storage_->lanes_ = 1; + return type; +} + +void Type::CheckTypeValid() const { CHECK_NE(GetStorage().type_, type_t::Unk); } + +Type Type::PointerOf() const { + CheckTypeValid(); + auto x = *this; + CHECK(!x.is_cpp_handle2()) << "Not support three level of PointerOf"; + if (x.is_cpp_handle()) + x.set_cpp_handle2(); + else + x.set_cpp_handle(); + return x; +} + +Type Type::ConstOf() const { + CheckTypeValid(); + auto x = *this; + x.set_cpp_const(); + return x; +} + +Type Type::IgnoreConst() const { + CheckTypeValid(); + auto x = *this; + x.set_cpp_const(false); + return x; +} + +Type Type::with_bits(int x) const { + CHECK(is_primitive()); + Type type = *this; + type.GetStorage().bits_ = x; + return type; +} + +Type Type::with_type(Type::type_t x) const { + Type type = *this; + type.GetStorage().type_ = x; + return type; +} + +Type Type::with_lanes(int x) const { + CHECK(valid()); + Type type = *this; + type.GetStorage().lanes_ = x; + return type; +} + +Type Type::with_cpp_const(bool x) const { + Type type = *this; + type.set_cpp_const(x); + return type; +} + +Type &Type::set_cpp_const(bool is_const) { + uint8_t &data = *reinterpret_cast(&GetStorage().cpp_type_); + if (is_const) { + data |= static_cast(cpp_type_t::Const); + } else { + data &= ~(static_cast(cpp_type_t::Const)); + } + + return *this; +} +Type &Type::set_customized_type(const std::string &t) { + GetStorage().type_ = type_t::Customized; + GetStorage().customized_type_ = t; + + return *this; +} + +bool Type::valid() const { + if (is_unk()) return false; + if (is_customized()) { + return !GetStorage().customized_type_.empty(); + } + if (is_primitive()) { + return bits() != 0; + } + return true; +} + +Type::Type(Type::type_t t, int b, int w) : storage_(new Storage(t, b, w)) {} +bool Type::is_primitive() const { + return !is_unk() && type() != type_t::Customized; +} +bool Type::is_customized() const { + return !is_unk() && type() == type_t::Customized; +} +bool Type::is_unk() const { return type() == type_t::Unk; } +bool Type::is_bool() const { return type() == type_t::UInt && bits() == 1; } +bool Type::is_void() const { return type() == type_t::Void; } +bool Type::is_vector() const { return lanes() > 1; } +bool Type::is_scalar() const { return lanes() == 1; } +bool Type::is_float(int bits) const { + return type() == type_t::Float && (bits < 0 || bits == this->bits()); +} +bool Type::is_uint(int bits) const { + return type() == type_t::UInt && (bits < 0 || bits == this->bits()); +} +bool Type::is_int(int bits) const { + return type() == type_t::Int && (bits < 0 || bits == this->bits()); +} +bool Type::is_integer(int bits) const { + return (type() == type_t::Int || type() == type_t::UInt) && + (bits < 0 || bits == this->bits()); +} +bool Type::is_index_type() { + return is_int() && lanes() == 1 && (bits() == 32 || bits() == 64); +} +bool Type::is_cpp_handle() const { + return static_cast(GetStorage().cpp_type_) & + static_cast(cpp_type_t::Handle); +} +bool Type::is_cpp_handle2() const { + return static_cast(GetStorage().cpp_type_) & + static_cast(cpp_type_t::HandleHandle); +} +bool Type::is_cpp_const() const { + return static_cast(cpp_type_t::Const) & + static_cast(GetStorage().cpp_type_); +} +const std::string &Type::customized_type() const { + return GetStorage().customized_type_; +} +bool Type::is_customized_type() const { + return !GetStorage().customized_type_.empty(); +} +Type::type_t Type::type() const { return GetStorage().type_; } +int Type::bits() const { return GetStorage().bits_; } +int Type::lanes() const { return GetStorage().lanes_; } +Type::cpp_type_t Type::cpp_type() const { return GetStorage().cpp_type_; } +bool Type::operator==(const Type &other) const { + return type() == other.type() && bits() == other.bits() && + lanes() == other.lanes() && + GetStorage().cpp_type_ == other.GetStorage().cpp_type_ && + customized_type() == other.customized_type(); +} +bool Type::is_string() const { return type() == type_t::String; } + +Type &Type::operator=(const Type &other) { + if (other.storage_) storage_.reset(new Storage(*other.storage_)); + return *this; +} + +Type::Storage &Type::GetStorage() { return *storage_; } +const Type::Storage &Type::GetStorage() const { return *storage_; } + +Type::Type() : storage_(new Storage) {} +Type::Type(Type &&other) : storage_(std::move(other.storage_)) {} + +const Type &F16() { + static auto t = Float(16); + return t; +} +const Type &F32() { + static auto t = Float(32); + return t; +} +const Type &F64() { + static auto t = Float(64); + return t; +} +const Type &I8() { + static auto t = Int(8); + return t; +} +const Type &I16() { + static auto t = Int(16); + return t; +} +const Type &I32() { + static auto t = Int(32); + return t; +} +const Type &I64() { + static auto t = Int(64); + return t; +} +const Type &UI8() { + static auto t = UInt(8); + return t; +} +const Type &UI16() { + static auto t = UInt(16); + return t; +} +const Type &UI32() { + static auto t = UInt(32); + return t; +} +const Type &UI64() { + static auto t = UInt(64); + return t; +} +const Type &I1() { + static auto t = Int(1); + return t; +} +const Type &UI1() { + static auto t = UInt(1); + return t; +} + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/common/type.h b/paddle/infrt/common/type.h new file mode 100644 index 0000000000..b532fc154f --- /dev/null +++ b/paddle/infrt/common/type.h @@ -0,0 +1,223 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +#include "paddle/infrt/common/macros.h" + +//! Much of the concepts are borrowed from Halide project. + +namespace infrt { +namespace common { + +/** + * Types in the INFRT type system. They can be ints, unsigned ints, or floats of + * various bit-widths. + * They can also be vectors of the same (by setting the `lanes` field to + * something larger than one). + * NOTE: Front-end code other than vectorize shouldn't use vector types. + */ +struct Type { + enum class type_t { + Unk = -1, + Int, + UInt, + Float, + String, + Void, + // stupid idea to mix the Customized with other primitive types, large + // refactor needs here. + Customized, // Customized type + }; + + //! type decorators in C++, the different code can used together. + enum class cpp_type_t : uint8_t { + None = 0, // None information. + Const = 1, // const. + Handle = 1 << 1, // pointer type, such as `infrt_buffer_t*`. + HandleHandle = 1 << 2, // pointer of pointer, such as `infrt_buffer_t**`. + }; + + Type(); + Type(type_t t, int b, int w); + Type(const Type& other); + explicit Type(Type&& other); + Type& operator=(const Type& other); + + INFRT_NODISCARD bool is_primitive() const; + INFRT_NODISCARD bool is_customized() const; + INFRT_NODISCARD bool valid() const; + + //! Some helper functions to check a type. + // @{ + INFRT_NODISCARD bool is_unk() const; + INFRT_NODISCARD bool is_void() const; + INFRT_NODISCARD bool is_bool() const; + INFRT_NODISCARD bool is_vector() const; + INFRT_NODISCARD bool is_scalar() const; + INFRT_NODISCARD bool is_float(int bits = -1) const; + INFRT_NODISCARD bool is_int(int bits = -1) const; + INFRT_NODISCARD bool is_integer(int bits = -1) const; + INFRT_NODISCARD bool is_uint(int bits = -1) const; + INFRT_NODISCARD bool is_string() const; + INFRT_NODISCARD bool is_index_type(); + // @} + + Type& set_cpp_handle(bool x = true); + INFRT_NODISCARD bool is_cpp_handle() const; + + Type& set_cpp_handle2(bool x = true); + INFRT_NODISCARD bool is_cpp_handle2() const; + + Type& set_cpp_const(bool is_const = true); + INFRT_NODISCARD bool is_cpp_const() const; + + Type& set_customized_type(const std::string& t); + const std::string& customized_type() const; + INFRT_NODISCARD bool is_customized_type() const; + + // Get a new type with bits set to \p x. + Type with_bits(int x) const; + // Get a new type with type set to \p x. + Type with_type(type_t x) const; + // Get a new type with lanes set to \p x. + Type with_lanes(int x) const; + // Get a new type with cpp_const set to \p x. + Type with_cpp_const(bool x = true) const; + + //! Getters + // @{ + type_t type() const; + int bits() const; + int lanes() const; + cpp_type_t cpp_type() const; + // @} + + //! Compare two types for equality. + bool operator==(const Type& other) const; + + //! Compare two types for inequality. + bool operator!=(const Type& other) const { return !(*this == other); } + + //! Generate a vector of this type, with `w` elements. + Type VectorOf(int w) const; + //! Generate a element type of this type. + Type ElementOf() const; + //! Generate the address type. + Type PointerOf() const; + //! Ignore const. + Type IgnoreConst() const; + //! Add const. + Type ConstOf() const; + + friend std::ostream& operator<<(std::ostream& os, const Type& t); + + ~Type(); + + private: + void CheckTypeValid() const; + + struct Storage; + Storage& GetStorage(); + const Storage& GetStorage() const; + + std::unique_ptr storage_; +}; // namespace common + +inline Type Void() { return Type(Type::type_t::Void, 1, 0); } +inline Type Int(int bits, int lanes = 1) { + return Type(Type::type_t::Int, bits, lanes); +} +inline Type UInt(int bits, int lanes = 1) { + return Type(Type::type_t::UInt, bits, lanes); +} +inline Type Float(int bits, int lanes = 1) { + return Type(Type::type_t::Float, bits, lanes); +} +inline Type Bool(int lanes = 1) { return Type(Type::type_t::UInt, 1, lanes); } +inline Type String() { return Type(Type::type_t::String, 1, 1); } + +//! Builtin native types as global singletons. +// @{ +const Type& F16(); +const Type& F32(); +const Type& F64(); +const Type& I8(); +const Type& I16(); +const Type& I32(); +const Type& I64(); +const Type& UI8(); +const Type& UI16(); +const Type& UI32(); +const Type& UI64(); +const Type& I1(); +const Type& UI1(); +// @} + +template +Type type_of(); + +// clang-format off +template <> inline Type type_of() { return F32(); } +template <> inline Type type_of() { return F64(); } +template <> inline Type type_of() { return UI8(); } +template <> inline Type type_of() { return UI16(); } +template <> inline Type type_of() { return I32(); } +template <> inline Type type_of() { return UI32(); } +template <> inline Type type_of() { return UI1(); } +template <> inline Type type_of() { return I8(); } +template <> inline Type type_of() { return I64(); } +template <> inline Type type_of() { return UI64(); } +template <> inline Type type_of() { return I8(); } +template <> inline Type type_of() { return Void(); } +// clang-format on +template <> +inline Type type_of() { + Type x = Int(8); + x.set_cpp_handle(); + return x; +} +template <> +inline Type type_of() { + Type x = type_of(); + x.set_cpp_handle(); + return x; +} +template <> +inline Type type_of() { + Type x = type_of(); + x.set_cpp_handle2(); + return x; +} +template <> +inline Type type_of() { + Type x = type_of(); + x.set_cpp_handle(); + return x; +} +template <> +inline Type type_of() { + Type x = type_of(); + x.set_cpp_handle(); + return x; +} + +std::ostream& operator<<(std::ostream& os, Type::type_t t); + +} // namespace common +} // namespace infrt diff --git a/paddle/infrt/dialect/CMakeLists.txt b/paddle/infrt/dialect/CMakeLists.txt new file mode 100644 index 0000000000..c1517beab0 --- /dev/null +++ b/paddle/infrt/dialect/CMakeLists.txt @@ -0,0 +1,61 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + dialect.cc + types.cc + basic_kernels.cc + test_kernels.cc + infrt_base.cc + init_infrt_dialects.cc + tensor_shape.cc + dense_tensor.cc + mlir_loader.cc + diagnostic_utils.cc + pd_types.cc + pd_ops.cc + ) + +mlir_tablegen_on(ops) +mlir_tablegen_on(basic_kernels) +mlir_tablegen_on(test_kernels) +mlir_tablegen_on(infrt_base DIALECT infrt) +mlir_tablegen_on(tensor_shape DIALECT ts) +mlir_tablegen_on(dense_tensor DIALECT dt) +mlir_tablegen_on(pd_op_base DIALECT pd) +mlir_tablegen_on(pd_ops) +mlir_add_rewriter(rewrite) + +# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code +add_executable(infrtopt opt.cc) +target_link_libraries(infrtopt infrt ${mlir_libs}) +add_dependencies(infrtopt infrt) + +add_executable(print-ir print_ir.cc) +target_link_libraries(print-ir infrt ${mlir_libs}) +add_dependencies(print-ir pd_ops_inc) + + +# MLIR opt tests +# %{ +set(infrt_opt_path ${CMAKE_BINARY_DIR}/infrt/dialect/infrtopt) + +add_test(test_infrt_mlir_opt_on_basic ${infrt_opt_path} + ${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/basic.mlir) +add_test(test_infrt_mlir_opt_on_tensor_shape ${infrt_opt_path} + ${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/tensor_shape.mlir) +add_test(test_infrt_mlir_opt_on_paddle_ops + ${infrt_opt_path} + ${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/paddle_ops.mlir) +# %} + +cc_test_tiny(test_infrt_mlir_loader SRCS mlir_loader_test.cc DEPS infrt ${MLIR_IR_LIBS}) + +# execute mlir and run FileCheck +infrt_exec_check(run_and_check_tensor_type mlir_tests/tensor_type.mlir) +infrt_exec_check(run_and_check_basic mlir_tests/basic.mlir) +infrt_exec_check(run_and_check_benchmark mlir_tests/benchmark.mlir) +#infrt_exec_check(run_and_check_dense_tensor mlir_tests/dense_tensor.mlir) +add_test(test_infrt_mlir_dense_tensor + ${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec + -i + ${CMAKE_CURRENT_SOURCE_DIR}/mlir_tests/dense_tensor.mlir) diff --git a/paddle/infrt/dialect/basic_kernels.cc b/paddle/infrt/dialect/basic_kernels.cc new file mode 100644 index 0000000000..b4d2b9182b --- /dev/null +++ b/paddle/infrt/dialect/basic_kernels.cc @@ -0,0 +1,164 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/basic_kernels.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/infrt/dialect/dense_tensor.h" + +namespace infrt::dialect { +using namespace mlir; // NOLINT + +static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + SymbolRefAttr callee_attr; + FunctionType callee_type; + SmallVector operands; + auto callee_loc = parser.getNameLoc(); + if (parser.parseAttribute(callee_attr, "callee", result.attributes) || + parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(callee_type) || + parser.addTypesToList(callee_type.getResults(), result.types) || + parser.resolveOperands( + operands, callee_type.getInputs(), callee_loc, result.operands)) + return failure(); + return success(); +} + +static ParseResult parseConstantOp(Type attrType, + OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + Attribute valueAttr; + if (parser.parseOptionalAttrDict(result.attributes) || + parser.parseAttribute(valueAttr, attrType, "value", result.attributes) || + parser.addTypeToList(attrType, result.types)) + return failure(); + return success(); +} + +static ParseResult parseConstantF32Op(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + return parseConstantOp( + FloatType::getF32(result.getContext()), parser, result); +} +static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + return parseConstantOp( + FloatType::getF64(result.getContext()), parser, result); +} +static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + return parseConstantOp( + IntegerType::get(32, result.getContext()), parser, result); +} +static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + return parseConstantOp( + IntegerType::get(64, result.getContext()), parser, result); +} + +static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + SmallVector opInfo; + SmallVector types; + llvm::SMLoc loc = parser.getCurrentLocation(); + return failure(parser.parseOperandList(opInfo) || + (!opInfo.empty() && parser.parseColonTypeList(types)) || + parser.resolveOperands(opInfo, types, loc, result.operands)); +} + +static void print(OpAsmPrinter &p, CallOp op) { // NOLINT + p << "infrt.call " << op.getAttr("callee") << "("; + p.printOperands(op.getOperands()); + p << ")"; + p.printOptionalAttrDict(op.getAttrs(), {"callee"}); + p << " : "; +} + +static void printConstant(OpAsmPrinter &p, mlir::Operation *op) { // NOLINT + p << op->getName() << " "; + p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"}); + + if (op->getAttrs().size() > 1) p << ' '; + Attribute attr = op->getAttr("value"); + if (auto int_attr = attr.dyn_cast()) { + bool is_signed = int_attr.getType().isIndex() || + int_attr.getType().getIntOrFloatBitWidth() != 1; + int_attr.getValue().print(p.getStream(), is_signed); + } else if (auto float_attr = attr.dyn_cast()) { + p << float_attr.getValue().convertToFloat(); + } else { + op->emitOpError("unknown attribute type"); + } +} + +static void print(OpAsmPrinter &p, ConstantF32Op op) { // NOLINT + printConstant(p, op); +} +static void print(OpAsmPrinter &p, ConstantF64Op op) { // NOLINT + printConstant(p, op); +} +static void print(OpAsmPrinter &p, ConstantI32Op op) { // NOLINT + printConstant(p, op); +} +static void print(OpAsmPrinter &p, ConstantI64Op op) { // NOLINT + printConstant(p, op); +} + +static void print(OpAsmPrinter &p, ReturnOp op) { // NOLINT + p << "infrt.return"; + if (op.getNumOperands() > 0) { + p << ' '; + p.printOperands(op.getOperands()); + p << " : "; + llvm::interleaveComma(op.getOperands(), p); + } +} + +static LogicalResult verify(CallOp op) { return success(); } + +static LogicalResult verify(ConstantF32Op op) { return success(); } +static LogicalResult verify(ConstantI32Op op) { return success(); } +static LogicalResult verify(ConstantF64Op op) { return success(); } +static LogicalResult verify(ConstantI64Op op) { return success(); } + +static LogicalResult verify(ReturnOp op) { + auto function = dyn_cast(op.getParentOp()); + + if (!function) return success(); + + auto results = function.getType().getResults(); + if (op.getNumOperands() != results.size()) + return op.emitOpError("has ") + << op.getNumOperands() + << " operands, but enclosing function returns " << results.size(); + + return success(); +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/basic_kernels.cpp.inc" + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.h b/paddle/infrt/dialect/basic_kernels.h new file mode 100644 index 0000000000..65316bc143 --- /dev/null +++ b/paddle/infrt/dialect/basic_kernels.h @@ -0,0 +1,24 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +using namespace mlir; // NOLINT + +namespace infrt::dialect { +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/basic_kernels.hpp.inc" +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/basic_kernels.td b/paddle/infrt/dialect/basic_kernels.td new file mode 100644 index 0000000000..df5e4d8a2c --- /dev/null +++ b/paddle/infrt/dialect/basic_kernels.td @@ -0,0 +1,139 @@ +// Operation definitions for basic kernels. + +#ifdef BASIC_OPS +#else +#define BASIC_OPS + +include "paddle/infrt/dialect/infrt_base.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +class INFRT_Op traits = []> : Op { + + // Each registered op needs to provide all of a printer, parser and verifier. + let printer = [{ return infrt::dialect::print(p, *this); }]; + let verifier = [{ return infrt::dialect::verify(*this); }]; + let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }]; +} + +def CallOp : INFRT_Op<"call"> { + let summary = "call a host operation"; + let description = [{ + The "infrt.call" operation represents a direct call to a function. The operands and result types of the call must match the specified function type. + + %2 = infrt.call @add(%0, %1) : (f32, f32) -> f32 + }]; + + let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); + let results = (outs Variadic); + + let extraClassDeclaration = [{ + StringRef getCallee() { return callee(); } + mlir::FunctionType getCalleeType(); + }]; +} + +class ConstantOp + : INFRT_Op<"constant." # suffix, [NoSideEffect]> { + let summary = "constant value constructor in host"; + + let arguments = (ins attr:$value); + let results = (outs baseType); +} + +def ConstantI32Op : ConstantOp<"i32", I32, I32Attr>; +def ConstantI64Op : ConstantOp<"i64", I64, I64Attr>; +def ConstantF32Op : ConstantOp<"f32", F32, F32Attr>; +def ConstantF64Op : ConstantOp<"f64", F64, F64Attr>; + +def ReturnOp : INFRT_Op<"return", [Terminator]> { + let summary = "host executor return operation"; + let description = [{ + The "infrt.return" operation represents a return operation within a function. + + func @foo() : (i32, f8) { + infrt.return %0, %1 : i32, f8 + } + }]; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "OpBuilder &b, OperationState &result", + [{ build(b, result, llvm::None); }]>]; +} + +class AddOp : INFRT_Op<"add." # suffix, [NoSideEffect]> { + let summary = "infrt.add operation"; + let description = [{ + An operation that takes two inputs and returns their sum as result. + }]; + + let arguments = (ins type, type); + let results = (outs type); + let assemblyFormat = "operands attr-dict"; + let verifier = ?; +} + +def AddI32Op : AddOp<"i32", I32>; +def AddI64Op : AddOp<"i64", I64>; +def AddF32Op : AddOp<"f32", F32>; +def AddF64Op : AddOp<"f64", F64>; + +class MulOp : INFRT_Op<"mul." # suffix, [NoSideEffect]> { + let summary = "infrt.mul operation"; + let description = [{ + An operation that takes two inputs and returns their mul as result. + }]; + + let arguments = (ins type, type); +let results = (outs type); +let assemblyFormat = "operands attr-dict"; +let verifier = ?; +} + +def MulI32Op : MulOp<"i32", I32>; +def MulI64Op : MulOp<"i64", I64>; +def MulF32Op : MulOp<"f32", F32>; +def MulF64Op : MulOp<"f64", F64>; + +class PrintOp : INFRT_Op<"print." # suffix> { + let summary = "infrt.print operation"; + let description = [{ + An operation takes a number as input and prints to stdout. + }]; + + let arguments = (ins type); + let assemblyFormat = "operands attr-dict"; + let verifier = ?; +} + +//def PrintI32Op : PrintOp<"i32", I32>; +//def PrintI64Op : PrintOp<"i64", I64>; +def PrintF32Op : PrintOp<"f32", F32>; +//def PrintF64Op : PrintOp<"f64", F64>; + +def GetStringOp : INFRT_Op<"get_string"> { + let summary = "infrt.get_string"; + let description = [{ + Get a !infrt.string value from the given string attribute. + }]; + + let arguments = (ins StrAttr:$value); + let results = (outs StringType); + let assemblyFormat = "`(` $value `)` attr-dict"; + let verifier = ?; +} + +def PrintStringOp : INFRT_Op<"print_string"> { + let summary = "infrt.print_string"; + let description = [{ + An operation that prints a string. + }]; + + let arguments = (ins StringType:$input); + let results = (outs); + let assemblyFormat = "`(` $input `)` attr-dict"; + let verifier = ?; +} + +#endif // basic kernels diff --git a/paddle/infrt/dialect/dense_tensor.cc b/paddle/infrt/dialect/dense_tensor.cc new file mode 100644 index 0000000000..629a7b1652 --- /dev/null +++ b/paddle/infrt/dialect/dense_tensor.cc @@ -0,0 +1,277 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/dense_tensor.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/tensor_shape.h" + +namespace infrt::dt { + +void DTDialect::initialize() { + allowUnknownTypes(); + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/dense_tensor.cpp.inc" + >(); +} + +namespace detail { +struct TensorTypeStorage : public mlir::TypeStorage { + TensorTypeStorage(TargetType target, + LayoutType layout, + PrecisionType precision) + : target_(target), layout_(layout), precision_(precision) {} + + using KeyTy = std::tuple; + + bool operator==(const KeyTy &key) const { + return key == KeyTy(target_, layout_, precision_); + } + + static llvm::hash_code hashKey(const KeyTy &key) { + return llvm::hash_value(key); + } + + static TensorTypeStorage *construct( + mlir::TypeStorageAllocator &allocator, // NOLINT + const KeyTy &key) { + return new (allocator.allocate()) + TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key)); + } + + TargetType target_; + LayoutType layout_; + PrecisionType precision_; +}; +} // namespace detail + +llvm::Optional GetTargetType(mlir::StringRef key) { + if (key.equals_lower("x86")) + return TargetType::X86; + else if (key.equals_lower("cuda")) + return TargetType::CUDA; + else + return llvm::None; +} + +llvm::Optional GetLayoutType(mlir::StringRef key) { + if (key.equals_lower("nchw")) + return LayoutType::NCHW; + else if (key.equals_lower("nhwc")) + return LayoutType::NHWC; + else + return llvm::None; +} + +llvm::Optional GetPrecisionType(mlir::StringRef key) { + if (key.equals_lower("i32")) + return PrecisionType::I32; + else if (key.equals_lower("f32")) + return PrecisionType::F32; + else + return llvm::None; +} + +TensorType TensorType::get(TargetType target, + LayoutType layout, + PrecisionType precision) { + return Base::get( + ::infrt::Global::getMLIRContext(), target, layout, precision); +} + +TargetType TensorType::target() { return getImpl()->target_; } + +LayoutType TensorType::layout() { return getImpl()->layout_; } + +PrecisionType TensorType::precision() { return getImpl()->precision_; } + +raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) { + os << "TensorType<" << tensorType.target() << ", " << tensorType.layout() + << ", " << tensorType.precision() << ">"; + return os; +} + +TensorMapType TensorMapType::get() { + return Base::get(::infrt::Global::getMLIRContext()); +} + +TensorMapType TensorMapType::get(mlir::MLIRContext *context) { + return Base::get(context); +} + +StringType StringType::get() { + return Base::get(::infrt::Global::getMLIRContext()); +} + +StringType StringType::get(mlir::MLIRContext *context) { + return Base::get(context); +} + +raw_ostream &operator<<(raw_ostream &os, TargetType type) { + switch (type) { + case (TargetType::X86): + os << "X86"; + break; + case (TargetType::CUDA): + os << "CUDA"; + break; + default: + os << "Unsupported"; + } + return os; +} + +raw_ostream &operator<<(raw_ostream &os, LayoutType type) { + switch (type) { + case (LayoutType::NCHW): + os << "NCHW"; + break; + case (LayoutType::NHWC): + os << "NHWC"; + break; + default: + os << "Unsupported"; + } + return os; +} + +raw_ostream &operator<<(raw_ostream &os, PrecisionType type) { + switch (type) { + case (PrecisionType::I32): + os << "I32"; + break; + case (PrecisionType::F32): + os << "F32"; + break; + default: + os << "Unsupported"; + } + return os; +} + +static Type getTensorType(mlir::MLIRContext *context) { + auto t_dialect = Identifier::get("t", context); + return OpaqueType::get(t_dialect, "tensor", context); +} + +static ParseResult parseCreateUninitTensorOp( + OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + auto loc = parser.getCurrentLocation(); + ::mlir::Type outputRawTypes[1]; + ::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes); + + mlir::ArrayAttr shapeAttr; + if (parser.parseAttribute(shapeAttr, + parser.getBuilder().getI64Type(), + "shape", + result.attributes)) + return failure(); + if (parser.parseOptionalAttrDict(result.attributes)) return failure(); + + if (parser.parseArrow()) return failure(); + if (parser.parseType(outputRawTypes[0])) return failure(); + if (!outputRawTypes[0].isa()) + return parser.emitError(loc, "invalid kind of type specified"); + result.addTypes(outputTypes); + return success(); +} + +template +static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT + CreateUninitTensorOp op) { + p << CreateUninitTensorOp::getOperationName(); + p << " "; + p.printAttributeWithoutType(op.shapeAttr()); + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); + p << " -> "; + p << op.getOperation()->getResultTypes(); +} + +// TODO(shibo): can be removed? +// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser, +// OperationState& result) { +// auto loc = parser.getCurrentLocation(); +// ::mlir::OpAsmParser::OperandType inputRawOperands[1]; +// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType> +// inputOperands(inputRawOperands); +// ::mlir::Type inputRawTypes[1]; +// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes); +// +// if (parser.parseOperand(inputRawOperands[0])) return failure(); +// +// if (parser.parseColon()) return failure(); +// if (parser.parseType(inputRawTypes[0])) return failure(); +// if (!inputRawTypes[0].isa()) +// return parser.emitError(loc, "invalid kind of type specified"); +// +// Attribute value_attr; +// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands)) +// return failure(); +// if (parser.parseAttribute(value_attr, "value", result.attributes)) return +// failure(); +// return success(); +//} + +// TODO(shibo): can be removed? +// template +// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) { +// p << FillTensorOp::getOperationName(); +// p << " "; +// p.printOperand(op.getOperand()); +// p << " : "; +// p << op.getOperation()->getOperandTypes(); +// p << " "; +// p << op.getAttr("value"); +//} + +static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + SmallVector operands; + if (parser.parseOperandList(operands, 1)) return failure(); + + auto tensor_type = getTensorType(result.getContext()); + + Attribute value_attr; + return failure( + parser.resolveOperand(operands[0], tensor_type, result.operands) || + parser.parseAttribute(value_attr, "values", result.attributes)); +} + +template +static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT + p << SetTensorOp::getOperationName() << " "; + p.printOperand(op.getOperand()); + p << " " << op.getAttr("values"); +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT + +} // namespace infrt::dt diff --git a/paddle/infrt/dialect/dense_tensor.h b/paddle/infrt/dialect/dense_tensor.h new file mode 100644 index 0000000000..866c62213a --- /dev/null +++ b/paddle/infrt/dialect/dense_tensor.h @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include + +using namespace mlir; // NOLINT +namespace infrt::dt { + +namespace detail { +struct TensorTypeStorage; +} // namespace detail + +enum class TargetType : uint8_t { X86, CUDA }; +enum class LayoutType : uint8_t { NCHW, NHWC }; +enum class PrecisionType : uint8_t { I32, F32 }; + +llvm::Optional GetTargetType(mlir::StringRef key); +llvm::Optional GetLayoutType(mlir::StringRef key); +llvm::Optional GetPrecisionType(mlir::StringRef key); + +raw_ostream &operator<<(raw_ostream &os, TargetType type); +raw_ostream &operator<<(raw_ostream &os, LayoutType type); +raw_ostream &operator<<(raw_ostream &os, PrecisionType type); + +class TensorType : public mlir::Type::TypeBase { + public: + using Base::Base; + static TensorType get(TargetType target, + LayoutType layout, + PrecisionType precision); + + TargetType target(); + LayoutType layout(); + PrecisionType precision(); +}; + +raw_ostream &operator<<(raw_ostream &os, TensorType tensorType); + +class TensorMapType : public mlir::Type::TypeBase { + public: + using Base::Base; + static TensorMapType get(); + static TensorMapType get(mlir::MLIRContext *context); +}; + +class StringType + : public mlir::Type::TypeBase { + public: + using Base::Base; + static StringType get(); + static StringType get(mlir::MLIRContext *context); +}; + +#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc" + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/dense_tensor.hpp.inc" + +} // namespace infrt::dt diff --git a/paddle/infrt/dialect/dense_tensor.td b/paddle/infrt/dialect/dense_tensor.td new file mode 100644 index 0000000000..07e70cb2ca --- /dev/null +++ b/paddle/infrt/dialect/dense_tensor.td @@ -0,0 +1,150 @@ +#ifdef DT_OPS +#else +#define DT_OPS + +include "paddle/infrt/dialect/infrt_base.td" +include "paddle/infrt/dialect/tensor_shape_base.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def DT_Dialect : Dialect { + let name = "dt"; + + let description = [{ + The DenseTensor dialect. + }]; + + let cppNamespace = "::infrt::dt"; +} + +class DT_Op traits = []> : + Op; + +class CreateUninitTensorOp + : DT_Op<"create_uninit_tensor." # dtype, [NoSideEffect]> { + let summary = "dt.create_uninit_tensor operation"; + + let description = [{ + An operation that creates an uninitialized tensor. + }]; + + let arguments = (ins I64ArrayAttr:$shape); + let results = (outs TensorType:$output); + + let parser = [{ return infrt::dt::parseCreateUninitTensorOp(parser, result); }]; + let printer = [{ return infrt::dt::printCreateUninitTensorOp(p, *this); }]; +} + + +def ShallowCopyTensorOp + : DT_Op<"shallow_copy_tensor", [NoSideEffect]> { + let summary = "dt.shallow_copy_tensor operation"; + + let description = [{ + An operation that copy a tensor shallowly. + }]; + + let arguments = (ins TensorType:$input); + let results = (outs TensorType:$output); + + let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; +} + + +class FillTensorWithConstantOp : + DT_Op<"fill_tensor_with_constant." # dtype> { + let summary = "dt.fill_tensor_with_constant operation"; + + let description = [{ + An operation that fills an input tensor with a value. + }]; + + let arguments = (ins + TensorType:$input, + AnyAttr:$value + ); + let results = (outs); + + // TODO: can be removed? + //let parser = [{ return infrt::dt::parseFillTensorWithConstantOp(parser, result); }]; + //let printer = [{ return infrt::dt::printFillTensorWithConstantOp(p, *this); }]; + let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict"; +} + +def PrintTensorOp : DT_Op<"print_tensor"> { + let summary = "dt.print_tensor operation"; + + let description = [{ + An operation that prints a tensor. + }]; + + let arguments = (ins TensorType:$input); + let results = (outs); + let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict"; +} + +class SetTensorOp : + DT_Op<"set_tensor_with_constant_values." # dtype> { + let summary = "dt.set_tensor_with_constant_values operation"; + + let description = [{ + An operation that sets an input tensor with given values. + }]; + + let arguments = (ins TensorType); + let results = (outs); + + let parser = [{ return infrt::dt::parseSetTensorOp(parser, result); }]; + let printer = [{ return infrt::dt::printSetTensorOp(p, *this); }]; +} + +def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> { + let summary = "dt.load_params operation"; + + let description = [{ + An operation that can load tensors to TensorMap. + }]; + + // input path of model params. + let arguments = (ins StringType:$path); + let results = (outs TensorMapType); + + let assemblyFormat = "`(` operands `)` attr-dict"; + let verifier = ?; +} + +def GetParamOp : DT_Op<"get_param", [NoSideEffect]> { + let summary = "dt.get_param operation"; + + let description = [{ + An operation that can get a tensor from TensorMap. + }]; + + // input path of model params. + let arguments = (ins + TensorMapType:$map, + StrAttr:$name + ); + let results = (outs TensorType:$output); + let assemblyFormat = "`(` $map `,` $name `)` attr-dict `->` type($output)"; + let verifier = ?; +} + +def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> { + let summary = "dt.get_tensor_shape operation"; + + let description = [{ + An operation that returns the shape of the input tensor. + }]; + + let arguments = (ins TensorType:$input); + let results = (outs TS_Shape:$output); + let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)"; +} + +foreach dtype = ["ui8", "ui16", "ui32", "ui64", "i32", "f32", "f64", "i64"] in { + def DT_CreateUninitTensorOp_#dtype : CreateUninitTensorOp; + def DT_FillTensorOp_#dtype : FillTensorWithConstantOp; + def DT_SetTensorOp_#dtype : SetTensorOp; +} + +#endif // DT_OPS diff --git a/paddle/infrt/dialect/diagnostic_utils.cc b/paddle/infrt/dialect/diagnostic_utils.cc new file mode 100644 index 0000000000..a28176e38f --- /dev/null +++ b/paddle/infrt/dialect/diagnostic_utils.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/diagnostic_utils.h" + +#include + +namespace infrt::dialect { + +struct MyScopedDiagnosicHandler::Impl { + Impl() : diag_stream_(diag_str_) {} + + // String stream to assemble the final error message. + std::string diag_str_; + llvm::raw_string_ostream diag_stream_; + + // A SourceMgr to use for the base handler class. + llvm::SourceMgr source_mgr_; + + // Log detail information. + bool log_info_{}; +}; + +MyScopedDiagnosicHandler::MyScopedDiagnosicHandler(mlir::MLIRContext *ctx, + bool propagate) + : mlir::SourceMgrDiagnosticHandler( + impl_->source_mgr_, ctx, impl_->diag_stream_), + impl_(new Impl) { + setHandler([this](mlir::Diagnostic &diag) { return this->handler(&diag); }); +} + +mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) { + if (diag->getSeverity() != mlir::DiagnosticSeverity::Error && + !impl_->log_info_) + return mlir::success(); + emitDiagnostic(*diag); + impl_->diag_stream_.flush(); + return mlir::failure(true); +} + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/diagnostic_utils.h b/paddle/infrt/dialect/diagnostic_utils.h new file mode 100644 index 0000000000..3a8098cf75 --- /dev/null +++ b/paddle/infrt/dialect/diagnostic_utils.h @@ -0,0 +1,39 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include + +namespace infrt::dialect { + +/** + * A scoped diagnostic handler to help debug MLIR process. + */ +class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler { + public: + MyScopedDiagnosicHandler(mlir::MLIRContext* ctx, bool propagate); + + mlir::LogicalResult handler(mlir::Diagnostic* diag); + + ~MyScopedDiagnosicHandler(); + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/dialect.cc b/paddle/infrt/dialect/dialect.cc new file mode 100644 index 0000000000..cbcd5d0f0f --- /dev/null +++ b/paddle/infrt/dialect/dialect.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infrt::hlir::dialect { + +class CinnDialect : public ::mlir::Dialect { + public: + explicit CinnDialect(::mlir::MLIRContext* ctx); + + //! We should register this function in dialect + static llvm::StringRef getDialectNamespace() { + return "infrt::hlir::dialect"; + } +}; + +} // namespace infrt::hlir::dialect diff --git a/paddle/infrt/dialect/infrt_base.cc b/paddle/infrt/dialect/infrt_base.cc new file mode 100644 index 0000000000..b28ad5ad4b --- /dev/null +++ b/paddle/infrt/dialect/infrt_base.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/infrt_base.h" + +#include "paddle/infrt/dialect/basic_kernels.h" +#include "paddle/infrt/dialect/dense_tensor.h" +#include "paddle/infrt/dialect/test_kernels.h" + +namespace infrt::dialect { + +// ----INFRTDialect definition begin---- +void INFRTDialect::initialize() { + allowUnknownTypes(); + allowUnknownOperations(); + + addTypes(); + addTypes(); + addTypes(); + + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/basic_kernels.cpp.inc" + >(); + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/test_kernels.cpp.inc" + >(); +} + +mlir::Type INFRTDialect::parseType(mlir::DialectAsmParser &parser) const { + llvm::StringRef keyword; + if (parser.parseKeyword(&keyword)) return mlir::Type(); + // parse TensorType, for example: !infrt.tensor + if (keyword == "tensor") { + llvm::StringRef target; + llvm::StringRef layout; + llvm::StringRef precision; + + // parse "<" + if (parser.parseLess()) return mlir::Type(); + // parse target + if (parser.parseKeyword(&target)) return mlir::Type(); + auto targetType = infrt::dt::GetTargetType(target); + if (!targetType) { + parser.emitError(parser.getCurrentLocation(), "unknown target type: ") + << target; + return mlir::Type(); + } + + // parse "," + if (parser.parseComma()) return mlir::Type(); + // parse layout + if (parser.parseKeyword(&layout)) return mlir::Type(); + auto layoutType = infrt::dt::GetLayoutType(layout); + if (!layoutType) { + parser.emitError(parser.getCurrentLocation(), "unknown layout type: ") + << layout; + return mlir::Type(); + } + + // parse "," + if (parser.parseComma()) return mlir::Type(); + // parse precision + if (parser.parseKeyword(&precision)) return mlir::Type(); + auto precisionType = infrt::dt::GetPrecisionType(precision); + if (!precisionType) { + parser.emitError(parser.getCurrentLocation(), "unknown precision type: ") + << precision; + return mlir::Type(); + } + + // parse ">" + if (parser.parseGreater()) return mlir::Type(); + + return infrt::dt::TensorType::get(*targetType, *layoutType, *precisionType); + } + // parse TensorMapType, for example: !infrt.tensor_map + if (keyword == "tensor_map") { + return infrt::dt::TensorMapType::get(); + } + // parse StringType, for example: !infrt.string + if (keyword == "string") { + return infrt::dt::StringType::get(); + } + + parser.emitError(parser.getCurrentLocation(), "unknown infrt type: ") + << keyword; + return mlir::Type(); +} + +void INFRTDialect::printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const { + // print TensorType, for example: !infrt.tensor + if (type.isa()) { + auto tensorType = type.cast(); + printer << "tensor<" << tensorType.target() << ", " << tensorType.layout() + << ", " << tensorType.precision() << ">"; + return; + } + // print TensorMapType, for example: !infrt.tensor_map + if (type.isa()) { + printer << "tensor_map"; + return; + } + // print StringType, for example: !infrt.string + if (type.isa()) { + printer << "string"; + return; + } + llvm_unreachable("unknown infrt type."); +} + +// ----INFRTDialect definition end---- + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/infrt_base.h b/paddle/infrt/dialect/infrt_base.h new file mode 100644 index 0000000000..1398378957 --- /dev/null +++ b/paddle/infrt/dialect/infrt_base.h @@ -0,0 +1,73 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/infrt/dialect/infrt_base.hpp.inc" + +namespace infrt::dialect { + +class INFRTDialect : public ::mlir::Dialect { + explicit INFRTDialect(::mlir::MLIRContext *context) + : ::mlir::Dialect(getDialectNamespace(), + context, + ::mlir::TypeID::get()) { + initialize(); + } + + // parse types registered to the dialect. + mlir::Type parseType(mlir::DialectAsmParser &parser) const override; + // print types registered to the dialect. + void printType(mlir::Type type, + mlir::DialectAsmPrinter &printer) const override; + + void initialize(); + friend class ::mlir::MLIRContext; + + public: + static ::llvm::StringRef getDialectNamespace() { return "infrt"; } +}; + +} // namespace infrt::dialect + +namespace mlir { + +template +static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT + mlir::Location loc, + T constant) { + return b.getIntegerAttr(b.getI32Type(), constant); +} + +static mlir::ValueRange cvtValueToValueRange(const mlir::Value &operand) { + return mlir::ValueRange(operand); +} + +static mlir::ValueRange concatTwoValueRange(mlir::ValueRange operand_0, + mlir::ValueRange operand_1) { + mlir::SmallVector<::mlir::Value, 4> operands; + operands.append(operand_0.begin(), operand_0.end()); + operands.append(operand_1.begin(), operand_1.end()); + return operands; +} + +} // namespace mlir diff --git a/paddle/infrt/dialect/infrt_base.td b/paddle/infrt/dialect/infrt_base.td new file mode 100644 index 0000000000..61dcfe5bfb --- /dev/null +++ b/paddle/infrt/dialect/infrt_base.td @@ -0,0 +1,42 @@ +#ifndef INFRT_BASE +#define INFRT_BASE + +include "mlir/IR/OpBase.td" + +def INFRT_Dialect : Dialect { + let name = "infrt"; + + let description = [{ + The INFRT host dialect. + }]; + + let cppNamespace = "::infrt::dialect"; +} + +// Type definitions +def StringType : + Type()">, "!infrt.string type">, + BuildableType<"$_builder.getType<::infrt::dt::StringType>()">; + +def TensorType : + Type()">, "!infrt.tensor type">; + +def TensorMapType : + Type()">, "!infrt.tensor_map type">, + BuildableType<"$_builder.getType<::infrt::dt::TensorMapType>()">; + +def BufferType : OpaqueType<"b", "buffer", "buffer">; + +class INFRT_createI32Attr : NativeCodeCall< + "mlir::createI32Attr($_builder, $_loc, " # value # ")">; + +def INFRT_cvtValueToValueRange : NativeCodeCall< + "mlir::cvtValueToValueRange($0)">; + +def INFRT_concatTwoValueRange : NativeCodeCall< + "mlir::concatTwoValueRange($0, $1)">; + +class IsBoolAttrEq : Constraint< + CPred<"($0.getValue() ==" # value # ")">, + "Bool attrbute value constraint">; +#endif // INFRT_BASE diff --git a/paddle/infrt/dialect/init_infrt_dialects.cc b/paddle/infrt/dialect/init_infrt_dialects.cc new file mode 100644 index 0000000000..4bc2bf7094 --- /dev/null +++ b/paddle/infrt/dialect/init_infrt_dialects.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/init_infrt_dialects.h" + +#include + +#include "paddle/infrt/dialect/basic_kernels.h" +#include "paddle/infrt/dialect/dense_tensor.h" +#include "paddle/infrt/dialect/infrt_base.h" +#include "paddle/infrt/dialect/pd_ops.h" +#include "paddle/infrt/dialect/tensor_shape.h" + +namespace infrt { + +void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); +} + +} // namespace infrt diff --git a/paddle/infrt/dialect/init_infrt_dialects.h b/paddle/infrt/dialect/init_infrt_dialects.h new file mode 100644 index 0000000000..50caca0189 --- /dev/null +++ b/paddle/infrt/dialect/init_infrt_dialects.h @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/IR/Dialect.h" + +namespace infrt { + +void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT + +} // namespace infrt diff --git a/paddle/infrt/dialect/mlir_loader.cc b/paddle/infrt/dialect/mlir_loader.cc new file mode 100644 index 0000000000..8df8727dbe --- /dev/null +++ b/paddle/infrt/dialect/mlir_loader.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/mlir_loader.h" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "paddle/infrt/dialect/diagnostic_utils.h" +#include "paddle/infrt/dialect/init_infrt_dialects.h" + +namespace infrt::dialect { + +mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, + const std::string& mlir_source) { + context->allowUnregisteredDialects(); + RegisterCinnDialects(context->getDialectRegistry()); + context->getDialectRegistry().insert(); + + mlir::ScopedDiagnosticHandler scope_handler( + context, [](mlir::Diagnostic& diag) { + if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) + return mlir::success(); + LOG(INFO) << "diag: " << diag.str(); + return mlir::failure(true); + }); + + auto res = mlir::parseSourceString( + llvm::StringRef(mlir_source.data(), mlir_source.length()), context); + CHECK(*res) << "failed to parse MLIR string"; + return res; +} + +mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, + mlir::MLIRContext* context) { + context->allowUnregisteredDialects(); + RegisterCinnDialects(context->getDialectRegistry()); + context->getDialectRegistry().insert(); + + mlir::ScopedDiagnosticHandler scope_handler( + context, [](mlir::Diagnostic& diag) { + if (diag.getSeverity() != mlir::DiagnosticSeverity::Error) + return mlir::success(); + LOG(INFO) << "diag: " << diag.str(); + return mlir::failure(true); + }); + + return mlir::parseSourceFile(std::string(file_name), context); +} + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/mlir_loader.h b/paddle/infrt/dialect/mlir_loader.h new file mode 100644 index 0000000000..092da7d9ce --- /dev/null +++ b/paddle/infrt/dialect/mlir_loader.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include + +namespace infrt::dialect { + +mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context, + const std::string& mlir_source); +mlir::OwningModuleRef LoadMlirFile(const std::string& file_name, + mlir::MLIRContext* context); + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/mlir_loader_test.cc b/paddle/infrt/dialect/mlir_loader_test.cc new file mode 100644 index 0000000000..1b622d585a --- /dev/null +++ b/paddle/infrt/dialect/mlir_loader_test.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/mlir_loader.h" + +#include +#include +#include +#include +#include + +#include + +#include "paddle/infrt/dialect/init_infrt_dialects.h" + +namespace infrt::dialect { + +TEST(MlirLoader, basic) { + mlir::MLIRContext context; + + auto source = R"ROC( +func @main() -> f32 { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32 + + "infrt.print.f32"(%v0) : (f32) -> () + + infrt.return %value : f32 +} +)ROC"; + + auto module = LoadMlirSource(&context, source); + module->verify(); + + LOG(INFO) << "module name: " << module->getOperationName().data(); + for (auto func : module->getOps()) { + LOG(INFO) << "get func " << func.getName().str(); + int num_args = func.getNumArguments(); + for (int i = 0; i < num_args; i++) { + LOG(INFO) << "arg: " << func.getArgument(i).getArgNumber(); + } + } +} + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/mlir_tests/basic.mlir b/paddle/infrt/dialect/mlir_tests/basic.mlir new file mode 100644 index 0000000000..84b9b0fbd7 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/basic.mlir @@ -0,0 +1,40 @@ +// CHECK-LABEL: @basic_f32 +func @basic_f32() -> f32 { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32 + + // CHECK-NEXT: 3 + "infrt.print.f32"(%value) : (f32) -> () + + infrt.return %value : f32 +} + +/// ================================================================ +/// @caller call the other function @callee +func @callee.add.f32(%x : f32, %y : f32, %y1 : f32) -> f32 { + %z = "infrt.add.f32"(%x, %y) : (f32, f32) -> f32 + %z1 = "infrt.add.f32"(%z, %y1) : (f32, f32) -> f32 + infrt.return %z1 : f32 +} + +// CHECK-LABEL: @caller.add.f32 +func @caller.add.f32() -> f32 { + %x = infrt.constant.f32 1.0 + %y = infrt.constant.f32 2.0 + %y1 = infrt.constant.f32 3.0 + %z = infrt.call @callee.add.f32(%x, %y, %y1) : (f32, f32, f32) -> f32 + + // CHECK-NEXT: 6 + "infrt.print.f32"(%z) : (f32) -> () + infrt.return %z : f32 +} +/// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + +// CHECK-LABEL: @string_test +func @string_test() { + %path = infrt.get_string("this is get_string op.") + // CHECK-LABEL: string = this is get_string op. + infrt.print_string(%path) + infrt.return +} diff --git a/paddle/infrt/dialect/mlir_tests/benchmark.mlir b/paddle/infrt/dialect/mlir_tests/benchmark.mlir new file mode 100644 index 0000000000..8b4530689d --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/benchmark.mlir @@ -0,0 +1,23 @@ +// CHECK-LABEL: @benchmark +func @benchmark() { + // CHECK-LABEL: BM:add.f32:Count: 3 + // CHECK-LABEL: BM:add.f32:Duration(ns) + // CHECK-LABEL: BM:add.f32:Time Min(ns) + // CHECK-LABEL: BM:add.f32:Time 50%(ns) + // CHECK-LABEL: BM:add.f32:Time 95%(ns) + // CHECK-LABEL: BM:add.f32:Time 99%(ns) + // CHECK-LABEL: BM:add.f32:CPU Min(ns) + // CHECK-LABEL: BM:add.f32:CPU 50%(ns) + // CHECK-LABEL: BM:add.f32:CPU 95%(ns) + // CHECK-LABEL: BM:add.f32:CPU 99%(ns) + // CHECK-LABEL: BM:add.f32:CPU utilization(percent) + infrt.benchmark "add.f32"() duration_secs = 1, max_count = 3, num_warmup_runs = 3 + { + %0 = infrt.constant.f32 1.0 + %1 = infrt.constant.f32 2.0 + %res = "infrt.add.f32"(%0, %1) : (f32, f32) -> f32 + "infrt.print.f32"(%res) : (f32) -> () + infrt.return %res : f32 + } + infrt.return +} diff --git a/paddle/infrt/dialect/mlir_tests/dense_tensor.mlir b/paddle/infrt/dialect/mlir_tests/dense_tensor.mlir new file mode 100644 index 0000000000..cca7445cd5 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/dense_tensor.mlir @@ -0,0 +1,22 @@ +func @dense_shape0() { + %shape = ts.build_shape [1:i64, 57:i64] + %a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor + + infrt.return +} + +func @predict(%a: !infrt.tensor, %b: !infrt.tensor) -> (!infrt.tensor, !infrt.tensor) { + %a0 = dt.shallow_copy_tensor %a : !infrt.tensor -> !infrt.tensor + %b0 = dt.shallow_copy_tensor %b : !infrt.tensor -> !infrt.tensor + + infrt.return %a0, %b0: !infrt.tensor, !infrt.tensor +} + + +func @main() { + %shape = ts.build_shape [1:i64, 57:i64] + %a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor + + %b, %c = infrt.call @predict(%a, %a) : (!infrt.tensor, !infrt.tensor) -> (!infrt.tensor, !infrt.tensor) + infrt.return +} diff --git a/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir b/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir new file mode 100644 index 0000000000..1855a68dd9 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/paddle_ops.mlir @@ -0,0 +1,8 @@ +func @ops() { + %a = pd.Feed() : tensor + %b = pd.Feed() : tensor + + %c = "pd.Matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor + + infrt.return +} diff --git a/paddle/infrt/dialect/mlir_tests/rewrite.mlir b/paddle/infrt/dialect/mlir_tests/rewrite.mlir new file mode 100644 index 0000000000..c984fda3e6 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/rewrite.mlir @@ -0,0 +1,24 @@ +// CHECK-LABEL: @main +func @main() -> tensor { + %a = "pd.Feed"() : () -> tensor + %b = "pd.Feed"() : () -> tensor + %bias = "pd.Feed"() : () -> tensor + + %b1 = "pd.Feed"() : () -> tensor + %b2 = "pd.Feed"() : () -> tensor + %bias1 = "pd.Feed"() : () -> tensor + %bias2 = "pd.Feed"() : () -> tensor + + %c = "pd.Matmul"(%a, %b) {transpose_y=false} : (tensor, tensor) -> tensor + %d = "pd.ElementwiseAdd"(%c, %bias) {axis=1:i32} : (tensor, tensor) -> tensor + %e = "pd.Relu6"(%d) {} : (tensor) -> tensor + + %c1 = "pd.Matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor, tensor) -> tensor + %d1 = "pd.ElementwiseAdd"(%c1, %bias1) {axis=1:i32} : (tensor, tensor) -> tensor + %e1 = "pd.Relu"(%d1) {} : (tensor) -> tensor + + %c2 = "pd.Matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor, tensor) -> tensor + %d2 = "pd.ElementwiseAdd"(%c2, %bias2) {axis=1:i32} : (tensor, tensor) -> tensor + %e2 = "pd.Relu"(%d2) {} : (tensor) -> tensor + infrt.return %e2 : tensor +} \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir new file mode 100644 index 0000000000..d41d4b2f9f --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/rewrite_conv_bn.mlir @@ -0,0 +1,15 @@ +// CHECK-LABEL: @main +func @main() -> tensor { + %a = "pd.Feed"() : () -> tensor + %filter = "pd.Constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32> + %bias = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + + %scale = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + %bias2 = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + %mean = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + %var = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32> + + %c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor + %d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor + infrt.return %d : tensor +} \ No newline at end of file diff --git a/paddle/infrt/dialect/mlir_tests/tensor_map.mlir b/paddle/infrt/dialect/mlir_tests/tensor_map.mlir new file mode 100644 index 0000000000..111c01c9a1 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/tensor_map.mlir @@ -0,0 +1,31 @@ +// CHECK-LABEL: @predict +func @predict(%input:!infrt.tensor, %map: !infrt.tensor_map) -> (!infrt.tensor) { + %w = dt.get_param(%map, "create_parameter_0.w_0") -> !infrt.tensor + %bias = dt.get_param(%map, "create_parameter_1.w_0") -> !infrt.tensor + + %out = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor + + // fc + "external.matmul"(%input, %w, %out) {}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.sigmoid"(%out, %out) {}: (!infrt.tensor, !infrt.tensor) -> () + //dt.print_tensor (%out : !infrt.tensor) + + infrt.return %out : !infrt.tensor +} + +// CHECK-LABEL: @main +func @main() { + %input = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor) {value=1.0:f32} + + %path = infrt.get_string("/infrt/build/paddle/paddle_1.8_fc_model") + // CHECK-LABEL: loading params + %map = dt.load_params(%path) + + %out = infrt.call @predict(%input, %map): (!infrt.tensor, !infrt.tensor_map) -> (!infrt.tensor) + dt.print_tensor (%out : !infrt.tensor) + + infrt.return +} + diff --git a/paddle/infrt/dialect/mlir_tests/tensor_shape.mlir b/paddle/infrt/dialect/mlir_tests/tensor_shape.mlir new file mode 100644 index 0000000000..504b5b36be --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/tensor_shape.mlir @@ -0,0 +1,5 @@ +func @build_tensor1() { + %a = ts.build_shape [1:i64, 57:i64, 92:i64] + ts.print_shape %a + infrt.return +} diff --git a/paddle/infrt/dialect/mlir_tests/tensor_type.mlir b/paddle/infrt/dialect/mlir_tests/tensor_type.mlir new file mode 100644 index 0000000000..c331097ab1 --- /dev/null +++ b/paddle/infrt/dialect/mlir_tests/tensor_type.mlir @@ -0,0 +1,9 @@ +// CHECK-LABEL: test_tensor_type +func @test_tensor_type() { + %a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor) {value=1.0:f32} + // CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + dt.print_tensor (%a : !infrt.tensor) + + infrt.return +} diff --git a/paddle/infrt/dialect/ops.td b/paddle/infrt/dialect/ops.td new file mode 100644 index 0000000000..264134a447 --- /dev/null +++ b/paddle/infrt/dialect/ops.td @@ -0,0 +1,6 @@ +include "mlir/IR/OpBase.td" +include "paddle/infrt/dialect/infrt_base.td" + + +class INFRT_Op traits = []> : + Op; diff --git a/paddle/infrt/dialect/opt.cc b/paddle/infrt/dialect/opt.cc new file mode 100644 index 0000000000..d90d25230d --- /dev/null +++ b/paddle/infrt/dialect/opt.cc @@ -0,0 +1,45 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/init_infrt_dialects.h" +#include "paddle/infrt/dialect/mlir_loader.h" + +int main(int argc, char **argv) { + mlir::MLIRContext *context = infrt::Global::getMLIRContext(); + + auto ®istry = context->getDialectRegistry(); + infrt::RegisterCinnDialects(registry); + + mlir::registerCanonicalizerPass(); + + return mlir::failed( + mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry)); +} diff --git a/paddle/infrt/dialect/pd_op_base.td b/paddle/infrt/dialect/pd_op_base.td new file mode 100644 index 0000000000..af53df113d --- /dev/null +++ b/paddle/infrt/dialect/pd_op_base.td @@ -0,0 +1,77 @@ +// This file defines some basic elements of Paddle(alias pd) dialect. +// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td + +#ifndef PD_OP_BASE +#define PD_OP_BASE + +include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +def PD_Dialect : Dialect { + let name = "pd"; + + let description = [{ + The PaddlePaddle dialect. + + This dialect contains the PaddlePaddle operators. + }]; + + let cppNamespace = "::mlir::pd"; +} + +class PD_Op traits = []> : + Op; + + +class PD_PaddleAttr : + Attr()">, + "PaddlePaddle " # description # " attribute">; + + +//===----------------------------------------------------------------------===// +// PaddlePaddle type definitions +//===----------------------------------------------------------------------===// + +def PD_PDDialectType : Type()">, "PaddlePaddle type">; + +class PD_PaddleType : + Type()">, + "Paddle " # description # " type">, + BuildableType<"getType()">; + +//===----------------------------------------------------------------------===// +// Integer types +def PD_Bool : AnyTypeOf<[I<1>], "bool">; +def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">; +def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">; +def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">; +def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">; + +def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">; +def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">; +def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">; +def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">; + +def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">; +def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">; +def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">; + +// Float types +def PD_Float16 : AnyTypeOf<[F16], "16-bit float">; +def PD_Float32 : AnyTypeOf<[F32], "32-bit float">; +def PD_Float64 : AnyTypeOf<[F64], "64-bit float">; + +def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">; + + +// Tensor types + +def PD_ElementType : Type, + "pd.dtype">; + +def PD_Tensor : TensorOf<[PD_ElementType]>; + + +#endif // PD_OP_BASE diff --git a/paddle/infrt/dialect/pd_ops.cc b/paddle/infrt/dialect/pd_ops.cc new file mode 100644 index 0000000000..7ca07dd5fc --- /dev/null +++ b/paddle/infrt/dialect/pd_ops.cc @@ -0,0 +1,177 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/pd_ops.h" + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "paddle/infrt/dialect/infrt_base.h" + +namespace mlir { +namespace pd { + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.hpp.inc" +#undef GET_OP_CLASSES + +PaddleDialect::PaddleDialect(MLIRContext *context) + : Dialect("pd", context, TypeID::get()) { + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT + >(); +#undef GET_OP_LIST +} + +mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder, + mlir::Attribute value, + mlir::Type type, + mlir::Location loc) { + return builder.create(loc, value); +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT +#undef GET_OP_CLASSES + +#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT + +void ConstantOp::build(OpBuilder &builder, + OperationState &state, + Attribute value) { + if (auto elem_attr = value.dyn_cast()) { + return ConstantOp::build(builder, state, elem_attr); + } else if (value.isa()) { + ShapedType type = RankedTensorType::get(/*shape=*/{}, value.getType()); + state.addAttribute("value", DenseElementsAttr::get(type, value)); + state.addTypes(type); + return; + } + llvm_unreachable("unsupported attribute type for building pd.constant"); +} + +LogicalResult ConstantOp::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(attributes.get("value").getType()); + return success(); +} +::mlir::OpFoldResult ConstantOp::fold( + ::llvm::ArrayRef<::mlir::Attribute> operands) { + return value(); +} + +LogicalResult ElementwiseAdd::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); +} +void ElementwiseAdd::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + results.insert(context); +} + +::mlir::OpFoldResult ElementwiseAdd::fold( + llvm::ArrayRef operands) { + if (getElementTypeOrSelf(getType()).isa()) { + if (!operands[0] || !operands[1]) return {}; + DenseElementsAttr lhs = operands[0].dyn_cast(); + DenseElementsAttr rhs = operands[1].dyn_cast(); + if (!lhs || !rhs) return {}; + ShapedType type = getType().template cast(); + if (!type.hasStaticShape()) return {}; + Type etype = type.getElementType(); + if (!etype.isa()) return {}; + SmallVector values; + values.reserve(lhs.getNumElements()); + for (const auto zip : + llvm::zip(lhs.getValues(), rhs.getValues())) { + values.push_back( + std::plus()(std::get<0>(zip), std::get<1>(zip))); + } + return DenseElementsAttr::get(type, values); + } + return {}; +} + +LogicalResult ElementwiseDiv::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); +} + +LogicalResult ElementwiseMul::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); +} + +LogicalResult ElementwiseSub::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); +} + +LogicalResult MulOp::inferReturnTypes( + MLIRContext *context, + Optional location, + ValueRange operands, + DictionaryAttr attributes, + RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(operands[0].getType()); + return success(); +} + +void ReluOp::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + results.insert(context); +} + +void FusedRepeatedFCRelu::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + results.insert(context); +} + +void BatchNormOp::getCanonicalizationPatterns( + ::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) { + results.insert(context); +} + +} // namespace pd +} // namespace mlir diff --git a/paddle/infrt/dialect/pd_ops.h b/paddle/infrt/dialect/pd_ops.h new file mode 100644 index 0000000000..d09b603225 --- /dev/null +++ b/paddle/infrt/dialect/pd_ops.h @@ -0,0 +1,57 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/DerivedAttributeOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/LoopLikeInterface.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace mlir { +namespace pd { + +class PaddleDialect : public Dialect { + public: + explicit PaddleDialect(MLIRContext* context); + + static StringRef getDialectNamespace() { return "pd"; } + + /// A hook used to materialize constant values with the given type. + Operation* materializeConstant(OpBuilder& builder, + Attribute value, + Type type, + Location loc) override; + + Type parseType(DialectAsmParser& parser) const override { + return Dialect::parseType(parser); + } + void printType(Type type, DialectAsmPrinter& printer) const override { + Dialect::printType(type, printer); + } +}; + +} // namespace pd +} // namespace mlir diff --git a/paddle/infrt/dialect/pd_ops.td b/paddle/infrt/dialect/pd_ops.td new file mode 100644 index 0000000000..9e906ad0c0 --- /dev/null +++ b/paddle/infrt/dialect/pd_ops.td @@ -0,0 +1,182 @@ +#ifndef PD_OPS +#define PD_OPS + +include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/LoopLikeInterface.td" +include "mlir/IR/OpBase.td" +include "paddle/infrt/dialect/pd_op_base.td" + +def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> { + let summary = "Feed Op"; + + let description = [{ + Feed a tensor into the model. + }]; + + let arguments = (ins); + let results = (outs PD_Tensor:$out); + + let assemblyFormat = [{ + `(` `)` attr-dict `:` type($out) + }]; +} + +def PD_ConstantOp : PD_Op<"Constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods, AllTypesMatch<["value", "output"]>]> { + let summary = "constant Op"; + let description = [{}]; + + let arguments = (ins ElementsAttr:$value); + let results = (outs PD_Tensor:$output); + let hasFolder = 1; + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">, + ]; +} + +def PD_AbsOp : PD_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the absolute value of a tensor"; + + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x); + let results = (outs PD_Tensor:$y); +} + +def PD_SqrtOp : PD_Op<"sqrt", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the sqrt value of a tensor"; + + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x); + let results = (outs PD_Tensor:$y); +} + +def PD_ReluOp : PD_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Relu of a tensor"; + + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x); + let results = (outs PD_Tensor:$y); + let hasCanonicalizer = 1; +} + +def PD_Relu6Op : PD_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> { + let summary = "Computes the Relu6 of a tensor"; + + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x); + let results = (outs PD_Tensor:$y); +} + +def PD_ElementwiseAdd : PD_Op<"ElementwiseAdd", [NoSideEffect, Commutative, DeclareOpInterfaceMethods]> { + let summary = "ElementwiseAdd Op"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr:$axis); + let results = (outs PD_Tensor:$out); + let hasCanonicalizer = 1; + let hasFolder = 1; +} + +def PD_ElementwiseSub : PD_Op<"ElementwiseSub", [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ElementwiseSub Op"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr:$axis); + let results = (outs PD_Tensor:$out); +} + +def PD_ElementwiseMul : PD_Op<"ElementwiseMul", [NoSideEffect, Commutative, DeclareOpInterfaceMethods]> { + let summary = "ElementwiseMul Op"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr:$axis); + let results = (outs PD_Tensor:$out); +} + +def PD_ElementwiseDiv : PD_Op<"ElementwiseDiv", [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "ElementwiseDiv Op"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr:$axis); + let results = (outs PD_Tensor:$out); +} + +def PD_MatmulOp : PD_Op<"Matmul", [NoSideEffect]> { + let summary = "Computes the matrix mulplication result of two tensors"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, + DefaultValuedAttr:$transpose_x, + DefaultValuedAttr:$transpose_y, + DefaultValuedAttr:$alpha); + let results = (outs PD_Tensor:$out); + + //let hasCanonicalizer = 1; +} + +def PD_MulOp : PD_Op<"mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "paddle mul op"; + let description = [{}]; + + let arguments = (ins PD_Tensor:$x, PD_Tensor:$y); + let results = (outs PD_Tensor:$out); + + //let hasCanonicalizer = 1; +} + +def PD_Conv2dOp : PD_Op<"conv2d", [NoSideEffect]> { + let summary = "paddle conv2d operation"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$Input, PD_Tensor:$Filter, PD_Tensor:$Bias); + let results = (outs PD_Tensor:$Output); + + //let hasCanonicalizer = 1; +} + +def PD_BatchNormOp : PD_Op<"batch_norm", [NoSideEffect]> { + let summary = "paddle batch_norm operation"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$X, PD_Tensor:$Scale, PD_Tensor:$Bias, + PD_Tensor:$Mean, PD_Tensor:$Variance, + DefaultValuedAttr:$epsilon); + let results = (outs PD_Tensor:$Y); + + let hasCanonicalizer = 1; +} + +def PD_FusedFC : PD_Op<"FC", [NoSideEffect]> { + let summary = "Computes the Fully Connected result of two tensors"; + let description = [{ + }]; + + let arguments = (ins PD_Tensor:$input, PD_Tensor:$w, PD_Tensor:$bias, DefaultValuedAttr:$in_num_col_dims); + let results = (outs PD_Tensor:$out); +} + +def PD_FusedRepeatedFCRelu : PD_Op<"RepeatedFCRelu", [SameVariadicOperandSize, NoSideEffect]> { + let summary = ""; + let description = [{ }]; + + let arguments = (ins PD_Tensor:$input, Variadic:$w, Variadic:$bias); + let results = (outs PD_Tensor:$out); + let hasCanonicalizer = 1; +} + +#endif // PD_OPS diff --git a/paddle/infrt/dialect/pd_types.cc b/paddle/infrt/dialect/pd_types.cc new file mode 100644 index 0000000000..94856e362d --- /dev/null +++ b/paddle/infrt/dialect/pd_types.cc @@ -0,0 +1,15 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/pd_types.h" diff --git a/paddle/infrt/dialect/pd_types.h b/paddle/infrt/dialect/pd_types.h new file mode 100644 index 0000000000..6f9fe56338 --- /dev/null +++ b/paddle/infrt/dialect/pd_types.h @@ -0,0 +1,57 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file defines the types used in PaddlePaddle MLIR dialect. +// We borrowed much ideas from tensorflow mlir dialect (tf_types.h in +// tensorflow). + +#pragma once + +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace PD { + +class PaddleType : public Type { + public: + using Type::Type; + + static bool classof(Type type); +}; + +namespace detail { + +template +class PaddleTypeImpl : public Type::TypeBase { + public: + using Base = typename Type::TypeBase; + using PDBase = PaddleTypeImpl; + using Base::Base; +}; + +} // namespace detail + +#define HANDLE_PD_TYPE(pdtype, enumerant, name) \ + class pdtype##Type : public detail::PaddleTypeImpl { \ + public: \ + using PDBase::PDBase; \ + }; + +} // namespace PD +} // namespace mlir diff --git a/paddle/infrt/dialect/print_ir.cc b/paddle/infrt/dialect/print_ir.cc new file mode 100644 index 0000000000..3c5a2b6a7b --- /dev/null +++ b/paddle/infrt/dialect/print_ir.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "llvm/ADT/Optional.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ScopedPrinter.h" +#include "llvm/Support/raw_os_ostream.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AsmState.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Parser.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/Passes.h" +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/init_infrt_dialects.h" + +namespace cl = llvm::cl; + +static cl::opt inputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +llvm::raw_ostream &printIndent(int indent = 0) { + for (int i = 0; i < indent; ++i) llvm::outs() << " "; + return llvm::outs(); +} + +void printOperation(mlir::Operation *op, int indent); +void printRegion(mlir::Region ®ion, int indent); // NOLINT +void printBlock(mlir::Block &block, int indent); // NOLINT + +void printOperation(mlir::Operation *op, int indent) { + llvm::Optional module_op = llvm::None; + if (llvm::isa(op)) + module_op = llvm::dyn_cast(op); + llvm::Optional func_op = llvm::None; + if (llvm::isa(op)) func_op = llvm::dyn_cast(op); + + printIndent(indent) << "op: '" << op->getName(); + // This getName is inherited from Operation::getName + if (module_op) { + printIndent() << "@" << module_op->getName(); + } + // This getName is inherited from SymbolOpInterfaceTrait::getName, + // which return value of "sym_name" in ModuleOp or FuncOp attributes. + if (func_op) { + printIndent() << "@" << func_op->getName(); + } + printIndent() << "' with " << op->getNumOperands() << " operands" + << ", " << op->getNumResults() << " results" + << ", " << op->getAttrs().size() << " attributes" + << ", " << op->getNumRegions() << " regions" + << ", " << op->getNumSuccessors() << " successors\n"; + if (!op->getAttrs().empty()) { + printIndent(indent) << op->getAttrs().size() << " attributes:\n"; + for (mlir::NamedAttribute attr : op->getAttrs()) { + printIndent(indent + 1) << "- {" << attr.first << " : " << attr.second + << "}\n"; + } + } + + if (op->getNumRegions() > 0) { + printIndent(indent) << op->getNumRegions() << " nested regions:\n"; + for (mlir::Region ®ion : op->getRegions()) { + printRegion(region, indent + 1); + } + } +} + +void printRegion(mlir::Region ®ion, int indent) { // NOLINT + printIndent(indent) << "Region with " << region.getBlocks().size() + << " blocks:\n"; + for (mlir::Block &block : region.getBlocks()) { + printBlock(block, indent + 1); + } +} + +void printBlock(mlir::Block &block, int indent) { // NOLINT + printIndent(indent) << "Block with " << block.getNumArguments() + << " arguments" + << ", " << block.getNumSuccessors() << " successors" + << ", " << block.getOperations().size() + << " operations\n"; + + for (mlir::Operation &operation : block.getOperations()) { + printOperation(&operation, indent + 1); + } +} + +int main(int argc, char **argv) { + mlir::registerAsmPrinterCLOptions(); + mlir::registerMLIRContextCLOptions(); + mlir::registerPassManagerCLOptions(); + cl::ParseCommandLineOptions(argc, argv, "mlir demo"); + + mlir::MLIRContext *context = infrt::Global::getMLIRContext(); + context->allowUnregisteredDialects(); + auto ®istry = context->getDialectRegistry(); + infrt::RegisterCinnDialects(registry); + + // mlir will verify module automatically after parsing. + // https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051 + // mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source, + // context); + mlir::OwningModuleRef module_ref = + mlir::parseSourceFile(inputFilename, context); + std::cout << "----------print IR Structure begin----------" << std::endl; + printOperation(module_ref->getOperation(), 0); + std::cout << "----------print IR Structure end----------" << std::endl; + + module_ref->dump(); + return 0; +} diff --git a/paddle/infrt/dialect/rewrite.td b/paddle/infrt/dialect/rewrite.td new file mode 100644 index 0000000000..aa81dd72d0 --- /dev/null +++ b/paddle/infrt/dialect/rewrite.td @@ -0,0 +1,90 @@ +#ifndef INFRT_REWRITE +#define INFRT_REWRITE + +include "paddle/infrt/dialect/infrt_base.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "paddle/infrt/dialect/pd_ops.td" + +//===----------------------------------------------------------------------===// +// This is to fuse the composition: 'Matmul o ElementwiseAdd' into 'PD_FusedFC'. +// +// We have: +// (Matmul) z = x * y +// (Add) out = z + bias +// +// which corresponds to the following computation: +// (FusedFC) out = x * y + bias +// +// Todo: +// 1. Make the constrait more completely. +// 2. Consider the case of : out = bias + z +//===----------------------------------------------------------------------===// +def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, $transpose_x, $transpose_y, $alpha), $bias, $axis), + (PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">)), + [(IsBoolAttrEq<"false"> $transpose_x),(IsBoolAttrEq<"false"> $transpose_y)]>; + + +//===----------------------------------------------------------------------===// +// This is to fuse the composition: 'FusedFC o Relu' into 'FusedRepeatedFCRelu'. +// +// We have: +// (FusedFC) z = fc(x, y, bias) +// (Relu) out = relu(z) +// +// which corresponds to the following computation: +// (FusedRepeatedFCRelu) out = RepeatedFCRelu(x, [y], [bias]) +// +//===----------------------------------------------------------------------===// +def FuseFCRelu : Pat<(PD_ReluOp (PD_FusedFC $x, $y, $bias, $_)), + (PD_FusedRepeatedFCRelu $x, (INFRT_cvtValueToValueRange $y), (INFRT_cvtValueToValueRange $bias))>; + +//===----------------------------------------------------------------------===// +// This is to fold 'FusedRepeatedFCRelu' op. +// +// We have: +// (FusedRepeatedFCRelu) z = RepeatedFCRelu(x, [y, ...], [bias, ...]) +// (FusedRepeatedFCRelu) out = RepeatedFCRelu(z, [y1, ...], [bias1, ...]) +// +// which corresponds to the following computation: +// (FusedRepeatedFCRelu) out = RepeatedFCRelu(x, [y, ..., y1, ...], [bias, ..., bias1, ....]) +// +//===----------------------------------------------------------------------===// +def FuseRepeatedFCRelu2 : Pat<(PD_FusedRepeatedFCRelu (PD_FusedRepeatedFCRelu $x, $y, $bias), $y_2, $bias_2), + (PD_FusedRepeatedFCRelu $x, (INFRT_concatTwoValueRange $y, $y_2), (INFRT_concatTwoValueRange $bias, $bias_2))>; + + +//===----------------------------------------------------------------------===// +// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv' +// by deriving new 'w' and 'b' for 'Conv': +// +// We have: +// (Conv) z = w * x + b +// (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias +// +// which corresponds to the following computation: +// y = w_ * x + b_ +// where +// w_ = scale * w / sqrt(var + eps) +// b_ = B + scale * (b - mean) / sqrt(var + eps) +// +//===----------------------------------------------------------------------===// +def FuseBatchNormWithConvPattern: Pat< + (PD_BatchNormOp + (PD_Conv2dOp $input, $filter, $bias), + $scale, $bias_2, $mean, $var, $epsilon), + (PD_Conv2dOp + $input, + (PD_MulOp $filter, + (PD_ElementwiseDiv:$coefficientW + $scale, + (PD_SqrtOp (PD_ElementwiseAdd $var, (PD_ConstantOp $epsilon), (INFRT_createI32Attr<"1">))), + (INFRT_createI32Attr<"1">))), + (PD_ElementwiseAdd + $bias, + (PD_MulOp + (PD_ElementwiseSub $bias, $mean, (INFRT_createI32Attr<"1">)), + $coefficientW), + (INFRT_createI32Attr<"1">))) +>; + +#endif // INFRT_REWRITE diff --git a/paddle/infrt/dialect/tensor_shape.cc b/paddle/infrt/dialect/tensor_shape.cc new file mode 100644 index 0000000000..ef5a5525cb --- /dev/null +++ b/paddle/infrt/dialect/tensor_shape.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/tensor_shape.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace infrt::ts { +using namespace mlir; // NOLINT + +void TensorShapeDialect::initialize() { + allowUnknownTypes(); + addTypes(); + addOperations< +#define GET_OP_LIST +#include "paddle/infrt/dialect/tensor_shape.cpp.inc" + >(); +} + +Type TensorShapeDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) return Type(); + if (keyword == "shape") return ShapeType::get(getContext()); + if (keyword == "partial_shape") return PartialShapeType::get(getContext()); + + parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; + return Type(); +} + +void TensorShapeDialect::printType(::mlir::Type type, + ::mlir::DialectAsmPrinter &os) const { + if (type.isa()) { + os << "shape"; + return; + } + + if (type.isa()) { + os << "partial_shape"; + return; + } + llvm_unreachable("unexpected 'shape' type kind"); +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT + +} // namespace infrt::ts diff --git a/paddle/infrt/dialect/tensor_shape.h b/paddle/infrt/dialect/tensor_shape.h new file mode 100644 index 0000000000..bd3fa88536 --- /dev/null +++ b/paddle/infrt/dialect/tensor_shape.h @@ -0,0 +1,40 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +namespace infrt::ts { + +class ShapeType + : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +class PartialShapeType : public mlir::Type::TypeBase { + public: + using Base::Base; +}; + +using namespace mlir; // NOLINT +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/tensor_shape.hpp.inc" +#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc" + +} // namespace infrt::ts diff --git a/paddle/infrt/dialect/tensor_shape.td b/paddle/infrt/dialect/tensor_shape.td new file mode 100644 index 0000000000..d3714c8ed1 --- /dev/null +++ b/paddle/infrt/dialect/tensor_shape.td @@ -0,0 +1,49 @@ +#ifdef INFRT_OPS +#else +#define INFRT_OPS + +include "paddle/infrt/dialect/infrt_base.td" +include "paddle/infrt/dialect/tensor_shape_base.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Base class for the operation in the TensorShape dialect +class TS_Op traits = []> : + Op { + let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }]; + let printer = " return infrt::dialect::printOpWithOperands(p, *this)" ";"; +} + +def TS_BuildShapeOp : TS_Op<"build_shape", [NoSideEffect]> { + let summary = "Build tensor shape operation"; + let description = [{ + An operation that builds a tensor shape of given ranks and extents. + }]; + + let arguments = (ins I64ArrayAttr:$value); + let results = (outs TS_Shape:$output); + let assemblyFormat = "$value attr-dict"; +} + +def TS_GetNumElementsOp : TS_Op<"get_num_elements"> { + let summary = "Returns the number of elements in the shape"; + + let description = [{ + An operation that returns the number of elements in the given shape. + }]; + + let arguments = (ins TS_Shape); + let results = (outs I64); + let assemblyFormat = "operands attr-dict"; +} + +def TS_PrintShapeOp : TS_Op<"print_shape"> { + let summary = "Print tensor shape operation"; + let description = [{ + An operation that prints a tensor shape. + }]; + + let arguments = (ins TS_Shape:$shape); + let assemblyFormat = "operands attr-dict"; +} + +#endif diff --git a/paddle/infrt/dialect/tensor_shape_base.td b/paddle/infrt/dialect/tensor_shape_base.td new file mode 100644 index 0000000000..ea1c1854d7 --- /dev/null +++ b/paddle/infrt/dialect/tensor_shape_base.td @@ -0,0 +1,36 @@ +#ifdef TS_OPS_BASE +#else +#define TS_OPS_BASE + +// Tensor shape dialect. +def TensorShapeDialect : Dialect { + let name = "ts"; + + let description = [{ + The Tensor Shape dialect. + + This dialect contains operations for working with tensor shapes. + }]; + + let cppNamespace = "::infrt::ts"; +} + +// Type definition. +def TS_Shape : DialectType()">, "!ts.shape type">, +BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> { + let typeDescription = [{ + `!ts.shape type` represents a static tensor shape. +}]; +} + +def TS_PartialShape : DialectType()">, "!ts.partial_shape type">, +BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> { + let typeDescription = [{ + `!ts.partial_shape type` represents either a static tensor shape, unranked + tensor shape or a ranked tensor shape with unknown dimension sizes. +}]; +} + +#endif // TS_OPS_BASE diff --git a/paddle/infrt/dialect/test_kernels.cc b/paddle/infrt/dialect/test_kernels.cc new file mode 100644 index 0000000000..894d96f95a --- /dev/null +++ b/paddle/infrt/dialect/test_kernels.cc @@ -0,0 +1,163 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/test_kernels.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/TypeUtilities.h" + +namespace infrt::dialect { + +//===----------------------------------------------------------------------===// +// BenchmarkOp +//===----------------------------------------------------------------------===// + +// Parse the BenchmarkOp in the following format +// infrt.benchmark "add.i32"(%c : i32, %d : f32) +// max_count = 100, duration_secs = 1 { +// ... +// } + +static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT + OperationState &result) { // NOLINT + StringAttr nameAttr; + if (parser.parseAttribute(nameAttr, "name", result.attributes)) + return failure(); + + // Parse the operands, e.g. (%c : i32, %d : f32) + if (parser.parseLParen()) return failure(); + + SmallVector operands; + SmallVector types; + llvm::SMLoc type_loc = parser.getCurrentLocation(); + + if (parser.parseOptionalRParen()) { + // Parse non-empty operands + do { + // Parse %c : i32, + OpAsmParser::OperandType operand; + Type type; + + if (parser.parseOperand(operand) || parser.parseColonType(type)) + return failure(); + + operands.push_back(operand); + types.push_back(type); + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) return failure(); + } + + if (parser.resolveOperands(operands, types, type_loc, result.operands)) + return failure(); + + // Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1 + do { + StringRef attr; + Attribute resultAttr; + if (parser.parseKeyword(&attr) || parser.parseEqual() || + parser.parseAttribute(resultAttr, + parser.getBuilder().getIntegerType(32), + attr, + result.attributes)) + return failure(); + } while (succeeded(parser.parseOptionalComma())); + + // Set the default attribute num_warmup_runs to 1 if unset + auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) { + bool found = llvm::any_of(result.attributes, + [attr_name](const NamedAttribute &attr) { + return attr.first == attr_name; + }); + if (!found) { + IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value); + result.addAttribute(attr_name, default_val); + } + }; + setDefaultAttrIfUnset("num_warmup_runs", 1); + + Region *target = result.addRegion(); + return parser.parseRegion(*target, + operands, + types, + /*enableNameShadowing=*/true); +} + +// Print the BenchmarkOp in the following format +// infrt.benchmark "add.i32"(%c : i32, %d : f32) +// max_count = 100, duration_secs = 1 { +// ... +// } +static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT + p << "infrt.benchmark "; + + // Print the name attribute, e.g "add.i32" + auto name_attr = op.getAttr("name"); + p << name_attr; + + // Print the operands and types, e.g. (%c : i32, %d : f32) + p << '('; + llvm::interleaveComma(llvm::zip(op.getOperands(), op.getOperandTypes()), + p, + [&](const auto &it) { + p << std::get<0>(it) << " : " << std::get<1>(it); + }); + p << ") "; + + bool need_comma = false; + // Print the attributes, e.g. max_count = 100, duration_secs = 1 + for (auto &name_attr : op.getAttrs()) { + auto id = name_attr.first; + if (id == "name") continue; + if (need_comma) p << ", "; + auto attr = name_attr.second; + p << id << " = "; + if (auto int_attr = attr.dyn_cast()) { + int_attr.getValue().print(p.getStream(), /*isSigned=*/false); + } else { + op.emitOpError("Unexpected attribute"); + } + need_comma = true; + } + p << ' '; + + // Print the region + // Reuse the argument names provided to the op for the bbarg names within + // the region. + p.shadowRegionArgs(op.region(), op.getOperands()); + p.printRegion(op.region(), /*printEntryBlockArgs=*/false); +} + +static LogicalResult verify(BenchmarkOp op) { + // Verify that the target benchmark region has exactly one return value. + auto ®ion = op.region(); + auto &last_op = region.front().back(); + if (last_op.getName().getStringRef() != "infrt.return") { + return op.emitOpError("missing return statement"); + } + if (last_op.getNumOperands() != 1) { + return op.emitOpError( + "incorrect number of return values. One return value is expected"); + } + + return success(); +} + +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/test_kernels.cpp.inc" + +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/test_kernels.h b/paddle/infrt/dialect/test_kernels.h new file mode 100644 index 0000000000..29d4209cb7 --- /dev/null +++ b/paddle/infrt/dialect/test_kernels.h @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +namespace infrt::dialect { +using namespace mlir; // NOLINT +#define GET_OP_CLASSES +#include "paddle/infrt/dialect/test_kernels.hpp.inc" +} // namespace infrt::dialect diff --git a/paddle/infrt/dialect/test_kernels.td b/paddle/infrt/dialect/test_kernels.td new file mode 100644 index 0000000000..6aa12f252d --- /dev/null +++ b/paddle/infrt/dialect/test_kernels.td @@ -0,0 +1,65 @@ +// Operation definitions for testing. + +#ifdef TEST_OPS +#else +#define TEST_OPS + +include "paddle/infrt/dialect/infrt_base.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +// Base class for Test dialect ops. +class Test_Op traits = []> : + Op { + + // Each registered op in the Test namespace needs to provide all of a printer, + // parser and verifier. + let printer = [{ return infrt::dialect::print(p, *this); }]; + let verifier = [{ return infrt::dialect::verify(*this); }]; + let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }]; +} + +def BenchmarkOp : Test_Op<"benchmark"> { + let summary = "benchmark operation"; + let description = [{ + The "infrt.benchmark" operation benchmarks the performance of an MLIR + region by executing the given MLIR region repeatedly up to the + `duratino_secs` seconds or `max_count` times. `num_warmup_runs` specifies + the number of warm up runs to run the given MLIR region before the + benchmark starts. + + The target MLIR region can take an arbitrary number of arguments and + should return exactly one value. The arguments for the MLIR region are + provided as the operands of the infrt.benchmark op. + + Example: + infrt.benchmark "add.i32"(%c : i32, %d : f32) max_count = 100, duration_secs = 1 { + // code for benchmarking + ... + } + + infrt.benchmark "add.i32"(%c : i32) + duration_secs = 1, + max_count = 100, + num_warmup_runs = 10 { + // The MLIR code to be benchmarked goes here. + // The following code benchmarks the infrt.add.i32 kernel. + %x = infrt.add.i32 %c, %c + // The benchmarked function needs to return exactly one value. + infrt.return %x : i32 + } + }]; + + let regions = (region SizedRegion<1>:$region); + + let arguments = (ins + Variadic, + I32Attr:$duration_secs, + I32Attr:$max_count, + StrAttr:$name, + DefaultValuedAttr:$num_warmup_runs + ); + + let results = (outs); +} + +#endif // TEST_OPS diff --git a/paddle/infrt/dialect/types.cc b/paddle/infrt/dialect/types.cc new file mode 100644 index 0000000000..6d6f6a20b4 --- /dev/null +++ b/paddle/infrt/dialect/types.cc @@ -0,0 +1,17 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/dialect/types.h" + +namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir diff --git a/paddle/infrt/dialect/types.h b/paddle/infrt/dialect/types.h new file mode 100644 index 0000000000..a9a2b61871 --- /dev/null +++ b/paddle/infrt/dialect/types.h @@ -0,0 +1,16 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include diff --git a/paddle/infrt/external_kernels/CMakeLists.txt b/paddle/infrt/external_kernels/CMakeLists.txt new file mode 100644 index 0000000000..faffc3909b --- /dev/null +++ b/paddle/infrt/external_kernels/CMakeLists.txt @@ -0,0 +1,13 @@ +set(external_kernels_src "basic_kernels.cc") + +cc_library(external_kernels SHARED SRCS ${external_kernels_src}) +set_target_properties(external_kernels PROPERTIES LINK_FLAGS "${LINK_FLAGS}") + +set(basic_mlir "${CMAKE_CURRENT_SOURCE_DIR}/basic.mlir") +set(external_kernels_lib "${CMAKE_CURRENT_BINARY_DIR}/libexternal_kernels.so") +message(STATUS "basic_mlir: ${basic_mlir}") +message(STATUS "external_kernels_lib: ${external_kernels_lib}") +add_test( + NAME run_and_check_external_kernels + COMMAND sh -c "${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i ${basic_mlir} --shared_libs=${external_kernels_lib} | ${LLVM_PATH}/bin/FileCheck ${basic_mlir}" +) diff --git a/paddle/infrt/external_kernels/basic.mlir b/paddle/infrt/external_kernels/basic.mlir new file mode 100644 index 0000000000..843b12ced2 --- /dev/null +++ b/paddle/infrt/external_kernels/basic.mlir @@ -0,0 +1,21 @@ +// CHECK: basic +func @basic() -> f32 { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %v2 = "external.add.f32"(%v0, %v1) : (f32, f32) -> f32 + + // CHECK: 1 + "external.print.f32"(%v0) : (f32) -> () + // CHECK: 2 + "external.print.f32"(%v1) : (f32) -> () + + // CHECK: 3 + "external.print.f32"(%v2) : (f32) -> () + + %v3 = "external.mul.f32"(%v2, %v1) : (f32, f32) -> f32 + + // CHECK: 6 + "external.print.f32"(%v3) : (f32) -> () + + infrt.return %v3 : f32 +} diff --git a/paddle/infrt/external_kernels/basic_kernels.cc b/paddle/infrt/external_kernels/basic_kernels.cc new file mode 100644 index 0000000000..b59a8881fb --- /dev/null +++ b/paddle/infrt/external_kernels/basic_kernels.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" + +template +T add(T a, T b) { + return a + b; +} + +template +T sub(T a, T b) { + return a - b; +} + +template +T mul(T a, T b) { + return a * b; +} + +template +T div(T a, T b) { + return a / b; +} + +template +void print(T a) { + std::cout << a << std::endl; +} + +void RegisterKernels(infrt::host_context::KernelRegistry *registry) { + // int32 + registry->AddKernel("external.add.i32", INFRT_KERNEL(add)); + registry->AddKernel("external.sub.i32", INFRT_KERNEL(sub)); + registry->AddKernel("external.mul.i32", INFRT_KERNEL(mul)); + registry->AddKernel("external.div.i32", INFRT_KERNEL(div)); + registry->AddKernel("external.print.i32", INFRT_KERNEL(print)); + + // float + registry->AddKernel("external.add.f32", INFRT_KERNEL(add)); + registry->AddKernel("external.sub.f32", INFRT_KERNEL(sub)); + registry->AddKernel("external.mul.f32", INFRT_KERNEL(mul)); + registry->AddKernel("external.div.f32", INFRT_KERNEL(div)); + registry->AddKernel("external.print.f32", INFRT_KERNEL(print)); +} diff --git a/paddle/infrt/external_kernels/fc.mlir b/paddle/infrt/external_kernels/fc.mlir new file mode 100644 index 0000000000..bdac9ded2e --- /dev/null +++ b/paddle/infrt/external_kernels/fc.mlir @@ -0,0 +1,43 @@ +// CHECK-LABEL: @fc +func @fc(%input : !infrt.tensor, + %w : !infrt.tensor, + %bias : !infrt.tensor) -> !infrt.tensor +{ + %out = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor + // dt.fill_tensor_with_constant.f32 (%out : !infrt.tensor) {value=0.0:f32} + + // fc1 + "external.matmul"(%input, %w, %out) {}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.sigmoid"(%out, %out) {}: (!infrt.tensor, !infrt.tensor) -> () + + // fc2 + "external.matmul"(%out, %w, %out) {}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + "external.sigmoid"(%out, %out) {}: (!infrt.tensor, !infrt.tensor) -> () + + infrt.return %out : !infrt.tensor +} + +// CHECK-LABEL: @benchmark +func @benchmark() { + %input = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor) {value=1.0:f32} + + %w = dt.create_uninit_tensor.f32 [50, 50] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%w : !infrt.tensor) {value=2.0:f32} + + %bias = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%bias : !infrt.tensor) {value=3.0:f32} + + infrt.benchmark "add.f32"( + %input:!infrt.tensor, + %w:!infrt.tensor, + %bias:!infrt.tensor) + duration_secs = 100, max_count = 300000, num_warmup_runs = 3 + { + %res = infrt.call @fc(%input, %w, %bias) : (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> (!infrt.tensor) + infrt.return %res : !infrt.tensor + } + infrt.return +} diff --git a/paddle/infrt/external_kernels/paddle.mlir b/paddle/infrt/external_kernels/paddle.mlir new file mode 100644 index 0000000000..e7b8e9efba --- /dev/null +++ b/paddle/infrt/external_kernels/paddle.mlir @@ -0,0 +1,50 @@ +// CHECK: paddle_func +func @paddle_func() -> () { + %input = dt.create_uninit_tensor.f32 [3, 5] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor) {value=1.0:f32} + + %w = dt.create_uninit_tensor.f32 [5, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%w : !infrt.tensor) {value=2.0:f32} + + %bias = dt.create_uninit_tensor.f32 [4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%bias : !infrt.tensor) {value=3.0:f32} + + %out = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%out : !infrt.tensor) {value=0.0:f32} + + "external.fc2"(%input, %w, %bias, %out) {in_num_col_dims=3:i32, test_attr=5:i32}: (!infrt.tensor, !infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + // CHECK-LABEL: tensor: shape=shape[3,5], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + dt.print_tensor (%input : !infrt.tensor) + // CHECK-LABEL: tensor: shape=shape[5,4], values=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] + dt.print_tensor (%w : !infrt.tensor) + dt.print_tensor (%bias : !infrt.tensor) + dt.print_tensor (%out : !infrt.tensor) + + // test external.matmul + %out1 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%out1 : !infrt.tensor) {value=0.0:f32} + "external.matmul"(%input, %w, %out1) {}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + dt.print_tensor (%out1 : !infrt.tensor) + + // test external.elementwise_add + %out2 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%out2 : !infrt.tensor) {value=0.0:f32} + %bias1 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%bias1 : !infrt.tensor) {value=3.0:f32} + "external.elementwise_add"(%out1, %bias1, %out2) {axis=-1}: (!infrt.tensor, !infrt.tensor, !infrt.tensor) -> () + dt.print_tensor (%out2 : !infrt.tensor) + + // test external.relu + %out3 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%out3 : !infrt.tensor) {value=0.0:f32} + "external.relu"(%out1, %out3) {}: (!infrt.tensor, !infrt.tensor) -> () + dt.print_tensor (%out3 : !infrt.tensor) + + // test external.sigmoid + %out4 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%out4 : !infrt.tensor) {value=0.0:f32} + "external.sigmoid"(%out1, %out4) {}: (!infrt.tensor, !infrt.tensor) -> () + dt.print_tensor (%out4 : !infrt.tensor) + + infrt.return +} diff --git a/paddle/infrt/gtest_main.cc b/paddle/infrt/gtest_main.cc new file mode 100644 index 0000000000..26e2b5dcfc --- /dev/null +++ b/paddle/infrt/gtest_main.cc @@ -0,0 +1,23 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + gflags::ParseCommandLineFlags(&argc, &argv, false); + + return RUN_ALL_TESTS(); +} diff --git a/paddle/infrt/host_context/CMakeLists.txt b/paddle/infrt/host_context/CMakeLists.txt new file mode 100644 index 0000000000..fdba9af4a5 --- /dev/null +++ b/paddle/infrt/host_context/CMakeLists.txt @@ -0,0 +1,29 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + kernel_frame.cc + kernel_registry.cc + value.cc + kernel_utils.cc + symbol_table.cc + op_executable.cc + core_runtime.cc + mlir_to_runtime_translate.cc + function.cc + mlir_function_executable.cc + mlir_program_executor.cc + ) + +cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_kernel_utils SRCS kernel_utils_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_kernel_registry SRCS kernel_registry_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS}) +cc_test_tiny(test_infrt_mlir_to_runtime_translate SRCS mlir_to_runtime_translate_test.cc DEPS infrt ${MLIR_IR_LIBS}) + +infrt_exec_check(test_infrt_mlir_exec_on_basic mlir_tests/basic.mlir) +infrt_exec_check(test_infrt_mlir_exec_on_shape mlir_tests/shape.mlir) +infrt_exec_check(test_infrt_mlir_exec_on_dense_tensor mlir_tests/dense_tensor.mlir) + +add_executable(infrt-exec mlir_exec.cc) +target_link_libraries(infrt-exec infrt ${MLIR_IR_LIBS}) diff --git a/paddle/infrt/host_context/core_runtime.cc b/paddle/infrt/host_context/core_runtime.cc new file mode 100644 index 0000000000..cdb8cc99ec --- /dev/null +++ b/paddle/infrt/host_context/core_runtime.cc @@ -0,0 +1,93 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/core_runtime.h" + +#include + +#include +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/op_executable.h" +#include "paddle/infrt/host_context/symbol_table.h" + +namespace infrt::host_context { + +struct CoreRuntime::Impl { + KernelRegistry* kernel_registry{}; + SymbolTable symbol_table; + std::vector op_executables; + + mutable std::vector results; +}; + +SymbolTable* CoreRuntime::symbol_table() { return &impl_->symbol_table; } + +CoreRuntime::CoreRuntime(CoreRuntime::Impl* impl) : impl_(impl) { CHECK(impl); } + +void CoreRuntime::Execute() { + // std::cout << "CoreRuntime::Execute" << std::endl; + int op_offset = 0; + for (auto& op : impl_->op_executables) { + VLOG(3) << "running op " << op_offset++ << " " << op.name(); + op.Execute(); + } +} + +KernelRegistry* CoreRuntime::kernel_registry() const { + return impl_->kernel_registry; +} + +size_t CoreRuntime::num_ops() const { return impl_->op_executables.size(); } + +CoreRuntimeBuilder::CoreRuntimeBuilder(KernelRegistry* kernel_registry) + : CoreRuntime(new Impl) { + impl_->kernel_registry = + kernel_registry ? kernel_registry : GetCpuKernelRegistry(); +} + +OpExecutableBuilder* CoreRuntimeBuilder::NewOpExecutable( + const std::string& op_name) { + CHECK(impl_.get()); + impl_->op_executables.emplace_back( + op_name, symbol_table(), impl_->kernel_registry); + return &impl_->op_executables.back(); +} + +void CoreRuntimeBuilder::FeedInArgs( + llvm::ArrayRef> args) { + for (auto& item : args) { + symbol_table()->Register(item.first, item.second); + } +} + +void CoreRuntimeBuilder::SetKernelRegistry(KernelRegistry* x) { + CHECK(x); + impl_->kernel_registry = x; +} + +llvm::SmallVector CoreRuntime::GetResults( + llvm::ArrayRef arg_names) { + llvm::SmallVector results; + for (auto& name : arg_names) { + results.push_back(ValueRef(symbol_table()->GetValue(name))); + } + + return results; +} + +CoreRuntime::~CoreRuntime() {} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/core_runtime.h b/paddle/infrt/host_context/core_runtime.h new file mode 100644 index 0000000000..802f8b17bb --- /dev/null +++ b/paddle/infrt/host_context/core_runtime.h @@ -0,0 +1,86 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include +#include +#include + +#include "paddle/infrt/host_context/value.h" + +namespace infrt::host_context { + +class KernelRegistry; +class OpExecutable; +class OpExecutableBuilder; +class SymbolTable; + +/** + * CoreRuntime encapsulate the execution for a sequence of ops. + * Each function call will bind to a CoreRuntime instance, push the argument + * Values in to the argument-list, and get the + * result Values from the return-list. + */ +class CoreRuntime : public std::enable_shared_from_this { + public: + //! Execute a program. + void Execute(); + + //! Return the number of ops. + size_t num_ops() const; + + //! Get the results of the execution. + llvm::SmallVector // + GetResults(llvm::ArrayRef arg_names); + + std::shared_ptr getptr() { + return std::shared_ptr(this); + } + + KernelRegistry* kernel_registry() const; + + ~CoreRuntime(); + + protected: + //! Get the symbol table. + SymbolTable* symbol_table(); + + class Impl; + explicit CoreRuntime(Impl* impl); + std::unique_ptr impl_; +}; + +/** + * The builder for CoreRuntime, help to construct a function. + */ +class CoreRuntimeBuilder : public CoreRuntime { + public: + explicit CoreRuntimeBuilder(KernelRegistry* kernel_registry); + + using CoreRuntime::symbol_table; + + void SetKernelRegistry(KernelRegistry* x); + + //! Feed the input arguments, each item is a pair of arg-name and arg-value. + void FeedInArgs(llvm::ArrayRef> args); + + llvm::ArrayRef attr_names() const; + + OpExecutableBuilder* NewOpExecutable(const std::string& op_name); +}; + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/core_runtime_test.cc b/paddle/infrt/host_context/core_runtime_test.cc new file mode 100644 index 0000000000..3c0dadaad4 --- /dev/null +++ b/paddle/infrt/host_context/core_runtime_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/core_runtime.h" + +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/host_context/op_executable.h" +#include "paddle/infrt/host_context/symbol_table.h" + +namespace infrt { +namespace host_context { + +int add(int a, int b) { return a + b; } +int sub(int a, int b) { return a - b; } + +TEST(CoreRuntime, basic) { + KernelRegistry registry; + registry.AddKernel("infrt.test.addi32", INFRT_KERNEL(add)); + registry.AddKernel("infrt.test.subi32", INFRT_KERNEL(sub)); + + CoreRuntimeBuilder builder(®istry); + auto* table = builder.symbol_table(); + table->Register("a", 1); + table->Register("b", 2); + table->Register("d", 4); + + // c = a + b + auto* op0 = builder.NewOpExecutable("infrt.test.addi32"); + op0->AppendArgument("a"); + op0->AppendArgument("b"); + op0->SetResults({"c"}); + + // e = c - d + auto* op1 = builder.NewOpExecutable("infrt.test.subi32"); + op1->AppendArgument("c"); + op1->AppendArgument("d"); + op1->SetResults({"e"}); + + builder.Execute(); + + ASSERT_EQ(table->GetValue("d")->get(), 4); + ASSERT_EQ(table->GetValue("c")->get(), 3); + ASSERT_EQ(table->GetValue("e")->get(), -1); +} + +TEST(CoreRuntime, function) { + // The function: + // func(int a, int b) { + // int c = a + b + // return c + // } + KernelRegistry registry; + registry.AddKernel("infrt.test.addi32", INFRT_KERNEL(add)); + registry.AddKernel("infrt.test.subi32", INFRT_KERNEL(sub)); + + CoreRuntimeBuilder builder(®istry); + auto* table = builder.symbol_table(); + + std::vector> feeds{ + {std::make_pair("a", ValueRef(new Value(1))), // + std::make_pair("b", ValueRef(new Value(2)))}}; + builder.FeedInArgs(llvm::ArrayRef>( + feeds.data(), feeds.size())); + + ASSERT_EQ(table->Get("a"), 1); + ASSERT_EQ(table->Get("b"), 2); + ASSERT_EQ(table->size(), 2UL); + + auto* op = builder.NewOpExecutable("infrt.test.addi32"); + op->AppendArgument("a"); + op->AppendArgument("b"); + op->SetResults({"c"}); + + builder.Execute(); + + auto res = builder.GetResults({"c"}); + ASSERT_EQ(res.size(), 1UL); + ASSERT_EQ(res[0].get(), 3); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/function.cc b/paddle/infrt/host_context/function.cc new file mode 100644 index 0000000000..8b111f2645 --- /dev/null +++ b/paddle/infrt/host_context/function.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/function.h" + +namespace infrt { +namespace host_context {} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/function.h b/paddle/infrt/host_context/function.h new file mode 100644 index 0000000000..030e3b6cfb --- /dev/null +++ b/paddle/infrt/host_context/function.h @@ -0,0 +1,62 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include + +namespace infrt { +namespace host_context { + +struct Value; +struct ValueRef; + +/** + * Base class of all executable Function. + * + * This is used by `infrt.call` op, to execute a function. + */ +class Function { + public: + Function(Function&& other) + : name_(other.name_), + num_arguments_(other.num_arguments_), + num_results_(other.num_results_) {} + + Function() = delete; + + std::string name() const { return name_; } + + size_t num_arguments() const { return num_arguments_; } + size_t num_results() const { return num_results_; } + + virtual void Execute(llvm::ArrayRef arguments, + llvm::MutableArrayRef results, + bool is_region = false) const {} + + virtual ~Function() = default; + + protected: + Function(std::string name, size_t num_arguments, size_t num_results) + : name_(name), num_arguments_(num_arguments), num_results_(num_results) {} + + private: + std::string name_; + size_t num_arguments_{}; + size_t num_results_{}; +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_frame.cc b/paddle/infrt/host_context/kernel_frame.cc new file mode 100644 index 0000000000..1acb35e898 --- /dev/null +++ b/paddle/infrt/host_context/kernel_frame.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/kernel_frame.h" + +#include + +namespace infrt { +namespace host_context { + +std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) { + os << "KernelFrame: " << frame.GetNumArgs() << " args, " + << frame.GetNumResults() << " res, " << frame.GetNumResults() << " attrs"; + return os; +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_frame.h b/paddle/infrt/host_context/kernel_frame.h new file mode 100644 index 0000000000..20cb17dc7f --- /dev/null +++ b/paddle/infrt/host_context/kernel_frame.h @@ -0,0 +1,166 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include + +#include "llvm/ADT/SmallVector.h" +#include "paddle/infrt/host_context/value.h" + +namespace infrt::host_context { + +/** + * KernelFrame captures the states(input arguments, attributes, results) + * associated with a kernel invocation. + */ +class KernelFrame { + public: + int GetNumArgs() const { return num_arguments_; } + int GetNumResults() const { return num_results_; } + int GetNumAttributes() const { + return value_or_attrs_.size() - num_arguments_ - + (num_results_ == -1 ? 0 : num_results_); + } + + template + T& GetArgAt(int index) { + CHECK_LT(index, GetNumArgs()); + return value_or_attrs_[index]->get(); + } + template + const T& GetArgAt(int index) const { + CHECK_LT(index, GetNumArgs()); + return value_or_attrs_[index]->get(); + } + + Value* GetArgAt(int index) { + CHECK_LT(index, GetNumArgs()); + return value_or_attrs_[index]; + } + + // Get all arguments. + llvm::ArrayRef GetArguments() const { + return GetValues(0, num_arguments_); + } + + Value* GetAttributeAt(int idx) { + CHECK_NE(num_results_, -1) + << "Must call SetNumResults before GetAttributeAt"; + CHECK_LT(idx, + static_cast(value_or_attrs_.size() - num_arguments_ - + num_results_)); + return value_or_attrs_[num_arguments_ + num_results_ + idx]; + } + + void AddAttribute(Value* v) { + CHECK_NE(num_results_, -1) + << "Must call SetNumResults before calling AddAttribute"; + value_or_attrs_.emplace_back(v); + } + + template + void EmplaceResult(Args&&... args) { + EmplaceResult(0, std::forward(args)...); + } + + template + void EmplaceResult(int index, Args&&... args) { + SetResultAt(index, T(std::forward(args)...)); + } + + template + void SetResultAt(int index, T&& value) { + CHECK_LT(index, num_results_) << "Invalid result index"; + CHECK(value_or_attrs_[num_arguments_ + index]); + value_or_attrs_[num_arguments_ + index]->set(std::move(value)); + } + + llvm::ArrayRef GetResults() const { + return GetValues(num_arguments_, num_results_); + } + llvm::MutableArrayRef GetResults() { + return GetMutableValues(num_arguments_, num_results_); + } + + llvm::ArrayRef GetValues(size_t from, size_t length) const { + CHECK_LE(static_cast(from + length), num_arguments_ + num_results_); + if (length == 0) return {}; + + return llvm::makeArrayRef(&value_or_attrs_[from], length); + } + + llvm::MutableArrayRef GetMutableValues(size_t from, size_t length) { + CHECK_LE(static_cast(from + length), num_arguments_ + num_results_); + if (length == 0) return {}; + return llvm::makeMutableArrayRef(&value_or_attrs_[from], length); + } + + protected: + int num_arguments_{}; + int num_results_{-1}; + + llvm::SmallVector value_or_attrs_; +}; + +std::ostream& operator<<(std::ostream& os, const KernelFrame& frame); + +class KernelFrameBuilder : public KernelFrame { + public: + void AddArgument(Value* value) { + CHECK(value); + CHECK_EQ(num_results_, -1) + << "Should call AddArgument before calling SetNumResults"; + value_or_attrs_.push_back(value); + ++num_arguments_; + } + + void SetResults(llvm::ArrayRef values) { + CHECK_EQ(num_arguments_, static_cast(value_or_attrs_.size())); + CHECK_EQ(num_results_, -1); + for (Value* x : values) { + value_or_attrs_.push_back(x); + } + num_results_ = values.size(); + } + + void SetNumResults(size_t n) { + CHECK_EQ(num_arguments_, static_cast(value_or_attrs_.size())); + CHECK_EQ(num_results_, -1); + num_results_ = n; + for (size_t i = 0; i < n; i++) { + value_or_attrs_.emplace_back(new Value); + } + } + + void SetResultAt(int result_id, Value* value) { + CHECK_EQ(static_cast(value_or_attrs_.size()), + num_arguments_ + num_results_) + << "Call SetNumResults first"; + CHECK_LT(result_id + num_arguments_, + static_cast(value_or_attrs_.size())); + CHECK(value); + value_or_attrs_[num_arguments_ + result_id]->set(value); + } + + void Reset() { + value_or_attrs_.clear(); + num_arguments_ = 0; + num_results_ = -1; + } +}; + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/kernel_registry.cc b/paddle/infrt/host_context/kernel_registry.cc new file mode 100644 index 0000000000..f343dfc71b --- /dev/null +++ b/paddle/infrt/host_context/kernel_registry.cc @@ -0,0 +1,70 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/kernel_registry.h" + +#include + +#include "glog/logging.h" +#include "llvm/ADT/SmallVector.h" + +namespace infrt { +namespace host_context { + +struct KernelRegistry::Impl { + std::unordered_map data; + std::unordered_map> attr_names; +}; + +KernelRegistry::KernelRegistry() : impl_(std::make_unique()) {} + +void KernelRegistry::AddKernel(const std::string &key, + KernelImplementation fn) { + CHECK(!impl_->data.count(key)) << "kernel [" << key + << "] is registered twice"; + impl_->data.emplace(key, fn); +} + +void KernelRegistry::AddKernelAttrNameList( + const std::string &key, const std::vector &names) { + CHECK(!impl_->attr_names.count(key)) + << "kernel [" << key << "] is registered twice in attribute names"; + impl_->attr_names.emplace( + key, llvm::SmallVector(names.begin(), names.end())); +} + +KernelImplementation KernelRegistry::GetKernel(const std::string &key) const { + auto it = impl_->data.find(key); + return it != impl_->data.end() ? it->second : KernelImplementation{}; +} + +std::vector KernelRegistry::GetKernelList() const { + std::vector res(impl_->data.size()); + for (auto i : impl_->data) { + res.push_back(i.first); + } + return res; +} + +KernelRegistry::~KernelRegistry() {} + +size_t KernelRegistry::size() const { return impl_->data.size(); } + +KernelRegistry *GetCpuKernelRegistry() { + static auto registry = std::make_unique(); + return registry.get(); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_registry.h b/paddle/infrt/host_context/kernel_registry.h new file mode 100644 index 0000000000..d65969999f --- /dev/null +++ b/paddle/infrt/host_context/kernel_registry.h @@ -0,0 +1,67 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +namespace infrt { +namespace host_context { + +class KernelFrame; + +using KernelImplementation = void (*)(KernelFrame *frame); + +/** + * Hold the kernels registered in the system. + */ +class KernelRegistry { + public: + KernelRegistry(); + + void AddKernel(const std::string &key, KernelImplementation fn); + void AddKernelAttrNameList(const std::string &key, + const std::vector &names); + + KernelImplementation GetKernel(const std::string &key) const; + std::vector GetKernelList() const; + + size_t size() const; + + ~KernelRegistry(); + + private: + class Impl; + + std::unique_ptr impl_; +}; + +//! The global CPU kernel registry. +KernelRegistry *GetCpuKernelRegistry(); + +} // namespace host_context +} // namespace infrt + +/** + * compile function RegisterKernels in C way to avoid C++ name mangling. + */ +#ifdef __cplusplus +extern "C" { +#endif +void RegisterKernels(infrt::host_context::KernelRegistry *registry); +#ifdef __cplusplus +} +#endif diff --git a/paddle/infrt/host_context/kernel_registry_test.cc b/paddle/infrt/host_context/kernel_registry_test.cc new file mode 100644 index 0000000000..f36ec2a1ca --- /dev/null +++ b/paddle/infrt/host_context/kernel_registry_test.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/kernel_registry.h" + +#include + +#include "paddle/infrt/host_context/kernel_utils.h" + +namespace infrt::host_context { + +int add_i32(int a, int b) { return a + b; } + +TEST(KernelRegistry, basic) { + KernelRegistry registry; + std::string key = "infrt.test.add.i32"; + registry.AddKernel(key, INFRT_KERNEL(add_i32)); + + auto* kernel_impl = registry.GetKernel(key); + ASSERT_TRUE(kernel_impl); + + ValueRef a(1); + ValueRef b(2); + KernelFrameBuilder fbuilder; + fbuilder.AddArgument(a.get()); + fbuilder.AddArgument(b.get()); + fbuilder.SetNumResults(1); + + kernel_impl(&fbuilder); + + auto results = fbuilder.GetResults(); + ASSERT_EQ(results.size(), 1UL); + ASSERT_EQ(results[0]->get(), 3); +} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/kernel_utils.cc b/paddle/infrt/host_context/kernel_utils.cc new file mode 100644 index 0000000000..cf9476da03 --- /dev/null +++ b/paddle/infrt/host_context/kernel_utils.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/kernel_utils.h" + +namespace infrt { +namespace host_context {} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_utils.h b/paddle/infrt/host_context/kernel_utils.h new file mode 100644 index 0000000000..33812912ba --- /dev/null +++ b/paddle/infrt/host_context/kernel_utils.h @@ -0,0 +1,352 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include + +#include "paddle/infrt/host_context/kernel_frame.h" +#include "paddle/infrt/host_context/value.h" + +namespace infrt { +namespace host_context { + +template +class Argument { + public: + explicit Argument(ValueRef value) : value_(value) {} + + ValueRef& value() { return value_; } + const ValueRef& value() const { return value_; } + + T& get() const { return value_.get(); } + + private: + ValueRef value_; +}; + +/** + * RemainingArguments collects all remaining arguments in an ArrayRef. + */ +class RemainingArguments { + public: + explicit RemainingArguments(llvm::ArrayRef remaining_arguments) + : remaining_arguments_(remaining_arguments) {} + + llvm::ArrayRef values() const { return remaining_arguments_; } + size_t size() const { return remaining_arguments_.size(); } + const Value* operator[](size_t i) const { return remaining_arguments_[i]; } + + private: + llvm::ArrayRef remaining_arguments_; +}; + +/** + * RemainingResults collects all remaining results in a MutableArrayRef. + */ +class RemainingResults { + public: + explicit RemainingResults(llvm::MutableArrayRef remaining_results) + : remaining_results_(remaining_results) {} + llvm::MutableArrayRef values() { return remaining_results_; } + size_t size() const { return remaining_results_.size(); } + + template + const ValueRef& AllocateAt(int index) { + // eagerly create a ValueRef + if (remaining_results_[index].get()) return remaining_results_[index]; + remaining_results_[index] = ValueRef(new Value); + return remaining_results_[index]; + } + ValueRef& operator[](size_t i) const { return remaining_results_[i]; } + + private: + llvm::MutableArrayRef remaining_results_; +}; + +template +class Result { + public: + explicit Result(ValueRef* result) : result_(result) {} + + template + void Emplace(Args&&... args) { + ValueRef v; + Set(T(std::forward(args)...)); + } + + void Set(Argument argument) { + CHECK(!result_->IsValid()); + *result_ = argument.value(); + } + + private: + ValueRef* result_{}; +}; + +template +class Attribute { + public: + explicit Attribute(const Value* value) : value_(value) {} + + const T& get() const { return value_->get(); } + + private: + const Value* value_; +}; + +template +class ArgumentView { + using UnderlyingT = typename ViewT::UnderlyingT; + + public: + explicit ArgumentView(Value* value) + : value_(value), arg_(&value->template get()) {} + + Value* value() const { return value_; } + ViewT& get() const { return arg_; } + ViewT* operator->() const { return &get(); } + ViewT& operator*() const { return get(); } + + private: + Value* value_{}; + mutable ViewT arg_; +}; + +template +struct KernelImpl; + +template +struct TypeTag {}; + +#define INFRT_KERNEL(...) \ + ::infrt::host_context::KernelImpl::Invoke + +template +struct KernelImpl { + static void Invoke(KernelFrame* frame) { + KernelCallHelper>::template Invoke<0, 0, 0>(frame); + } + + // Helper that introspects the arguments to derive the signature and cast + // parts of the KernelFrame to their type before passing them to impl_fn. + template + struct KernelCallHelper; + + // Casts the return value of the kernel, if non-void. + // bool _ is an unnecessary parameter to make compiler allow templace specific + // in non-namespace scope. + template + struct KernelReturnHelper { + static void Invoke(KernelFrame* frame, const Args&... args) { + HandleReturn(frame, impl_fn(args...)); + } + }; + + template + struct KernelReturnHelper { + static void Invoke(KernelFrame* frame, const Args&... args) { + impl_fn(args...); + } + }; + + // Specialization to cast a single input argument(Head). + template + struct KernelCallHelper, Tail...> { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(in_idx != -1, + "Do not place Arguments after RemainingArguments"); + static_assert(out_idx == 0, "Arguments should appear before results"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes."); + + Argument arg(frame->GetArgAt(in_idx)); + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + template + struct KernelCallHelper, Tail...> { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(in_idx != -1, + "Do not place Arguments after RemainingArguments"); + static_assert(out_idx == 0, "Arguments should appear before results"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes."); + + ArgumentView arg(frame->GetArgAt(in_idx)); + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + // Specialization to cast a single result argument (Head). + template + struct KernelCallHelper, Tail...> { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(out_idx != -1, + "Do not place Results after RemainingResults"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes"); + Result arg(&frame->GetResults()[out_idx]); + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + // Specialization to cast a single attribute. + template + struct KernelCallHelper, Tail...> { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(const_idx != -1, + "Do not place Attributes after RemainingAttributes"); + Attribute arg(frame->GetAttributeAt(const_idx)); + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + // Treat other pointer as an Argument. + template + struct KernelCallHelper { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(in_idx != -1, + "Do not place Arguments after RemainingArguments"); + static_assert(out_idx == 0, "Arguments should appear before results"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes."); + auto* arg = &frame->GetArgAt(in_idx); + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + // Treat any other type as an Argument. + template + struct KernelCallHelper { + using ArgT = std::decay_t; + + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(in_idx != -1, + "Do not place Arguments after RemainingArguments"); + static_assert(out_idx == 0, "Arguments should appear before results"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes."); + + auto* value = frame->GetArgAt(in_idx); + auto&& arg = value->get(); + + KernelCallHelper< + Tail...>::template Invoke(frame, + pargs..., + arg); + } + }; + + // RemainingArguments provides an ArrayRef containing all + // remaining arguments. Useful for variadic + // kernels. + template + struct KernelCallHelper { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(in_idx != -1, + "Do not use more than one RemainingArguments"); + static_assert(out_idx == 0, "Arguments should appear before results."); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes"); + RemainingArguments remaining_arguments( + frame->GetArguments().drop_front(in_idx)); + + KernelCallHelper::template Invoke<-1, out_idx, const_idx>( + frame, pargs..., remaining_arguments); + } + }; + + // RemainingResults provides an MutableArrayRef containing all + // remaining results. + template + struct KernelCallHelper { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + static_assert(out_idx != -1, "Do not use more than one RemainingResults"); + static_assert(const_idx == 0, + "Arguments and results should appear before attributes"); + llvm::MutableArrayRef returned_results = + frame->GetResults().drop_front(out_idx); + + llvm::SmallVector result_values; + for (size_t i = 0; i < returned_results.size(); i++) + result_values.emplace_back(returned_results[i]); + + RemainingResults remaining_results(result_values); + KernelCallHelper::template Invoke( + frame, pargs..., remaining_results); + } + }; + + // No arguments left. + template + struct KernelCallHelper> { + template + static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) { + KernelReturnHelper::Invoke(frame, pargs...); + } + }; + + // Handle pair result + template + static void HandleReturn(KernelFrame* frame, std::pair&& t) { + CHECK_EQ(frame->GetNumResults(), 2); + StoreResultAt(frame, 0, std::move(t.first)); + StoreResultAt(frame, 1, std::move(t.second)); + } + + // Store the function result back to the output Value in KernelFrame. + template + static void HandleReturn(KernelFrame* frame, T&& t) { + assert(frame->GetNumResults() == 1 && "Extra results passed to kernel."); + StoreResultAt(frame, 0, std::forward(t)); + } + + // Store result as an Value output in KernelFrame. + template + static void StoreResultAt(KernelFrame* frame, int index, T&& t) { + frame->EmplaceResult>(index, std::forward(t)); + } +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/kernel_utils_test.cc b/paddle/infrt/host_context/kernel_utils_test.cc new file mode 100644 index 0000000000..1904eb106a --- /dev/null +++ b/paddle/infrt/host_context/kernel_utils_test.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/kernel_utils.h" + +#include + +namespace infrt::host_context { + +int add_i32(int a, int b) { return a + b; } +float add_f32(float a, float b) { return a + b; } +std::pair add_pair(int a, float b) { return {a, b}; } + +TEST(KernelImpl, i32) { + KernelFrameBuilder fbuilder; + ValueRef a(new Value(1)); + ValueRef b(new Value(2)); + fbuilder.AddArgument(a.get()); + fbuilder.AddArgument(b.get()); + fbuilder.SetNumResults(1); + + INFRT_KERNEL(add_i32)(&fbuilder); + auto results = fbuilder.GetResults(); + ASSERT_EQ(results.size(), 1UL); + ASSERT_EQ(results.front()->get(), 3); +} + +TEST(KernelImpl, f32) { + KernelFrameBuilder fbuilder; + ValueRef a(new Value(1.f)); + ValueRef b(new Value(2.f)); + fbuilder.AddArgument(a.get()); + fbuilder.AddArgument(b.get()); + fbuilder.SetNumResults(1); + + INFRT_KERNEL(add_f32)(&fbuilder); + auto results = fbuilder.GetResults(); + ASSERT_EQ(results.size(), 1UL); + ASSERT_EQ(results.front()->get(), 3.f); +} + +TEST(KernelImpl, pair) { + KernelFrameBuilder fbuilder; + ValueRef a(new Value(1)); + ValueRef b(new Value(3.f)); + + fbuilder.AddArgument(a.get()); + fbuilder.AddArgument(b.get()); + fbuilder.SetNumResults(2); + + INFRT_KERNEL(add_pair)(&fbuilder); + auto results = fbuilder.GetResults(); + ASSERT_EQ(results.size(), 2UL); + ASSERT_EQ(results[0]->get(), 1); + ASSERT_EQ(results[1]->get(), 3.f); +} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/mlir_exec.cc b/paddle/infrt/host_context/mlir_exec.cc new file mode 100644 index 0000000000..b0d70af5ef --- /dev/null +++ b/paddle/infrt/host_context/mlir_exec.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include + +#include "llvm/Support/DynamicLibrary.h" +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" +#include "paddle/infrt/kernel/basic_kernels.h" +#include "paddle/infrt/kernel/control_flow_kernels.h" +#include "paddle/infrt/kernel/tensor_kernels.h" +#include "paddle/infrt/kernel/tensor_shape_kernels.h" +#include "paddle/infrt/kernel/test_kernels.h" + +static llvm::cl::list cl_shared_libs( // NOLINT + "shared_libs", + llvm::cl::desc("Specify shared library with kernels."), + llvm::cl::ZeroOrMore, + llvm::cl::MiscFlags::CommaSeparated); + +int main(int argc, char** argv) { + using namespace llvm; // NOLINT + using namespace infrt; // NOLINT + cl::opt input_file("i", + cl::desc("Specify input filename"), + cl::value_desc("input file name")); + cl::ParseCommandLineOptions(argc, argv); + + mlir::MLIRContext* context = infrt::Global::getMLIRContext(); + auto module = dialect::LoadMlirFile(input_file.c_str(), context); + + host_context::KernelRegistry registry; + + kernel::RegisterBasicKernels(®istry); + kernel::RegisterTestKernels(®istry); + kernel::RegisterTensorShapeKernels(®istry); + kernel::RegisterTensorKernels(®istry); + kernel::RegisterControlFlowKernels(®istry); + + // load extra shared library + for (const auto& lib_path : cl_shared_libs) { + std::string err; + llvm::sys::DynamicLibrary dynLib = + llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err); + if (!dynLib.isValid()) { + llvm::errs() << "Load shared library failed. Error: " << err << "\n"; + return 1; + } + if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) { + auto reg_func = + reinterpret_cast(reg_sym); + reg_func(®istry); + } else { + llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path + << "\". Skip.\n"; + } + } + + host_context::TestMlir(module.get(), ®istry); + + std::cout << std::endl; + return 0; +} diff --git a/paddle/infrt/host_context/mlir_function_executable.cc b/paddle/infrt/host_context/mlir_function_executable.cc new file mode 100644 index 0000000000..5f8dacf8e4 --- /dev/null +++ b/paddle/infrt/host_context/mlir_function_executable.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/mlir_function_executable.h" + +#include + +#include // NOLINT + +#include "paddle/infrt/common/common.h" +#include "paddle/infrt/host_context/core_runtime.h" + +namespace infrt { +namespace host_context { + +template +std::string DumpToString(T& op) { // NOLINT + std::string buffer; + llvm::raw_string_ostream os(buffer); + op.print(os); + os.flush(); + return buffer; +} + +MlirFunctionExecutable::MlirFunctionExecutable( + mlir::FuncOp func_op, + KernelRegistry* kernel_registry, + MlirToRuntimeTranslator::function_defs_t& function_table) + : Function(func_op.getName().str(), + func_op.getNumArguments(), + func_op.getNumResults()), + MlirToRuntimeTranslator(&core_runtime_builder_), + region_(&func_op.getRegion()), + core_runtime_builder_(kernel_registry), + function_table_(function_table) {} + +MlirFunctionExecutable::MlirFunctionExecutable( + mlir::Region* region, + mlir::FunctionType func_type, + KernelRegistry* kernel_registry, + MlirToRuntimeTranslator::function_defs_t& function_table) + : Function("", func_type.getNumInputs(), func_type.getNumResults()), + MlirToRuntimeTranslator(&core_runtime_builder_), + region_(region), + core_runtime_builder_(kernel_registry), + function_table_(function_table) {} + +void MlirFunctionExecutable::BuildExecutables( + llvm::ArrayRef arguments, + llvm::MutableArrayRef results, + bool is_region) { + CHECK_EQ(arguments.size(), num_arguments()); + // We use the function call's arguments as op_executable's operands to avoid + // copy. + for (size_t i = 0; i < num_arguments(); i++) { + AddValue(region_->getArgument(i), arguments[i]); + } + + // build the program + auto& blocks = region_->getBlocks(); + CHECK_EQ(blocks.size(), 1UL) + << "function with more than one block is not supported yet"; + + llvm::SmallVector runtime_results; + for (auto& op : blocks.front()) { + if (EmitConstantOp(&op)) continue; + if (EmitBuildShapeOp(&op)) continue; + + llvm::SmallVector mlir_results; + if (EmitReturnOp(&op, &mlir_results)) { + if (!is_region) { + for (auto v : mlir_results) { + runtime_results.push_back(GetValue(v)); + } + } + continue; + } + + if (EmitCallOp(&op, &function_table_)) continue; + + if (EmitGeneralOp(&op)) continue; + LOG(FATAL) << "Not supported op: " << DumpToString(op); + } + + // after the block is built, we can get the result values of the whole + // function call in the runtime_results. + + mlir::SmallVector results_copied; + if (!is_region) { + for (ValueRef& x : results) { + results_copied.push_back(x.get()); + } + } + + // set a lambda function to help copy the results from the runtime results in + // the local function to outer program. + CHECK_EQ(results_copied.size(), runtime_results.size()); + this->copy_res_fn_ = [results_copied, runtime_results] { + VLOG(4) << "copy results to result"; + for (size_t i = 0; i < results_copied.size(); i++) { + VLOG(4) << ".. copy " << runtime_results[i] << " to " + << results_copied[i]; + CopyTo(*runtime_results[i], results_copied[i]); + } + }; +} + +void MlirFunctionExecutable::Execute(llvm::ArrayRef arguments, + llvm::MutableArrayRef results, + bool is_region) const { + CHECK_EQ(arguments.size(), num_arguments()); + CHECK_EQ(results.size(), num_results()); + + if (core_runtime_builder_.num_ops() == 0) { + Reference(this).BuildExecutables(arguments, results, is_region); + } + + Reference(&core_runtime_builder_).Execute(); + + copy_res_fn_(); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_function_executable.h b/paddle/infrt/host_context/mlir_function_executable.h new file mode 100644 index 0000000000..ba5fa154d6 --- /dev/null +++ b/paddle/infrt/host_context/mlir_function_executable.h @@ -0,0 +1,78 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/function.h" +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" + +namespace infrt { +namespace host_context { + +struct KernelRegistry; + +/** + * Executable function for a given MLIR function definition, mainly used in two + * scenerios: + * 1. infrt.call op + * 2. main function call + * + * A MlirFunctionExecutable might have one or more arguments and results. + */ +class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator { + public: + using function_defs_t = std::unordered_map; + + MlirFunctionExecutable(mlir::FuncOp func_op, + KernelRegistry* kernel_registry, + function_defs_t& function_table); // NOLINT + + MlirFunctionExecutable( + mlir::Region* region, + mlir::FunctionType func_type, + KernelRegistry* kernel_registry, + MlirToRuntimeTranslator::function_defs_t& function_table); // NOLINT + + /** + * Execute the function with the given arguments and results. + * NOTE the \param arguments and \param results should not be altered. + */ + void Execute(llvm::ArrayRef arguments, + llvm::MutableArrayRef results, + bool is_region = false) const; + + private: + /** + * Build the runtime executables once the function call arguments and results + * are passed in. + * This will trigger in the first execution. + */ + void BuildExecutables(llvm::ArrayRef arguments, + llvm::MutableArrayRef results, + bool is_region); + + private: + mlir::Region* region_{}; + CoreRuntimeBuilder core_runtime_builder_; + MlirToRuntimeTranslator::function_defs_t& function_table_; + std::function copy_res_fn_; +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_program_executor.cc b/paddle/infrt/host_context/mlir_program_executor.cc new file mode 100644 index 0000000000..c5009bcc97 --- /dev/null +++ b/paddle/infrt/host_context/mlir_program_executor.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/mlir_program_executor.h" + +namespace infrt { +namespace host_context {} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_program_executor.h b/paddle/infrt/host_context/mlir_program_executor.h new file mode 100644 index 0000000000..b2af4d2d79 --- /dev/null +++ b/paddle/infrt/host_context/mlir_program_executor.h @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" +#include "paddle/infrt/host_context/op_executable.h" + +namespace infrt { +namespace host_context { + +/** + * This get a MLIR program as input, it compiles it into runtime program, and + * one can retrieve the function and execute + * it by passing the input arguments. + */ +class MlirProgramExecutor : public MlirToRuntimeTranslator { + public: + CoreRuntimeBuilder runtime_builder; + mlir::ModuleOp module; + function_defs_t function_defs; + + MlirProgramExecutor(mlir::ModuleOp module, KernelRegistry* registry) + : MlirToRuntimeTranslator(module, &runtime_builder), + runtime_builder(registry), + module(module) {} + + // Build functions and generate executables. + void BuildFunctions() { EmitFunctions(); } + + void EmitFunction(mlir::FuncOp op) override { + LOG(INFO) << "Emit function: " << op.getName().str(); + function_defs[op.getName().str()] = op; + + func_executables_.emplace( + op.getName().str(), + new MlirFunctionExecutable( + op, runtime_builder.kernel_registry(), function_defs)); + } + + MlirFunctionExecutable* LookupFunc(const std::string& name) { + auto it = func_executables_.find(name); + if (it != func_executables_.end()) { + return it->second.get(); + } + return nullptr; + } + + private: + std::unordered_map> + func_executables_; +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/mlir_tests/basic.mlir b/paddle/infrt/host_context/mlir_tests/basic.mlir new file mode 100644 index 0000000000..263d588413 --- /dev/null +++ b/paddle/infrt/host_context/mlir_tests/basic.mlir @@ -0,0 +1,30 @@ +// CHECK-LABEL: basic +func @basic() -> f32 { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32 + + // CHECK: 1 + "infrt.print.f32"(%v0) : (f32) -> () + // CHECK: 2 + "infrt.print.f32"(%v1) : (f32) -> () + + // CHECK: 3 + "infrt.print.f32"(%v2) : (f32) -> () + + %v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32 + + // CHECK: 6 + "infrt.print.f32"(%v3) : (f32) -> () + + infrt.return %v3 : f32 +} + +// CHECK-LABEL: basic1 +// Check the mlir executor can work with more than one function in a file. +func @basic1() -> () { + %v0 = infrt.constant.f32 1.0 + "infrt.print.f32"(%v0) : (f32) -> () + // CHECK: 1 + infrt.return +} \ No newline at end of file diff --git a/paddle/infrt/host_context/mlir_tests/dense_tensor.mlir b/paddle/infrt/host_context/mlir_tests/dense_tensor.mlir new file mode 100644 index 0000000000..83afa1db8a --- /dev/null +++ b/paddle/infrt/host_context/mlir_tests/dense_tensor.mlir @@ -0,0 +1,9 @@ +// CHECK-LABEL: build_tensor1 +func @build_tensor1() { + %a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor + dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor) {value=1.0:f32} + // CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + dt.print_tensor (%a : !infrt.tensor) + + infrt.return +} diff --git a/paddle/infrt/host_context/mlir_tests/shape.mlir b/paddle/infrt/host_context/mlir_tests/shape.mlir new file mode 100644 index 0000000000..a3130857b0 --- /dev/null +++ b/paddle/infrt/host_context/mlir_tests/shape.mlir @@ -0,0 +1,7 @@ +// CHECK-LABEL: build_tensor1 +func @build_tensor1() { + %a = ts.build_shape [1:i64, 57:i64, 92:i64] + // CHECK: shape[1,57,92] + ts.print_shape %a + infrt.return +} \ No newline at end of file diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.cc b/paddle/infrt/host_context/mlir_to_runtime_translate.cc new file mode 100644 index 0000000000..25324b1291 --- /dev/null +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.cc @@ -0,0 +1,558 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "boost/optional.hpp" +#include "paddle/infrt/common/string.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/dialect/tensor_shape.h" +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_frame.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" +#include "paddle/infrt/host_context/op_executable.h" +#include "paddle/infrt/host_context/value.h" +#include "paddle/infrt/tensor/tensor_shape.h" + +namespace infrt::host_context { + +template +std::string DumpToString(T& op) { // NOLINT + std::string buffer; + llvm::raw_string_ostream os(buffer); + op.print(os); + os.flush(); + return buffer; +} + +struct MlirToRuntimeTranslator::Impl { + mlir::ModuleOp module; + // The runtime for a function call. + CoreRuntimeBuilder* runtime{}; + // The current working op, the translator process the ops one by one, each + // time it updates `cur_op` here to current op + // working on. + OpExecutableBuilder* cur_op{}; + + // record the current function name. + std::string cur_func_name; + + // Name to function definitions. + std::unordered_map func_defs; + + // Map from an operation to its results. + std::unordered_map> op_results; + llvm::DenseMap value_map; +}; + +bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) { + if (!infrt::Startswith(op->getName().getStringRef().str(), "infrt.constant")) + return false; + VLOG(3) << "Emitting constant op [" << op->getName().getStringRef().str() + << "]"; + + auto attr = op->getAttr("value"); + if (attr.isa()) { + if (attr.getType().isF32()) { + impl_->op_results[op] = {ValueRef( + static_cast(attr.cast().getValueAsDouble()))}; + } else if (attr.getType().isF64()) { + impl_->op_results[op] = {ValueRef(static_cast( + attr.cast().getValueAsDouble()))}; + } else { + LOG(FATAL) << "Not supported attribute type"; + } + return true; + } + + if (attr.isa()) { + if (attr.getType().isInteger(32)) { + impl_->op_results[op] = {ValueRef( + static_cast(attr.cast().getSInt()))}; + } else if (attr.getType().isInteger(64)) { + impl_->op_results[op] = {ValueRef( + static_cast(attr.cast().getSInt()))}; + } else if (attr.getType().isInteger(1)) { + impl_->op_results[op] = { + ValueRef(static_cast(attr.cast().getInt()))}; + } else { + LOG(FATAL) << "Not supported attribute type"; + } + return true; + } + + LOG(FATAL) << "Not supported constant attribute type"; + return true; +} + +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + if (attr->isa()) { + auto val = attr->cast(); + if (val.getType().isInteger(32)) { + return val.getInt(); + } + } + return boost::none; +} +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + if (attr->isa()) { + auto val = attr->cast(); + if (val.getType().isInteger(64)) { + return val.getInt(); + } + } + return boost::none; +} + +// TODO(Superjomn) Make double and float parsing share some thing. +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + if (attr->isa()) { + auto val = attr->cast(); + if (val.getType().isF32()) return val.getValueAsDouble(); + } + return boost::none; +} + +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + if (attr->isa()) { + auto val = attr->cast(); + if (val.getType().isF64()) return val.getValueAsDouble(); + } + return boost::none; +} + +template <> +boost::optional MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + return attr->cast().getValue().str(); +} + +#define PROCESS_ARRAY_INT(type__, bits__) \ + template <> \ + boost::optional> MlirToRuntimeTranslator::EmitAttribute( \ + const mlir::Attribute* attr) { \ + if (!attr->isa()) return boost::none; \ + auto array = attr->cast(); \ + CHECK(!array.empty()); \ + \ + if (!array[0].getType().isInteger(bits__)) { \ + return boost::none; \ + } \ + \ + std::vector res; \ + for (auto& v : array) { \ + res.push_back(v.cast().getInt()); \ + } \ + return res; \ + } + +PROCESS_ARRAY_INT(int16_t, 16); +PROCESS_ARRAY_INT(int32_t, 32); +PROCESS_ARRAY_INT(int64_t, 64); + +template <> +boost::optional> MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + auto array = attr->cast(); + CHECK(!array.empty()); + + if (!array[0].getType().isF32()) return boost::none; + + std::vector res; + for (auto& v : array) { + res.push_back(v.cast().getValueAsDouble()); + } + return res; +} + +template <> +boost::optional> MlirToRuntimeTranslator::EmitAttribute( + const mlir::Attribute* attr) { + if (!attr->isa()) return boost::none; + auto array = attr->cast(); + CHECK(!array.empty()); + + if (!array[0].getType().isF64()) return boost::none; + + std::vector res; + for (auto& v : array) { + res.push_back(v.cast().getValueAsDouble()); + } + return res; +} + +static bool IsReturn(mlir::Operation* op) { + return op->getName().getStringRef() == "infrt.return"; +} + +bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) { + CHECK(impl_->runtime); + impl_->cur_op = + impl_->runtime->NewOpExecutable(op->getName().getStringRef().str()); + + VLOG(3) << "processing general op : " << op->getName().getStringRef().str(); + + // process operands + for (int i = 0, e = op->getNumOperands(); i < e; i++) { + // function argument as value + auto operand = op->getOperand(i); + if (operand.getKind() == mlir::Value::Kind::BlockArgument) { + mlir::BlockArgument arg = operand.dyn_cast(); + Value* arg_value = GetValue(arg); + impl_->cur_op->AppendArgument(arg_value); + VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " " + << GetValue(arg); + continue; + } + + // normal value + Value* arg_value = GetValue(operand); + if (!arg_value) { + auto upstream_op = operand.getDefiningOp(); + arg_value = GetOpResult(upstream_op); + } + CHECK(arg_value) << "No-exist argument value found: " + << DumpToString(operand); + impl_->cur_op->AppendArgument(arg_value); + + VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " " + << GetValue(operand) << " vs " << arg_value; + } + + // process results + llvm::SmallVector res_values; + for (int i = 0, e = op->getNumResults(); i < e; i++) { + auto res = op->getResult(i); + res_values.push_back(AddValue(res)); + + VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res); + } + impl_->cur_op->SetResults(res_values); + +#ifdef INFRT_DEBUG + { + VLOG(3) << "check result"; + for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) { + VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i]; + } + } +#endif + + // process attributes + auto attrs = op->getAttrs(); + + for (size_t i = 0; i < attrs.size(); i++) { + auto& attr = attrs[i]; + if (auto v = EmitAttribute(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(*v)); + } else if (auto v = EmitAttribute(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(*v)); + } else if (auto v = EmitAttribute(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(*v)); + } else if (auto v = EmitAttribute(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(*v)); + } else if (auto v = EmitAttribute(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute>(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute>(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute>(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute>(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else if (auto v = EmitAttribute>(&attr.second)) { + impl_->cur_op->AppendAttribute(new Value(std::move(*v))); + } else { + LOG(FATAL) << "Not supported attribute type"; + } + } + + // process regions, we treat regions as attribute. + auto num_regions = op->getNumRegions(); + if (num_regions > 0) { + CHECK_EQ(num_regions, 1UL) + << "op with more than one region is not supported yet."; + auto& region = op->getRegions().front(); + auto num_blocks = region.getBlocks().size(); + CHECK_EQ(num_blocks, 1UL) + << "region with more than one block is not supported yet."; + + // process arguments + llvm::SmallVector inputs; + auto& block = region.getBlocks().front(); + for (auto arg : block.getArguments()) inputs.push_back(arg.getType()); + + // process results + // NOTE: if an op contains a region, we simply ignore the region's return + // values, + // or its return values will conflict with op's return values. + llvm::SmallVector results; + + auto func_type = + mlir::FunctionType::get(inputs, results, region.getContext()); + auto* function = impl_->cur_op->CreateFunctionExecutable( + ®ion, func_type, &impl_->func_defs); + impl_->cur_op->AppendAttribute(new Value(function)); + } + + return true; +} + +bool MlirToRuntimeTranslator::EmitReturnOp( + mlir::Operation* op, llvm::SmallVectorImpl* results) { + CHECK(results); + if (op->getName().getStringRef() == "infrt.return") { + for (size_t i = 0; i < op->getNumOperands(); i++) { + results->push_back(op->getOperand(i)); + } + + return true; + } + return false; +} + +bool MlirToRuntimeTranslator::EmitFunctions() { + for (auto func_op : impl_->module.getOps()) { + EmitFunction(func_op); + } + return true; +} + +void MlirToRuntimeTranslator::EmitFunction(mlir::FuncOp op) { + impl_->func_defs[op.getName().str()] = op; +} + +Value* MlirToRuntimeTranslator::GetOpResult(mlir::Operation* op) { + auto it = impl_->op_results.find(op); + return it == impl_->op_results.end() ? nullptr : it->second.front().get(); +} + +Value* MlirToRuntimeTranslator::GetValue(mlir::Value value) { + auto it = impl_->value_map.find(value); + return it == impl_->value_map.end() ? nullptr : it->second.get(); +} + +Value* MlirToRuntimeTranslator::AddValue(mlir::Value value) { + auto res = impl_->value_map.try_emplace(value, ValueRef(new Value)); + CHECK(res.second) << "Duplicate add mlir value [" << DumpToString(value) + << "]"; + return res.first->second.get(); +} + +MlirToRuntimeTranslator::~MlirToRuntimeTranslator() {} + +void MlirToRuntimeTranslator::UpdateCurFuncName(const std::string& name) { + impl_->cur_func_name = std::string(name); +} + +MlirToRuntimeTranslator::MlirToRuntimeTranslator(mlir::ModuleOp module, + CoreRuntimeBuilder* runtime) + : impl_(new Impl) { + CHECK(runtime); + impl_->module = module; + impl_->runtime = runtime; +} + +bool MlirToRuntimeTranslator::EmitBuildShapeOp(mlir::Operation* op) { + if (op->getName().getStringRef() != "ts.build_shape") return false; + + auto value = op->getAttr("value"); + + CHECK(value.isa()); + auto values = value.cast().getValue(); + std::vector dims; + for (auto& attr_v : values) { + dims.push_back(attr_v.cast().getInt()); + } + impl_->op_results[op] = { + ValueRef(new Value(tensor::TensorShape(llvm::ArrayRef(dims))))}; + + return true; +} + +bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op, + function_defs_t* function_table) { + CHECK(op); + CHECK(function_table); + if (op->getName().getStringRef() != "infrt.call") return false; + + impl_->cur_op = + impl_->runtime->NewOpExecutable(op->getName().getStringRef().str()); + + auto callee = op->getAttr("callee"); + auto callee_name = callee.dyn_cast(); + + // process arguments + for (size_t i = 0; i < op->getNumOperands(); i++) { + auto operand = op->getOperand(i); + auto* arg_value = GetValue(operand); + + if (!arg_value) { + auto upstream_op = operand.getDefiningOp(); + arg_value = GetOpResult(upstream_op); + } + CHECK(arg_value) << "No-exist argument value found: " + << DumpToString(operand); + impl_->cur_op->AppendArgument(arg_value); + } + + // process results + llvm::SmallVector res_values; + for (int i = 0, e = op->getNumResults(); i < e; i++) { + auto res = op->getResult(i); + res_values.push_back(AddValue(res)); + } + impl_->cur_op->SetResults(res_values); + + // process attribute + auto& table = function_table ? *function_table : impl_->func_defs; + { + // lookup the callee function + auto it = table.find(callee_name.getValue().str()); + CHECK(it != table.end()) << "can't find function [" + << callee_name.getValue().str() << "]"; + auto* function = + impl_->cur_op->CreateFunctionExecutable(it->second, &impl_->func_defs); + impl_->cur_op->AppendAttribute(new Value(function)); + } + + VLOG(3) << "Emit call " << callee_name.getValue().str() << " " + << impl_->cur_op->frame(); + return true; +} + +MlirToRuntimeTranslator::MlirToRuntimeTranslator(CoreRuntimeBuilder* runtime) + : impl_(new Impl) { + CHECK(runtime); + impl_->runtime = runtime; +} + +Value* MlirToRuntimeTranslator::AddValue(mlir::Value mlir_value, Value* value) { + auto it = impl_->value_map.try_emplace(mlir_value, ValueRef(value)); + CHECK(it.second) << "duplicate add value " << DumpToString(mlir_value); + return value; +} + +void MlirToRuntimeTranslate(mlir::ModuleOp module, + CoreRuntimeBuilder* runtime) { + MlirToRuntimeTranslator(module, runtime).Run(); +} + +/** + * Execute the mlir program in test mode -- print some debug information to + * stdout. + */ +class MlirProgramTestExecutor : public MlirToRuntimeTranslator { + public: + CoreRuntimeBuilder core_runtime; + + MlirProgramTestExecutor(mlir::ModuleOp module, KernelRegistry* registry) + : MlirToRuntimeTranslator(module, &core_runtime), + core_runtime(registry), + registry(registry) { + CHECK(registry); + } + + void Run() { + EmitFunctions(); + + CHECK(registry); + for (auto func_op : impl_->module.getOps()) { + VLOG(3) << "Running function " << func_op.getName().str(); + EmitAndRunFuncWithoutArguments(func_op); + } + } + + protected: + std::unordered_map func_def_table; + + void EmitFunction(mlir::FuncOp op) override { + CHECK(!impl_->func_defs.count(op.getName().str())) + << "Duplicate function defition found for function [" + << op.getName().str(); + impl_->func_defs.emplace(op.getName().str(), op); + } + + private: + void EmitAndRunFuncWithoutArguments(mlir::FuncOp func) { + // print the function name for llvm FileChecker macro, CHECK-LABEL + std::cout << '@' << func.getName().str() << std::endl; + if (func.getNumArguments() == + 0) { // an entry function, execute it immediately + VLOG(3) << "executing function " << func.getName().str(); + // Emit and execute each function + CoreRuntimeBuilder runtime(registry); + impl_->runtime = &runtime; + + auto& blocks = func.getBlocks(); + CHECK_EQ(blocks.size(), 1UL) + << "function with more than one block is not supported yet"; + + for (auto& op : blocks.front()) { + if (EmitConstantOp(&op)) continue; + if (EmitBuildShapeOp(&op)) continue; + llvm::SmallVector results; + if (EmitReturnOp(&op, &results)) continue; + if (EmitCallOp(&op, &impl_->func_defs)) continue; + if (EmitGeneralOp(&op)) continue; + LOG(FATAL) << "Not supported op: " << DumpToString(op); + } + + runtime.Execute(); + + } else { + VLOG(2) << "get an callable function: " << func.getName().str(); + } + } + + private: + KernelRegistry* registry{}; +}; + +void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) { + MlirProgramTestExecutor execute(module, registry); + execute.Run(); +} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate.h b/paddle/infrt/host_context/mlir_to_runtime_translate.h new file mode 100644 index 0000000000..598e81bfd9 --- /dev/null +++ b/paddle/infrt/host_context/mlir_to_runtime_translate.h @@ -0,0 +1,107 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include +#include // NOLINT +#include //NOLINT +#include // NOLINT + +namespace mlir { +class FuncOp; +class ModuleOp; +class Operation; +class Attribute; +class Value; +} // namespace mlir + +namespace infrt::host_context { + +class CoreRuntimeBuilder; +class Value; +class ValueRef; +class KernelRegistry; + +/** + * MlirToRuntimeTranslator helps to translate a MLIR program to a CoreRuntime. + * This is the base class of all the modules those parse a MLIR program and + * finally generate a CoreRuntime. + */ +class MlirToRuntimeTranslator { + public: + //! Holds all the function definitions. + using function_defs_t = std::unordered_map; + + explicit MlirToRuntimeTranslator(CoreRuntimeBuilder* runtime); + MlirToRuntimeTranslator(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); + + void Run() { EmitFunctions(); } + + virtual ~MlirToRuntimeTranslator(); + + protected: + //! Emit a "infrt.constant.*" operation, return true if succeed. + bool EmitConstantOp(mlir::Operation* op); + //! Emit a "infrt.return" operation. + bool EmitReturnOp(mlir::Operation* op, + llvm::SmallVectorImpl* results); + //! Emit a "ts.build_shape" operation. + bool EmitBuildShapeOp(mlir::Operation* op); + //! Emit an operation other than the special cases above. + bool EmitGeneralOp(mlir::Operation* op); + //! Emit all the functions. + bool EmitFunctions(); + + //! Emit a single function, this is an API that should be implemented by + //! inherients. + virtual void EmitFunction(mlir::FuncOp op); + + bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table); + + template + boost::optional EmitAttribute(const mlir::Attribute* attr); + + Value* GetOpResult(mlir::Operation* op); + + Value* GetValue(mlir::Value value); + + Value* AddValue(mlir::Value value); + + Value* AddValue(mlir::Value mlir_value, Value* value); + + void UpdateCurFuncName(const std::string& name); + + protected: + struct Impl; + std::unique_ptr impl_; +}; + +/** + * Build a CoreRuntime from a MLIR module. + */ +void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime); + +/** + * Execute a MLIR program, that is execute all the functions without input + * arguments. + * This is mainly used by testcase. + * @param module a MLIR module. + * @param registry the kernel registry containing all the valid kernels. + */ +void TestMlir(mlir::ModuleOp module, KernelRegistry* registry); + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc new file mode 100644 index 0000000000..9b85be977a --- /dev/null +++ b/paddle/infrt/host_context/mlir_to_runtime_translate_test.cc @@ -0,0 +1,160 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/mlir_to_runtime_translate.h" + +#include +#include + +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/dialect/mlir_loader.h" +#include "paddle/infrt/host_context/core_runtime.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/host_context/mlir_program_executor.h" +#include "paddle/infrt/kernel/basic_kernels.h" +#include "paddle/infrt/kernel/control_flow_kernels.h" +#include "paddle/infrt/kernel/tensor_kernels.h" +#include "paddle/infrt/kernel/tensor_shape_kernels.h" +#include "paddle/infrt/kernel/test_kernels.h" + +namespace infrt::host_context { + +TEST(MlirToRuntimeTranslate, basic) { + mlir::MLIRContext context; + + auto source = R"ROC( +func @main() -> () { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32 + %v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32 + + "infrt.print.f32"(%v1) : (f32) -> () + + infrt.return +} +)ROC"; + + auto module = dialect::LoadMlirSource(&context, source); + module->verify(); + + KernelRegistry registry; + kernel::RegisterFloatBasicKernels(®istry); + kernel::RegisterIntBasicKernels(®istry); + + TestMlir(module.get(), ®istry); +} + +TEST(TestMlir, basic) { + mlir::MLIRContext context; + + auto source = R"ROC( +func @main() -> () { + %v0 = infrt.constant.f32 1.0 + %v1 = infrt.constant.f32 2.0 + %v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32 + %v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32 + + "infrt.print.f32"(%v1) : (f32) -> () + + infrt.return +} +)ROC"; + + auto module = dialect::LoadMlirSource(&context, source); + module->verify(); + + KernelRegistry registry; + kernel::RegisterFloatBasicKernels(®istry); + kernel::RegisterIntBasicKernels(®istry); + + TestMlir(module.get(), ®istry); +} + +TEST(TestMlir, shadow_copy_tensor_profile) { + mlir::MLIRContext* context = infrt::Global::getMLIRContext(); + + auto head = R"ROC( +func @predict(%a: !infrt.tensor, %b: !infrt.tensor) -> (!infrt.tensor, !infrt.tensor) { +)ROC"; + + auto tpl0 = + "%a{0} = dt.shallow_copy_tensor %a : !infrt.tensor -> " + "!infrt.tensor"; + auto tpl1 = + "%b{0} = dt.shallow_copy_tensor %b : !infrt.tensor -> " + "!infrt.tensor"; + + auto end = R"ROC( +infrt.return %a0, %b0: !infrt.tensor, !infrt.tensor +} + )ROC"; + + std::stringstream ss; + ss << head; + for (int i = 0; i < 2000; i++) { + ss << llvm::formatv(tpl0, i).str() << "\n"; + ss << llvm::formatv(tpl1, i).str() << "\n"; + } + ss << end; + + auto content = ss.str(); + + // LOG(INFO) << "content: " << content << std::endl; + + auto module = dialect::LoadMlirSource(context, content); + module->verify(); + + host_context::KernelRegistry registry; + + kernel::RegisterBasicKernels(®istry); + kernel::RegisterTestKernels(®istry); + kernel::RegisterTensorShapeKernels(®istry); + kernel::RegisterTensorKernels(®istry); + kernel::RegisterControlFlowKernels(®istry); + + MlirProgramExecutor executor(*module, ®istry); + executor.BuildFunctions(); + + auto* func = executor.LookupFunc("predict"); + ASSERT_TRUE(func); + + std::vector in_args; + std::vector out_args( + {ValueRef(new Value(tensor::DenseHostTensor())), + ValueRef(new Value(tensor::DenseHostTensor()))}); + + auto create_tensor = [] { + tensor::DenseHostTensor a(tensor::TensorShape{{200, 3000}}, + DType(DType::Kind::F32)); + auto* data = reinterpret_cast(a.raw_data()); + for (int i = 0; i < a.shape().GetNumElements(); i++) { + data[i] = i; + } + return a; + }; + + std::vector inputs({ValueRef(new Value(create_tensor())), + ValueRef(new Value(create_tensor()))}); + in_args.assign({inputs[0].get(), inputs[1].get()}); + + for (int i = 0; i < 500; i++) { + func->Execute( + llvm::ArrayRef(in_args.data(), in_args.size()), + llvm::MutableArrayRef(out_args.data(), out_args.size())); + } +} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/op_executable.cc b/paddle/infrt/host_context/op_executable.cc new file mode 100644 index 0000000000..6b10ed4737 --- /dev/null +++ b/paddle/infrt/host_context/op_executable.cc @@ -0,0 +1,151 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/op_executable.h" + +#include + +#include "paddle/infrt/host_context/kernel_frame.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" +#include "paddle/infrt/host_context/symbol_table.h" + +namespace infrt::host_context { + +struct OpExecutable::Impl { + Impl(const std::string& op_name, + SymbolTable* symbol_table, + KernelRegistry* kernel_registry) + : name(op_name), + symbol_table(symbol_table), + kernel_registry(kernel_registry ? kernel_registry + : GetCpuKernelRegistry()) { + CHECK(kernel_registry); + } + + inline bool to_execute() const { + return !run_once || run_once && !has_executed; + } + inline void MarkRun() { has_executed = true; } + + std::string name; + SymbolTable* symbol_table{}; + KernelFrameBuilder frame; + KernelRegistry* kernel_registry{}; + + std::unique_ptr mlir_function_executable; + + KernelImplementation kernel_impl{}; + + //! Tell whether this Op should be executed only once. + bool run_once{}; + //! Tell whether this op has been executed. + bool has_executed{}; +}; + +OpExecutable::OpExecutable(OpExecutable::Impl* impl) : impl_(impl) {} + +const std::string& OpExecutable::name() const { return impl_->name; } + +OpExecutableBuilder::OpExecutableBuilder(const std::string& op_name, + SymbolTable* symbol_table, + KernelRegistry* kernel_registry) + : OpExecutable(new Impl(op_name, symbol_table, kernel_registry)) { + CHECK(impl_); + // CPU kernel registry is the default KernelRegistry. + impl_->kernel_impl = impl_->kernel_registry->GetKernel( + std::string(op_name.data(), op_name.size())); + // TODO(Superjomn) support other device other than CPU. + CHECK(impl_->kernel_impl) << "No CPU kernel called " << op_name; + + if (op_name == "dt.get_param") { + impl_->run_once = true; + } +} + +void OpExecutableBuilder::AppendArgument(const std::string& name) { + if (!impl_->symbol_table->GetValue(name)) { + impl_->symbol_table->Register(name); + } + impl_->frame.AddArgument(impl_->symbol_table->GetValue(name)); +} + +void OpExecutableBuilder::AppendArgument(Value* value) { + impl_->frame.AddArgument(value); +} + +KernelFrame& OpExecutable::frame() { return impl_->frame; } +const KernelFrame& OpExecutable::frame() const { return impl_->frame; } + +void OpExecutableBuilder::SetResults(llvm::ArrayRef result_names) { + llvm::SmallVector results; + for (size_t result_id = 0; result_id < result_names.size(); result_id++) { + Value* value = impl_->symbol_table->Register(result_names[result_id]); + results.push_back(value); + } + impl_->frame.SetResults(results); +} + +void OpExecutableBuilder::SetResults(llvm::ArrayRef results) { + impl_->frame.SetResults(results); +} + +void OpExecutableBuilder::AppendAttribute(Value* value) { + impl_->frame.AddAttribute(value); +} + +OpExecutableBuilder::OpExecutableBuilder(OpExecutableBuilder&& other) + : OpExecutable(other.impl_.release()) {} + +MlirFunctionExecutable* OpExecutableBuilder::CreateFunctionExecutable( + mlir::FuncOp op, MlirToRuntimeTranslator::function_defs_t* function_defs) { + CHECK(!impl_->mlir_function_executable); + impl_->mlir_function_executable.reset( + new MlirFunctionExecutable(op, impl_->kernel_registry, *function_defs)); + return impl_->mlir_function_executable.get(); +} + +MlirFunctionExecutable* OpExecutableBuilder::CreateFunctionExecutable( + mlir::Region* region, + mlir::FunctionType func_type, + function_defs_t* function_defs) { + CHECK(!impl_->mlir_function_executable); + impl_->mlir_function_executable.reset(new MlirFunctionExecutable( + region, func_type, impl_->kernel_registry, *function_defs)); + return impl_->mlir_function_executable.get(); +} + +void OpExecutable::Execute() { +#ifndef NDEBUG + VLOG(3) << "execute " << name() + << " --- frame args: " << impl_->frame.GetNumArgs() << " results " + << impl_->frame.GetNumResults() << " attributes " + << impl_->frame.GetNumAttributes(); + for (int i = 0; i < impl_->frame.GetNumArgs(); i++) { + VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i); + } + for (int i = 0; i < impl_->frame.GetNumResults(); i++) { + VLOG(3) << "function result: " << impl_->frame.GetResults()[i]; + } +#endif + + if (impl_->to_execute()) { + impl_->kernel_impl(&impl_->frame); + impl_->MarkRun(); + } +} + +OpExecutable::~OpExecutable() {} + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/op_executable.h b/paddle/infrt/host_context/op_executable.h new file mode 100644 index 0000000000..e2248225a5 --- /dev/null +++ b/paddle/infrt/host_context/op_executable.h @@ -0,0 +1,92 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include + +#include "mlir/IR/Function.h" +#include "mlir/IR/Region.h" + +namespace mlir { +class FuncOp; +} // namespace mlir + +namespace infrt::host_context { + +class SymbolTable; +class KernelRegistry; +class KernelFrame; +class Value; +class CoreRuntimeBuilder; +class MlirFunctionExecutable; + +/** + * OpExecutable is a runtime executable instance for an operation. It captures + * all the information(Tensors, attributes + * and so on) needed for execution. + * With the SymbolTable and op definition, it create and hold a KernelFrame once + * and execute any times. + */ +class OpExecutable { + public: + KernelFrame& frame(); + const KernelFrame& frame() const; + + void Execute(); + + const std::string& name() const; + + ~OpExecutable(); + + protected: + class Impl; + explicit OpExecutable(Impl* impl); + + std::unique_ptr impl_; +}; + +/** + * Builder to help contruct an OpExecutable. + */ +class OpExecutableBuilder : public OpExecutable { + public: + using function_defs_t = std::unordered_map; + + OpExecutableBuilder(const std::string& op_name, + SymbolTable* symbol_table, + KernelRegistry* kernel_registry = nullptr); + OpExecutableBuilder(OpExecutableBuilder&& other); + + void AppendArgument(const std::string& name); + void AppendArgument(Value* value); + + void SetResults(llvm::ArrayRef result_names); + void SetResults(llvm::ArrayRef results); + + void AppendAttribute(Value* value); + + MlirFunctionExecutable* CreateFunctionExecutable( + mlir::FuncOp op, function_defs_t* function_defs); + + MlirFunctionExecutable* CreateFunctionExecutable( + mlir::Region* region, + mlir::FunctionType func_type, + function_defs_t* function_defs); +}; + +} // namespace infrt::host_context diff --git a/paddle/infrt/host_context/op_executable_test.cc b/paddle/infrt/host_context/op_executable_test.cc new file mode 100644 index 0000000000..f981cca442 --- /dev/null +++ b/paddle/infrt/host_context/op_executable_test.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/op_executable.h" + +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/host_context/symbol_table.h" + +namespace infrt { +namespace host_context { + +int add(int a, int b) { return a + b; } + +TEST(OpExecutable, basic) { + // register kernel + KernelRegistry registry; + registry.AddKernel("infrt.test.add.i32", INFRT_KERNEL(add)); + + SymbolTable table; + table.Register("a", 1); + table.Register("b", 2); + + OpExecutableBuilder executable("infrt.test.add.i32", &table, ®istry); + executable.AppendArgument("a"); + executable.AppendArgument("b"); + executable.SetResults({"c"}); + + executable.Execute(); + + // check the kernel frame has the result. + auto results = executable.frame().GetResults(); + ASSERT_EQ(results.size(), 1UL); + ASSERT_EQ(results.front()->get(), 3); + + // check symbol table contains the same result instance. + LOG(INFO) << "type: " << table.GetValue("c")->type_info(); + int c = table.GetValue("c")->get(); + ASSERT_EQ(c, 3); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/symbol_table.cc b/paddle/infrt/host_context/symbol_table.cc new file mode 100644 index 0000000000..318dc0cc55 --- /dev/null +++ b/paddle/infrt/host_context/symbol_table.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/symbol_table.h" + +#include + +namespace infrt { +namespace host_context { + +struct SymbolTable::Impl { + std::unordered_map data; +}; + +SymbolTable::SymbolTable() : impl_(new Impl) {} + +Value* SymbolTable::Register(const std::string& key) { + CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]"; + auto newitem = ValueRef(new Value); + impl_->data.emplace(key, newitem); + return newitem.get(); +} + +Value* SymbolTable::Register(const std::string& key, ValueRef value) { + CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]"; + impl_->data.emplace(key, value); + return value.get(); +} + +Value* SymbolTable::GetValue(const std::string& key) const { + auto it = impl_->data.find(std::string(key)); + return it != impl_->data.end() ? it->second.get() : nullptr; +} + +// @{ +#define REGISTER_TYPE__(T) \ + template <> \ + T SymbolTable::Get(const std::string& key) { \ + auto it = impl_->data.find(std::string(key)); \ + CHECK(it != impl_->data.end()) << "No value called " << key; \ + return it->second->get(); \ + } +REGISTER_TYPE__(int32_t); +REGISTER_TYPE__(float); +REGISTER_TYPE__(double); +REGISTER_TYPE__(int64_t); +#undef REGISTER_TYPE__ +// @} + +SymbolTable::~SymbolTable() {} + +size_t SymbolTable::size() const { return impl_->data.size(); } + +// @{ +#define REGISTER_TYPE__(T) \ + template <> \ + Value* SymbolTable::Register(const std::string& key, T&& v) { \ + CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]"; \ + auto newitem = ValueRef(v); \ + impl_->data.emplace(key, newitem); \ + return newitem.get(); \ + } +REGISTER_TYPE__(int) +REGISTER_TYPE__(float) +REGISTER_TYPE__(double) +REGISTER_TYPE__(bool) +#undef REGISTER_TYPE__ +// @} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/symbol_table.h b/paddle/infrt/host_context/symbol_table.h new file mode 100644 index 0000000000..805215a78c --- /dev/null +++ b/paddle/infrt/host_context/symbol_table.h @@ -0,0 +1,65 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include + +#include "paddle/infrt/host_context/value.h" + +namespace infrt { +namespace host_context { + +/** + * SymbolTable holds all the states of the kernel graph in the runtime. + */ +class SymbolTable { + public: + SymbolTable(); + + /** + * Register a state called \p key. + */ + Value* Register(const std::string& key); + + Value* Register(const std::string& key, ValueRef value); + + /** + * Register a state and set value. + */ + template + Value* Register(const std::string& key, T&& v); + + size_t size() const; + + /** + * Get a state called \p key. + */ + Value* GetValue(const std::string& key) const; + + template + T Get(const std::string& key); + + ~SymbolTable(); + + private: + class Impl; + + std::unique_ptr impl_; +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/value.cc b/paddle/infrt/host_context/value.cc new file mode 100644 index 0000000000..8c3ccba3d0 --- /dev/null +++ b/paddle/infrt/host_context/value.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/value.h" + +#include "paddle/infrt/tensor/dense_tensor_view.h" + +namespace infrt { +namespace host_context { + +ValueRef::ValueRef(int32_t val) : Shared(new Value(val)) {} +ValueRef::ValueRef(int64_t val) : Shared(new Value(val)) {} +ValueRef::ValueRef(float val) : Shared(new Value(val)) {} +ValueRef::ValueRef(double val) : Shared(new Value(val)) {} +ValueRef::ValueRef(bool val) : Shared(new Value(val)) {} + +const char* Value::type_info() const { return __type_info__; } + +void CopyTo(const Value& from, Value* to) { + CHECK(from.valid()) << "from value is not valid, can't be copied"; + CHECK(to) << "to is not valid"; + visit( + [&](auto&& arg) { + using T = std::decay_t; + if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else if (std::is_same>::value) + to->data = arg; + else if (std::is_same>::value) + to->data = arg; + else if (std::is_same::value) + to->data = arg; + else + LOG(FATAL) << "Not supported Value copy: " << typeid(T).name(); + }, + from.data); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/value.h b/paddle/infrt/host_context/value.h new file mode 100644 index 0000000000..4a2b92a7e6 --- /dev/null +++ b/paddle/infrt/host_context/value.h @@ -0,0 +1,156 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include +#include +#include + +#include "paddle/infrt/common/object.h" +#include "paddle/infrt/common/shared.h" +#include "paddle/infrt/host_context/function.h" +#include "paddle/infrt/support/variant.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" +#include "paddle/infrt/tensor/dense_tensor_view.h" +#include "paddle/infrt/tensor/tensor_map.h" +#include "paddle/infrt/tensor/tensor_shape.h" + +namespace infrt { +namespace host_context { + +struct MlirFunctionExecutable; + +using ValueVariantType = Variant, + std::vector, + std::vector, + std::vector, + std::vector>; + +//! Copy content from \param from to \param to. +void CopyTo(const Value& from, Value* to); + +/** + * Represents any data type for value in host context. + */ +class Value : public common::Object { + public: + using variant_type = ValueVariantType; + + explicit Value() {} // NOLINT + explicit Value(int32_t x) : data(x) {} + explicit Value(int64_t x) : data(x) {} + explicit Value(float x) : data(x) {} + explicit Value(double x) : data(x) {} + explicit Value(bool x) : data(x) {} + explicit Value(std::string x) : data(x) {} + explicit Value(tensor::TensorMap&& x) : data(x) {} + explicit Value(std::vector&& x) : data(x) {} + explicit Value(std::vector&& x) : data(x) {} + explicit Value(std::vector&& x) : data(x) {} + explicit Value(std::vector&& x) : data(x) {} + explicit Value(std::vector&& x) : data(x) {} + explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {} + explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {} + explicit Value(MlirFunctionExecutable* x) : data(x) {} + + template + const T& get() const { + return data.get(); + } + template + T& get() { + return data.get(); + } + + template + void set(T&& v) { + data = std::move(v); + } + + void set(Value* v) { data = std::move(v->data); } + + bool valid() const { return true; } + + const char* type_info() const override; + + friend void CopyTo(const Value& from, Value* to); + + private: + ValueVariantType data; + static constexpr const char* __type_info__ = "host_context_value"; +}; + +/** + * Represents a counted reference of a Value. + */ +class ValueRef : common::Shared { + public: + ValueRef() = default; + explicit ValueRef(Value* n) : common::Shared(n) {} + explicit ValueRef(int32_t val); + explicit ValueRef(int64_t val); + explicit ValueRef(float val); + explicit ValueRef(double val); + explicit ValueRef(bool val); + + using common::Shared::get; + using common::Shared::Reset; + using common::Shared::operator->; + using common::Shared::operator*; + //! Get a readonly data. + template + const T& get() const { + CHECK(p_); + return p_->get(); + } + + template + T& get() { + CHECK(p_); + return p_->get(); + } + + //! Assign a data. + template + void Assign(const T& x) { + if (!p_) { + p_ = common::make_shared(); + } + *p_ = x; + } + + template + void Assign(Args... args) { + p_ = common::make_shared(std::forward(args)...); + } + + inline bool IsValid() { return p_; } +}; + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/host_context/value_test.cc b/paddle/infrt/host_context/value_test.cc new file mode 100644 index 0000000000..48d49478ce --- /dev/null +++ b/paddle/infrt/host_context/value_test.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/host_context/value.h" + +#include + +namespace infrt { +namespace host_context { + +TEST(ValueRef, test) { + ValueRef x(12); + ASSERT_EQ(x.get(), 12); + + ValueRef y(1.2f); + ASSERT_EQ(y.get(), 1.2f); + + ValueRef z(true); + ASSERT_EQ(z.get(), true); +} + +} // namespace host_context +} // namespace infrt diff --git a/paddle/infrt/kernel/CMakeLists.txt b/paddle/infrt/kernel/CMakeLists.txt new file mode 100644 index 0000000000..da858aad28 --- /dev/null +++ b/paddle/infrt/kernel/CMakeLists.txt @@ -0,0 +1,9 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + basic_kernels.cc + test_kernels.cc + tensor_shape_kernels.cc + tensor_kernels.cc + control_flow_kernels.cc + ) diff --git a/paddle/infrt/kernel/basic_kernels.cc b/paddle/infrt/kernel/basic_kernels.cc new file mode 100644 index 0000000000..d7f2c38651 --- /dev/null +++ b/paddle/infrt/kernel/basic_kernels.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/basic_kernels.h" + +#include +#include + +#include "llvm/Support/raw_ostream.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" + +using infrt::host_context::Attribute; + +namespace infrt::kernel { + +template +T add(T a, T b) { + return a + b; +} + +template +T sub(T a, T b) { + return a - b; +} + +template +T mul(T a, T b) { + return a * b; +} + +template +T div(T a, T b) { + return a / b; +} + +template +void print(T a) { + std::cout << a << std::endl; +} + +static std::string GetString(Attribute value) { + return value.get(); +} + +static void PrintString(const std::string &str) { + llvm::outs() << "string = " << str << '\n'; + llvm::outs().flush(); +} + +void RegisterBasicKernels(host_context::KernelRegistry *registry) { + RegisterIntBasicKernels(registry); + RegisterFloatBasicKernels(registry); + registry->AddKernel("infrt.get_string", INFRT_KERNEL(GetString)); + registry->AddKernel("infrt.print_string", INFRT_KERNEL(PrintString)); +} + +void RegisterIntBasicKernels(host_context::KernelRegistry *registry) { + registry->AddKernel("infrt.add.i32", INFRT_KERNEL(add)); + registry->AddKernel("infrt.sub.i32", INFRT_KERNEL(sub)); + registry->AddKernel("infrt.mul.i32", INFRT_KERNEL(mul)); + registry->AddKernel("infrt.div.i32", INFRT_KERNEL(div)); + registry->AddKernel("infrt.print.i32", INFRT_KERNEL(print)); +} + +void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) { + registry->AddKernel("infrt.add.f32", INFRT_KERNEL(add)); + registry->AddKernel("infrt.sub.f32", INFRT_KERNEL(sub)); + registry->AddKernel("infrt.mul.f32", INFRT_KERNEL(mul)); + registry->AddKernel("infrt.div.f32", INFRT_KERNEL(div)); + registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print)); +} + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/basic_kernels.h b/paddle/infrt/kernel/basic_kernels.h new file mode 100644 index 0000000000..9e98885cf6 --- /dev/null +++ b/paddle/infrt/kernel/basic_kernels.h @@ -0,0 +1,34 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace infrt::host_context { + +struct KernelRegistry; + +} // namespace infrt::host_context + +namespace infrt::kernel { + +/** + * Register all the basic kernels to \p registry. + */ +void RegisterBasicKernels(host_context::KernelRegistry* registry); + +void RegisterIntBasicKernels(host_context::KernelRegistry* registry); +void RegisterFloatBasicKernels(host_context::KernelRegistry* registry); + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/control_flow_kernels.cc b/paddle/infrt/kernel/control_flow_kernels.cc new file mode 100644 index 0000000000..6cc94dbcce --- /dev/null +++ b/paddle/infrt/kernel/control_flow_kernels.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/control_flow_kernels.h" + +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" + +namespace infrt { +namespace kernel { + +static void INFRTCall( + host_context::RemainingArguments args, + host_context::RemainingResults results, + host_context::Attribute fn) { + VLOG(3) << "running call kernel ..."; + CHECK_EQ(fn.get()->num_arguments(), args.size()); + CHECK_EQ(fn.get()->num_results(), results.size()); + + for (auto& v : results.values()) { + CHECK(v.get()); + } + fn.get()->Execute(args.values(), results.values()); +} + +void RegisterControlFlowKernels(host_context::KernelRegistry* registry) { + registry->AddKernel("infrt.call", INFRT_KERNEL(INFRTCall)); +} + +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/control_flow_kernels.h b/paddle/infrt/kernel/control_flow_kernels.h new file mode 100644 index 0000000000..5fa6b985f0 --- /dev/null +++ b/paddle/infrt/kernel/control_flow_kernels.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/infrt/host_context/function.h" +#include "paddle/infrt/host_context/kernel_utils.h" + +namespace infrt { + +namespace host_context { +struct KernelRegistry; +} // namespace host_context + +namespace kernel { + +void RegisterControlFlowKernels(host_context::KernelRegistry* registry); + +} // namespace kernel +} // namespace infrt diff --git a/paddle/infrt/kernel/tensor_kernels.cc b/paddle/infrt/kernel/tensor_kernels.cc new file mode 100644 index 0000000000..2fa477aa4d --- /dev/null +++ b/paddle/infrt/kernel/tensor_kernels.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/tensor_kernels.h" + +#include +#include + +#include "paddle/infrt/common/global.h" +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" +#include "paddle/infrt/tensor/dense_tensor_view.h" +#include "paddle/infrt/tensor/tensor_map.h" +#include "paddle/infrt/tensor/tensor_shape.h" + +namespace infrt::kernel { +using namespace host_context; // NOLINT +using namespace tensor; // NOLINT + +/// ===== Kernel begin ==== + +template +DenseHostTensor CreateUninitTensor(Attribute> shape) { + const auto &shape_data = shape.get(); + auto array = llvm::ArrayRef(shape_data.data(), shape_data.size()); + auto type = GetDType(); + return DenseHostTensor(TensorShape(array), type); +} + +void PrintTensor(const DenseHostTensor &tensor) { + std::cout << tensor << std::endl; +} + +template +void FillTensorWithConstant(DenseHostTensor *tensor, Attribute v) { + MutableDTArrayView(tensor).Fill(v.get()); +} + +TensorMap LoadParams(const std::string &path) { + return *(infrt::tensor::LoadParams(path)); +} + +DenseHostTensor GetParam(TensorMap map, Attribute nameAttr) { + auto &name = nameAttr.get(); + return *(map[name]); +} + +DenseHostTensor ShallowCopyTensor(DenseHostTensor v) { return v; } + +/// ===== Kernel end ==== + +void RegisterTensorKernels(host_context::KernelRegistry *registry) { + registry->AddKernel("dt.create_uninit_tensor.f32", + INFRT_KERNEL(CreateUninitTensor)); + registry->AddKernelAttrNameList("dt.create_uninit_tensor.f32", {"shape"}); + registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor)); + registry->AddKernel("dt.fill_tensor_with_constant.f32", + INFRT_KERNEL(FillTensorWithConstant)); + registry->AddKernel("dt.fill_tensor_with_constant.f64", + INFRT_KERNEL(FillTensorWithConstant)); + registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams)); + registry->AddKernel("dt.get_param", INFRT_KERNEL(GetParam)); + registry->AddKernel("dt.shallow_copy_tensor", + INFRT_KERNEL(ShallowCopyTensor)); +} + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/tensor_kernels.h b/paddle/infrt/kernel/tensor_kernels.h new file mode 100644 index 0000000000..8f2180ba80 --- /dev/null +++ b/paddle/infrt/kernel/tensor_kernels.h @@ -0,0 +1,25 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace infrt::host_context { +struct KernelRegistry; +} // namespace infrt::host_context + +namespace infrt::kernel { + +void RegisterTensorKernels(host_context::KernelRegistry* registry); + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/tensor_shape_kernels.cc b/paddle/infrt/kernel/tensor_shape_kernels.cc new file mode 100644 index 0000000000..a04b492819 --- /dev/null +++ b/paddle/infrt/kernel/tensor_shape_kernels.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/tensor_shape_kernels.h" + +#include +#include +#include + +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/tensor/tensor_shape.h" + +namespace infrt::kernel { + +void PrintShape(const tensor::TensorShape& shape) { + llvm::raw_os_ostream oos(std::cout); + oos << shape << '\n'; +} + +void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) { + registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape)); +} + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/tensor_shape_kernels.h b/paddle/infrt/kernel/tensor_shape_kernels.h new file mode 100644 index 0000000000..e87c6c37e8 --- /dev/null +++ b/paddle/infrt/kernel/tensor_shape_kernels.h @@ -0,0 +1,27 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace infrt::host_context { + +class KernelRegistry; + +} // namespace infrt::host_context + +namespace infrt::kernel { + +void RegisterTensorShapeKernels(host_context::KernelRegistry* registry); + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/test_kernels.cc b/paddle/infrt/kernel/test_kernels.cc new file mode 100644 index 0000000000..d5f64d09b6 --- /dev/null +++ b/paddle/infrt/kernel/test_kernels.cc @@ -0,0 +1,200 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/kernel/test_kernels.h" + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "paddle/infrt/host_context/kernel_registry.h" +#include "paddle/infrt/host_context/kernel_utils.h" +#include "paddle/infrt/host_context/mlir_function_executable.h" +#include "paddle/infrt/tensor/dense_host_tensor.h" + +using infrt::host_context::Attribute; +using infrt::host_context::MlirFunctionExecutable; +using infrt::host_context::RemainingArguments; + +namespace infrt::kernel { +namespace { +class BenchmarkStats { + public: + BenchmarkStats(std::string name, + int num_warmup_runs, + int max_count, + std::chrono::microseconds benchmark_duration) + : name_{name}, + num_warmup_runs_{num_warmup_runs}, + max_count_{max_count}, + benchmark_duration_{benchmark_duration} {} + + void StartRun() { + ++cur_count_; + // Start recording CPU time. + cur_start_walltime_ = std::chrono::steady_clock::now(); + cur_start_cpu_ = std::clock(); + } + + void StopRun() { + // Do not collect the runtime statistics if we are still in the warm up + // period. + if (cur_count_ <= num_warmup_runs_) return; + + // Stop the CPU timer. + std::clock_t cur_stop_cpu_ = std::clock(); + + // Stop the wall clock timer. + auto cur_stop_walltime_ = std::chrono::steady_clock::now(); + + // Collect the wall clock duration. + auto duration_walltime_ = cur_stop_walltime_ - cur_start_walltime_; + run_times_walltime_.push_back(duration_walltime_); + + // Collect the CPU duration in microseconds. + // First cast to integer that represents microseconds with truncation, as + // does std::chrono::duration_cast. Then cast to std::chrono::microseconds. + std::clock_t duration_cpu_raw = cur_stop_cpu_ - cur_start_cpu_; + auto duration_cpu_ = static_cast( + static_cast(1e9 * duration_cpu_raw / CLOCKS_PER_SEC)); + + run_times_cpu_.push_back(duration_cpu_); + + total_duration_walltime_ += duration_walltime_; + total_duration_cpu_ += duration_cpu_; + } + // Return if we should we run more rounds. + bool MoreRun() const { + return cur_count_ < max_count_ + num_warmup_runs_ && + total_duration_walltime_ < benchmark_duration_; + } + + // Summarize the benchmark results. + void Summarize() { + std::sort(run_times_walltime_.begin(), run_times_walltime_.end()); + std::sort(run_times_cpu_.begin(), run_times_cpu_.end()); + + auto percentile = []( + double p, const std::vector &run_times) { + assert(p >= 0.0 && p <= 1.0); + return run_times[run_times.size() * p]; + }; + + // BM: prefix is added to make grepping results from lit output easier. + std::string prefix; + llvm::raw_string_ostream(prefix) << "BM:" << name_ << ':'; + auto cpu_utilization = + total_duration_cpu_.count() * 100.0 / total_duration_walltime_.count(); + + llvm::outs() << prefix << "Count: " << run_times_walltime_.size() << '\n'; + llvm::outs() << prefix + << "Duration(ns): " << total_duration_walltime_.count() + << '\n'; + llvm::outs() << prefix + << "Time Min(ns): " << run_times_walltime_.front().count() + << '\n'; + llvm::outs() << prefix + << "Time Max(ns): " << run_times_walltime_.back().count() + << '\n'; + llvm::outs() << prefix << "Time 50%(ns): " + << percentile(0.5, run_times_walltime_).count() << '\n'; + llvm::outs() << prefix << "Time 95%(ns): " + << percentile(0.95, run_times_walltime_).count() << '\n'; + llvm::outs() << prefix << "Time 99%(ns): " + << percentile(0.99, run_times_walltime_).count() << '\n'; + // Log CPU time statistics. + llvm::outs() << prefix + << "CPU Duration(ns): " << total_duration_cpu_.count() << '\n'; + llvm::outs() << prefix << "CPU Min(ns): " << run_times_cpu_.front().count() + << '\n'; + llvm::outs() << prefix << "CPU Max(ns): " << run_times_cpu_.back().count() + << '\n'; + llvm::outs() << prefix + << "CPU 50%(ns): " << percentile(0.5, run_times_cpu_).count() + << '\n'; + llvm::outs() << prefix + << "CPU 95%(ns): " << percentile(0.95, run_times_cpu_).count() + << '\n'; + llvm::outs() << prefix + << "CPU 99%(ns): " << percentile(0.99, run_times_cpu_).count() + << '\n'; + llvm::outs() << prefix << "CPU utilization(percent): " << cpu_utilization + << "\n"; + llvm::outs().flush(); + } + + private: + const std::string name_; + const int num_warmup_runs_; + const int max_count_; + int cur_count_ = 0; + const std::chrono::nanoseconds benchmark_duration_; + std::chrono::nanoseconds total_duration_walltime_{}; + std::chrono::nanoseconds total_duration_cpu_{}; + std::chrono::time_point cur_start_walltime_{}; + std::clock_t cur_start_cpu_; + std::vector run_times_walltime_; + // CPU run times in microseconds. + std::vector run_times_cpu_; +}; + +} // anonymous namespace + +// This op benchmarks the input function by running the function in a loop +// up to a max count or max time as specified in the function's attributes. +// +// Attributes: +// duration_secs: Benchmark duration in seconds. +// max_count: Max run count of input function. +// name: The name used to tag the benchmark results. +// num_warmup_runs: Number of warm up runs before benchmarking starts. +// fn: The input function to be benchmarked. +static void benchmark(RemainingArguments args, + host_context::RemainingResults results, + Attribute duration_secs, + Attribute max_count, + Attribute name, + Attribute num_warmup_runs, + Attribute fn) { + BenchmarkStats bm_stats{name.get(), + num_warmup_runs.get(), + max_count.get(), + std::chrono::seconds(duration_secs.get())}; + + while (bm_stats.MoreRun()) { + bm_stats.StartRun(); + fn.get()->Execute(args.values(), results.values(), true); + bm_stats.StopRun(); + } + bm_stats.Summarize(); +} + +// Just copy the input to the result. +tensor::DenseHostTensor ShadowCopyTensor(tensor::DenseHostTensor src) { + return src; +} + +void RegisterTestKernels(host_context::KernelRegistry *registry) { + registry->AddKernel("infrt.benchmark", INFRT_KERNEL(benchmark)); + registry->AddKernel("infrt.test.shadow_copy_tensor", + INFRT_KERNEL(ShadowCopyTensor)); +} + +} // namespace infrt::kernel diff --git a/paddle/infrt/kernel/test_kernels.h b/paddle/infrt/kernel/test_kernels.h new file mode 100644 index 0000000000..f42884dfaf --- /dev/null +++ b/paddle/infrt/kernel/test_kernels.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +namespace infrt::host_context { + +struct KernelRegistry; + +} // namespace infrt::host_context + +namespace infrt::kernel { + +/** + * Register all the test kernels to registry. + */ +void RegisterTestKernels(host_context::KernelRegistry* registry); + +} // namespace infrt::kernel diff --git a/paddle/infrt/paddle/CMakeLists.txt b/paddle/infrt/paddle/CMakeLists.txt new file mode 100644 index 0000000000..172d78ecde --- /dev/null +++ b/paddle/infrt/paddle/CMakeLists.txt @@ -0,0 +1,24 @@ +proto_library(paddle_framework_proto SRCS framework.proto) + +add_subdirectory(cpp) +add_subdirectory(pb) + +core_gather_headers() + +gather_srcs(infrt_src SRCS + model_parser.cc + scope.cc + tensor.cc + ) + +foreach(cpp ${SRCS}) + set(infrt_src + "${infrt_src};infrt/paddle/${cpp}" + CACHE INTERNAL "") +endforeach() + +file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + +foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/infrt/paddle/cpp/CMakeLists.txt b/paddle/infrt/paddle/cpp/CMakeLists.txt new file mode 100644 index 0000000000..0feaabd2fa --- /dev/null +++ b/paddle/infrt/paddle/cpp/CMakeLists.txt @@ -0,0 +1,16 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + ) + +foreach(cpp ${SRCS}) + set(infrt_src + "${infrt_src};infrt/paddle/cpp/${cpp}" + CACHE INTERNAL "") +endforeach() + +file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + +foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/infrt/paddle/cpp/desc_api.h b/paddle/infrt/paddle/cpp/desc_api.h new file mode 100644 index 0000000000..ccd79c048a --- /dev/null +++ b/paddle/infrt/paddle/cpp/desc_api.h @@ -0,0 +1,229 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +namespace infrt::paddle::cpp { + +/* + * Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc + * classes should implement this. + */ +class VarDescAPI { + public: + enum class Type { + // Pod Types + BOOL = 0, + INT16, + INT32, + INT64, + FP16, + FP32, + FP64, + // Tensor is used in C++. + SIZE_T, + UINT8, + INT8, + + // Other types that may need additional descriptions + LOD_TENSOR, + SELECTED_ROWS, + FEED_MINIBATCH, + FETCH_LIST, + STEP_SCOPES, + LOD_RANK_TABLE, + LOD_TENSOR_ARRAY, + PLACE_LIST, + READER, + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW, + TUPLE + }; + + using VarDataType = Type; + + virtual ~VarDescAPI() = default; + + // Get var's name + virtual std::string Name() const = 0; + // Set var's name + virtual void SetName(std::string name) = 0; + // Get var's type + virtual Type GetType() const = 0; + // Set var's type + virtual void SetType(Type type) = 0; + // Tell whether var is persistable or not + virtual bool Persistable() const = 0; + // Set var to be persistable or not + virtual void SetPersistable(bool persistable) = 0; + // Get var's shape + virtual std::vector GetShape() const = 0; + // Set var's shape + virtual void SetShape(const std::vector& dims) = 0; +}; + +/* + * NOTE Some interfaces are weried, we remain them unchanged to keep compatible + * with framework::OpDesc in Fluid framework. + */ +class OpDescAPI { + public: + // The AttrType is used to make the proto::AttrType portable. + enum class AttrType { + INT = 0, + FLOAT = 1, + STRING = 2, + INTS = 3, + FLOATS = 4, + STRINGS = 5, + BOOLEAN = 6, + BOOLEANS = 7, + BLOCK = 8, + LONG = 9, + BLOCKS = 10, + LONGS = 11, + UNK, + }; + + virtual ~OpDescAPI() = default; + + /// Get operator's type. + virtual std::string Type() const = 0; + /// Set operator's type. + virtual void SetType(const std::string& type) = 0; + /// Get arguments given the parameter. + virtual std::vector Input(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector InputArgumentNames() const = 0; + /// Get arguments given the parameter. + virtual std::vector Output(const std::string& param) const = 0; + /// Get parameters. + virtual std::vector OutputArgumentNames() const = 0; + /// Set a input given the parameter and arguments. + virtual void SetInput(const std::string& param, + const std::vector& args) = 0; + virtual void SetOutput(const std::string& param, + const std::vector& args) = 0; + /// Tell whether this desc has an attribute. + virtual bool HasAttr(const std::string& name) const = 0; + + /// Get the type of an attribute. + virtual AttrType GetAttrType(const std::string& name) const = 0; + + virtual std::vector AttrNames() const = 0; + + /// Set an attribute. + template + void SetAttr(const std::string& name, const T& v); + + /// Get an attribute. + template + T GetAttr(const std::string& name) const; + + std::string Repr() const { + std::stringstream ss; + ss << Type(); + ss << "("; + for (auto& arg : InputArgumentNames()) { + ss << arg << ":"; + for (auto val : Input(arg)) { + ss << val << " "; + } + } + ss << ") -> ("; + for (auto& arg : OutputArgumentNames()) { + ss << arg << ":"; + for (auto val : Output(arg)) { + ss << val << " "; + } + } + ss << ")"; + return ss.str(); + } +}; + +class BlockDescAPI { + public: + virtual ~BlockDescAPI() = default; + + virtual int32_t Idx() const = 0; + + virtual void SetIdx(int32_t idx) = 0; + + virtual int32_t ParentIdx() const = 0; + + virtual void SetParentIdx(int32_t idx) = 0; + + virtual size_t VarsSize() const = 0; + + virtual void ClearVars() = 0; + + // NOTE: This ugly method is used to compatible interfaces between cpp and + // pb/nb backends + // TODO(sangoly): refine this + template + T* GetVar(int32_t idx); + + template + T* AddVar(); + + virtual size_t OpsSize() const = 0; + + virtual void ClearOps() = 0; + + // NOTE: This ugly method is used to compatible interfaces between cpp and + // pb/nb backends + // TODO(sangoly): refine this + template + T* GetOp(int32_t idx); + + template + T* AddOp(); + + virtual int32_t ForwardBlockIdx() const = 0; + + virtual void SetForwardBlockIdx(int32_t idx) = 0; +}; + +class ProgramDescAPI { + public: + virtual ~ProgramDescAPI() = default; + + virtual size_t BlocksSize() const = 0; + + virtual void ClearBlocks() = 0; + + // NOTE: This ugly method is used to compatible interfaces between cpp and + // pb/nb backends + // TODO(sangoly): refine this + template + T* GetBlock(int32_t idx); + + template + T* AddBlock(); + + virtual bool HasVersion() const = 0; + + virtual int64_t Version() const = 0; + + virtual void SetVersion(int64_t version) = 0; +}; + +} // namespace infrt::paddle::cpp diff --git a/paddle/infrt/paddle/framework.proto b/paddle/infrt/paddle/framework.proto new file mode 100644 index 0000000000..634ec9665d --- /dev/null +++ b/paddle/infrt/paddle/framework.proto @@ -0,0 +1,213 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +syntax = "proto2"; +package paddle.framework.proto; + +// Any incompatible changes to ProgramDesc and its dependencies should +// raise the version defined version.h. +// +// Serailization and Deserialization codes should be modified in a way +// that supports old versions following the version and compatibility policy. +message Version { optional int64 version = 1 [ default = 0 ]; } + +enum AttrType { + INT = 0; + FLOAT = 1; + STRING = 2; + INTS = 3; + FLOATS = 4; + STRINGS = 5; + BOOLEAN = 6; + BOOLEANS = 7; + BLOCK = 8; + LONG = 9; + BLOCKS = 10; + LONGS = 11; +} + +// OpDesc describes an instance of a C++ framework::OperatorBase +// derived class type. +message OpDesc { + + message Attr { + required string name = 1; + required AttrType type = 2; + optional int32 i = 3; + optional float f = 4; + optional string s = 5; + repeated int32 ints = 6; + repeated float floats = 7; + repeated string strings = 8; + optional bool b = 10; + repeated bool bools = 11; + optional int32 block_idx = 12; + optional int64 l = 13; + repeated int32 blocks_idx = 14; + repeated int64 longs = 15; + }; + + message Var { + required string parameter = 1; + repeated string arguments = 2; + }; + + required string type = 3; + repeated Var inputs = 1; + repeated Var outputs = 2; + repeated Attr attrs = 4; + optional bool is_target = 5 [ default = false ]; +}; + +// OpProto describes a C++ framework::OperatorBase derived class. +message OpProto { + + // VarProto describes the C++ type framework::Variable. + message Var { + required string name = 1; + required string comment = 2; + + optional bool duplicable = 3 [ default = false ]; + optional bool intermediate = 4 [ default = false ]; + optional bool dispensable = 5 [ default = false ]; + } + + // AttrProto describes the C++ type Attribute. + message Attr { + required string name = 1; + required AttrType type = 2; + required string comment = 3; + // If that attribute is generated, it means the Paddle third + // language binding has responsibility to fill that + // attribute. End-User should not set that attribute. + optional bool generated = 4 [ default = false ]; + } + + required string type = 1; + repeated Var inputs = 2; + repeated Var outputs = 3; + repeated Attr attrs = 4; + required string comment = 5; +} + +message VarType { + enum Type { + // Pod Types + BOOL = 0; + INT16 = 1; + INT32 = 2; + INT64 = 3; + FP16 = 4; + FP32 = 5; + FP64 = 6; + // Tensor is used in C++. + SIZE_T = 19; + UINT8 = 20; + INT8 = 21; + + // Other types that may need additional descriptions + LOD_TENSOR = 7; + SELECTED_ROWS = 8; + FEED_MINIBATCH = 9; + FETCH_LIST = 10; + STEP_SCOPES = 11; + LOD_RANK_TABLE = 12; + LOD_TENSOR_ARRAY = 13; + PLACE_LIST = 14; + READER = 15; + // Any runtime decided variable type is raw + // raw variables should manage their own allocations + // in operators like nccl_op + RAW = 17; + TUPLE = 18; + } + + required Type type = 1; + + message TensorDesc { + // Should only be PODType. Is enforced in C++ + required Type data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] + } + optional TensorDesc selected_rows = 2; + + message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorDesc lod_tensor = 3; + + message LoDTensorArrayDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; + } + optional LoDTensorArrayDesc tensor_array = 4; + + message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; } + optional ReaderDesc reader = 5; + + message Tuple { repeated Type element_type = 1; } + optional Tuple tuple = 7; +} + +message VarDesc { + required string name = 1; + required VarType type = 2; + optional bool persistable = 3 [ default = false ]; + // True if the variable is an input data and + // have to check the feed data shape and dtype + optional bool need_check_feed = 4 [ default = false ]; +} + +message BlockDesc { + required int32 idx = 1; + required int32 parent_idx = 2; + repeated VarDesc vars = 3; + repeated OpDesc ops = 4; + optional int32 forward_block_idx = 5 [ default = -1 ]; +} + +// CompatibleInfo is used to determine if a feature is compatible and +// provides the information. +message CompatibleInfo { + enum Type { + COMPATIBLE = 0; + DEFINITELY_NOT = 1; + POSSIBLE = 2; + BUG_FIX = 3; + PRECISION_CHANGE = 4; + } + required string version = 1; + required Type type = 2; +} + +// In some cases, Paddle Fluid may perform operator definition iterations, +// and the operator uses OpCompatibleMap for compatibility testing. +message OpCompatibleMap { + message OpCompatiblePair { + required string op_name = 1; + required CompatibleInfo compatible_info = 2; + } + repeated OpCompatiblePair pair = 1; + optional string default_required_version = 2; +} + +// Please refer to +// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md +// for more details. +// TODO(panyx0718): A model can have multiple programs. Need a +// way to distinguish them. Maybe ID or name? +message ProgramDesc { + reserved 2; // For backward compatibility. + repeated BlockDesc blocks = 1; + optional Version version = 4; + optional OpCompatibleMap op_compatible_map = 3; +} \ No newline at end of file diff --git a/paddle/infrt/paddle/model_parser.cc b/paddle/infrt/paddle/model_parser.cc new file mode 100644 index 0000000000..285280e694 --- /dev/null +++ b/paddle/infrt/paddle/model_parser.cc @@ -0,0 +1,172 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/model_parser.h" + +#include +#include + +#include "paddle/infrt/common/common.h" +#include "paddle/infrt/common/string.h" +#include "paddle/infrt/common/target.h" +#include "paddle/infrt/common/type.h" + +namespace infrt::paddle { + +int SizeOfType(framework_proto::VarType::Type type) { + using Type = framework_proto::VarType::Type; + switch (static_cast(type)) { +#define DO(desc, type) \ + case Type::VarType_Type_##desc: \ + return sizeof(type); + DO(BOOL, bool); + DO(FP16, float); + DO(FP32, float); + DO(INT8, int8_t); + DO(INT16, int16_t); + DO(INT32, int); + DO(INT64, int64_t); +#undef DO + default: + LOG(FATAL) << "unknown data type " << type; + } + return -1; +} + +void TensorFromStream(std::istream &is, + _Tensor_ *tensor, + const common::Target &target) { + using Type = framework_proto::VarType::Type; + uint32_t version; + is.read(reinterpret_cast(&version), sizeof(version)); + CHECK_EQ(version, 0U) << "Only version 0 is supported"; + // read tensor desc + framework_proto::VarType::TensorDesc desc; + { + // int32_t size + // proto buffer + int32_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::unique_ptr buf(new char[size]); + is.read(reinterpret_cast(buf.get()), size); + CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc"; + } + + // read tensor + std::vector dims_vec; + std::copy( + desc.dims().begin(), desc.dims().end(), std::back_inserter(dims_vec)); + Shape dims(dims_vec); + tensor->Resize(dims); + void *buf; + size_t size = tensor->shape().numel() * SizeOfType(desc.data_type()); + // alllocate memory + if (target.arch == Target::Arch::X86) { + switch (static_cast(desc.data_type())) { +#define SET_TENSOR(desc, type, precision) \ + case Type::VarType_Type_##desc: \ + buf = tensor->mutable_data(target); \ + tensor->set_type(precision); \ + break + + SET_TENSOR(FP32, float, Float(32)); + SET_TENSOR(INT8, int8_t, Int(8)); + SET_TENSOR(INT16, int16_t, Int(16)); + SET_TENSOR(INT32, int32_t, Int(32)); + SET_TENSOR(INT64, int64_t, Int(64)); +#undef SET_TENSOR + default: + LOG(FATAL) << "unknown type " << desc.data_type(); + } + // tensor->set_persistable(true); + is.read(static_cast(buf), size); + } else if (target.arch == Target::Arch::NVGPU) { +#ifdef INFRT_WITH_CUDA + if (desc.data_type() != Type::VarType_Type_FP32) + LOG(FATAL) << "[CUDA] The type is not fp32!!"; + auto *data = tensor->mutable_data(target); + tensor->set_type(infrt::common::Float(32)); + std::vector temp(tensor->shape().numel()); + // LOG(INFO) <<"[CUDA] The tensor's size is "<< tensor->shape().numel(); + is.read(reinterpret_cast(temp.data()), size); + CUDA_CALL(cudaMemcpy(reinterpret_cast(data), + temp.data(), + tensor->shape().numel() * sizeof(float), + cudaMemcpyHostToDevice)); +#else + LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!"; +#endif + } else { + INFRT_NOT_IMPLEMENTED + } +} + +void LoadLoDTensor(std::istream &is, _Variable *var, const Target &target) { + auto &tensor = var->get(); + uint32_t version{}; + is.read(reinterpret_cast(&version), sizeof(version)); + VLOG(3) << "model version " << version; + + // Load LoD information + uint64_t lod_level{}; + is.read(reinterpret_cast(&lod_level), sizeof(lod_level)); + + for (uint64_t i = 0; i < lod_level; ++i) { + uint64_t size; + is.read(reinterpret_cast(&size), sizeof(size)); + std::vector tmp(size / sizeof(uint64_t)); + is.read(reinterpret_cast(tmp.data()), + static_cast(size)); + // lod[i] = tmp; + } + + TensorFromStream(is, tensor.operator->(), target); +} + +void ReadBinaryFile(const std::string &filename, std::string *contents) { + std::ifstream fin(filename, std::ios::in | std::ios::binary); + CHECK(fin.is_open()) << "Cannot open file: " << filename; + fin.seekg(0, std::ios::end); + auto size = fin.tellg(); + contents->clear(); + contents->resize(size); + fin.seekg(0, std::ios::beg); + fin.read(&(contents->at(0)), contents->size()); + fin.close(); +} + +std::unique_ptr LoadProgram( + const std::string &path, bool program_from_memory) { + std::unique_ptr main_program( + new framework_proto::ProgramDesc); + if (!program_from_memory) { + std::string desc_str; + ReadBinaryFile(path, &desc_str); + main_program->ParseFromString(desc_str); + } else { + main_program->ParseFromString(path); + } + return main_program; +} + +void LoadParams(const std::string &path) {} + +// Load directly to CPU, and latter transfer to other devices. +void LoadParam(const std::string &path, _Variable *out, const Target &target) { + std::ifstream fin(path, std::ios::binary); + CHECK(fin.is_open()) << "failed to open file " << path; + LoadLoDTensor(fin, out, target); +} + +} // namespace infrt::paddle diff --git a/paddle/infrt/paddle/model_parser.h b/paddle/infrt/paddle/model_parser.h new file mode 100644 index 0000000000..73125faded --- /dev/null +++ b/paddle/infrt/paddle/model_parser.h @@ -0,0 +1,55 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include +#include + +#include "paddle/infrt/paddle/framework.pb.h" +#include "paddle/infrt/paddle/pb/block_desc.h" +#include "paddle/infrt/paddle/pb/op_desc.h" +#include "paddle/infrt/paddle/pb/program_desc.h" +#include "paddle/infrt/paddle/scope.h" +#include "paddle/infrt/paddle/tensor.h" + +namespace infrt::paddle { +namespace framework_proto = ::paddle::framework::proto; + +// Read a __model__ file. +std::unique_ptr LoadProgram( + const std::string& path, bool program_from_memory = false); + +void LoadLoDTensor(std::istream& is, + _Variable* var, + const common::Target& target); + +// Read a single file containing all the parameters. +void LoadParams(const std::string& path); + +// Load a single parameter to an output tensor. +void LoadParam(const std::string& path, + _Variable* out, + const common::Target& target); + +// LoDTensor to ostream +void TensorToStream(std::ostream& os, const _Tensor_& tensor); +void TensorFromStream( + std::istream& is, + _Tensor_* tensor, + const common::Target& target = common::DefaultHostTarget()); +void ReadBinaryFile(const std::string& filename, std::string* contents); + +} // namespace infrt::paddle diff --git a/paddle/infrt/paddle/pb/CMakeLists.txt b/paddle/infrt/paddle/pb/CMakeLists.txt new file mode 100644 index 0000000000..fac38afa62 --- /dev/null +++ b/paddle/infrt/paddle/pb/CMakeLists.txt @@ -0,0 +1,20 @@ +core_gather_headers() + +gather_srcs(infrt_src SRCS + var_desc.cc + op_desc.cc + block_desc.cc + program_desc.cc + ) + +foreach(cpp ${SRCS}) + set(infrt_src + "${infrt_src};infrt/paddle/pb/${cpp}" + CACHE INTERNAL "") +endforeach() + +file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h) + +foreach(header ${includes}) + set(core_includes "${core_includes};${header}" CACHE INTERNAL "") +endforeach() diff --git a/paddle/infrt/paddle/pb/block_desc.cc b/paddle/infrt/paddle/pb/block_desc.cc new file mode 100644 index 0000000000..11186bc68a --- /dev/null +++ b/paddle/infrt/paddle/pb/block_desc.cc @@ -0,0 +1,43 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/pb/block_desc.h" + +namespace infrt::paddle::pb { + +template <> +framework_proto::VarDesc* BlockDesc::GetVar( + int32_t idx) { + CHECK_LT(idx, static_cast(VarsSize())) << "idx >= vars.size()"; + return desc_->mutable_vars(idx); +} + +template <> +framework_proto::VarDesc* BlockDesc::AddVar() { + return desc_->add_vars(); +} + +template <> +framework_proto::OpDesc* BlockDesc::GetOp( + int32_t idx) { + CHECK_LT(idx, static_cast(OpsSize())) << "idx >= ops.size()"; + return desc_->mutable_ops(idx); +} + +template <> +framework_proto::OpDesc* BlockDesc::AddOp() { + return desc_->add_ops(); +} + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/block_desc.h b/paddle/infrt/paddle/pb/block_desc.h new file mode 100644 index 0000000000..9c1b7f9adf --- /dev/null +++ b/paddle/infrt/paddle/pb/block_desc.h @@ -0,0 +1,77 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/infrt/paddle/cpp/desc_api.h" +#include "paddle/infrt/paddle/framework.pb.h" + +namespace infrt::paddle::pb { + +namespace framework_proto = ::paddle::framework::proto; + +class BlockDesc : public cpp::BlockDescAPI { + public: + BlockDesc() = delete; + + explicit BlockDesc(framework_proto::BlockDesc* desc) : desc_(desc) { + CHECK(desc_); + } + + framework_proto::BlockDesc* Proto() { return desc_; } + + const framework_proto::BlockDesc& ReadonlyProto() const { return *desc_; } + + int32_t Idx() const override { return desc_->idx(); } + + void SetIdx(int32_t idx) override { desc_->set_idx(idx); } + + int32_t ParentIdx() const override { return desc_->parent_idx(); } + + void SetParentIdx(int32_t idx) override { desc_->set_parent_idx(idx); } + + size_t VarsSize() const override { return desc_->vars_size(); } + + void ClearVars() override { desc_->clear_vars(); } + + template + T* GetVar(int32_t idx); + + template + T* AddVar(); + + size_t OpsSize() const override { return desc_->ops_size(); } + + void ClearOps() override { desc_->clear_ops(); } + + template + T* GetOp(int32_t idx); + + template + T* AddOp(); + + int32_t ForwardBlockIdx() const override { + return desc_->forward_block_idx(); + } + + void SetForwardBlockIdx(int32_t idx) override { + desc_->set_forward_block_idx(idx); + } + + private: + framework_proto::BlockDesc* desc_; // not_own +}; + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/op_desc.cc b/paddle/infrt/paddle/pb/op_desc.cc new file mode 100644 index 0000000000..c7b1e66f50 --- /dev/null +++ b/paddle/infrt/paddle/pb/op_desc.cc @@ -0,0 +1,139 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/pb/op_desc.h" + +namespace infrt::paddle::pb { + +google::protobuf::internal::RepeatedPtrIterator +FindAttr(framework_proto::OpDesc *desc, const std::string &name) { + auto &xs = *desc->mutable_attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); + if (it == xs.end()) { + auto *attr = xs.Add(); + attr->set_name(name); + it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); + } + return it; +} + +#define SET_IMPL_ONE(T, ty__, pb_f__) \ + template <> \ + void OpDesc::SetAttr(const std::string &name, const T &v) { \ + auto it = FindAttr(desc_, name); \ + it->set_type(framework_proto::ty__); \ + it->set_##pb_f__(v); \ + } +SET_IMPL_ONE(int, INT, i); +SET_IMPL_ONE(float, FLOAT, f); +SET_IMPL_ONE(bool, BOOLEAN, b); +SET_IMPL_ONE(int64_t, LONG, l); + +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework_proto::INTS); + it->clear_ints(); + for (auto &i : v) { + it->add_ints(i); + } +} + +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework_proto::STRING); + it->set_s(v.c_str()); +} + +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework_proto::FLOATS); + it->clear_floats(); + for (auto &i : v) { + it->add_floats(i); + } +} + +template <> +void OpDesc::SetAttr>( + const std::string &name, const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework_proto::STRINGS); + it->clear_strings(); + for (auto &i : v) { + it->add_strings(i); + } +} + +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v) { + auto it = FindAttr(desc_, name); + it->set_type(framework_proto::LONGS); + it->clear_longs(); + for (auto &i : v) { + it->add_longs(i); + } +} +google::protobuf::internal::RepeatedPtrIterator< + const framework_proto::OpDesc_Attr> +GetFindAttr(const framework_proto::OpDesc &desc, const std::string &name) { + auto &xs = desc.attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); + return it; +} + +#define GET_ATTR_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(*desc_, name); \ + return it->pb_f__(); \ + } + +#define GET_ATTRS_IMPL(T, pb_f__) \ + template <> \ + T OpDesc::GetAttr(const std::string &name) const { \ + auto it = GetFindAttr(*desc_, name); \ + T res; \ + for (const auto &v : it->pb_f__()) { \ + res.push_back(v); \ + } \ + return res; \ + } +GET_ATTR_IMPL(int32_t, i); +GET_ATTR_IMPL(int16_t, block_idx); +GET_ATTR_IMPL(float, f); +GET_ATTR_IMPL(bool, b); +GET_ATTR_IMPL(int64_t, l); +GET_ATTRS_IMPL(std::vector, ints); +GET_ATTRS_IMPL(std::vector, floats); +GET_ATTRS_IMPL(std::vector, strings); +GET_ATTR_IMPL(std::string, s); +GET_ATTRS_IMPL(std::vector, longs); + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/op_desc.h b/paddle/infrt/paddle/pb/op_desc.h new file mode 100644 index 0000000000..81d57d9f32 --- /dev/null +++ b/paddle/infrt/paddle/pb/op_desc.h @@ -0,0 +1,198 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/infrt/paddle/cpp/desc_api.h" +#include "paddle/infrt/paddle/framework.pb.h" +#include "paddle/infrt/support/variant.h" + +namespace infrt::paddle::pb { + +namespace framework_proto = ::paddle::framework::proto; + +using Attribute = + Variant, std::vector>; +using VariableNameMap = std::map>; + +/* + * The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc. + * Unlike the original one in framework::OpDesc, we remove the local members + * except the desc_, to avoid the inconsistent state, which is normal in the + * original interface and results in bugs. + */ +class OpDesc : public cpp::OpDescAPI { + public: + OpDesc() = delete; + + explicit OpDesc(framework_proto::OpDesc *desc) : desc_(desc) { CHECK(desc_); } + + framework_proto::OpDesc *Proto() { return desc_; } + const framework_proto::OpDesc &ReadonlyProto() const { return *desc_; } + + std::string Type() const override { return desc_->type(); } + + void SetType(const std::string &type) override { desc_->set_type(type); } + + // Get the arguments of parameter called `param` + std::vector Input(const std::string ¶m) const override { + return GetArguments(desc_->inputs(), param); + } + + std::vector InputArgumentNames() const override { + return GetArgumentNames(desc_->inputs()); + } + + void SetInput(const std::string ¶m, + const std::vector &args) override { + SetArgument(desc_->mutable_inputs(), param, args); + } + + std::vector Output(const std::string ¶m) const override { + return GetArguments(desc_->outputs(), param); + } + + std::vector OutputArgumentNames() const override { + return GetArgumentNames(desc_->outputs()); + } + + void SetOutput(const std::string ¶m, + const std::vector &args) override { + SetArgument(desc_->mutable_outputs(), param, args); + } + + bool HasAttr(const std::string &name) const override { + const auto &xs = desc_->attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); + return it != xs.end(); + } + + AttrType GetAttrType(const std::string &name) const override { + const auto &xs = desc_->attrs(); + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) { + return x.name() == name; + }); + CHECK(it != xs.end()); +#define DEF_ONE(type__) \ + case framework_proto::AttrType::type__: \ + return AttrType::type__; + + switch (it->type()) { + DEF_ONE(INT); + DEF_ONE(FLOAT); + DEF_ONE(STRING); + DEF_ONE(INTS); + DEF_ONE(FLOATS); + DEF_ONE(STRINGS); + DEF_ONE(BOOLEAN); + DEF_ONE(BOOLEANS); + DEF_ONE(BLOCK); + DEF_ONE(LONG); + DEF_ONE(BLOCKS); + DEF_ONE(LONGS); + default: + LOG(FATAL) << "Unknown attribute type"; + return static_cast(-1); + } +#undef DEF_ONE + } + + std::vector AttrNames() const override { + std::vector res; + const auto &xs = desc_->attrs(); + std::transform( + xs.begin(), + xs.end(), + std::back_inserter(res), + [](const framework_proto::OpDesc_Attr &x) { return x.name(); }); + return res; + } + + template + void SetAttr(const std::string &name, const T &v); + + template + T GetAttr(const std::string &name) const; + + private: + std::vector GetArguments( + const google::protobuf::RepeatedPtrField &xs, + const std::string ¶m) const { + std::vector res; + auto it = std::find_if( + xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Var &it) { + return it.parameter() == param; + }); + CHECK(it != xs.end()); + + const auto &ys = it->arguments(); + std::transform(ys.begin(), + ys.end(), + std::back_inserter(res), + [](const std::string &x) { return x; }); + return res; + } + + void SetArgument( + google::protobuf::RepeatedPtrField *xs, + const std::string ¶m, + const std::vector &args) { + auto it = std::find_if( + xs->begin(), xs->end(), [&](const framework_proto::OpDesc_Var &it) { + return it.parameter() == param; + }); + if (it == xs->end()) { + auto *new_arg = xs->Add(); + new_arg->set_parameter(param); + for (const auto &arg : args) { + *new_arg->mutable_arguments()->Add() = arg; + } + } else { + it->mutable_arguments()->Clear(); + for (const auto &arg : args) { + *it->mutable_arguments()->Add() = arg; + } + } + } + + std::vector GetArgumentNames( + const google::protobuf::RepeatedPtrField &xs) + const { + std::vector res; + std::transform( + xs.begin(), + xs.end(), + std::back_inserter(res), + [](const framework_proto::OpDesc_Var &x) { return x.parameter(); }); + return res; + } + + private: + framework_proto::OpDesc *desc_; +}; + +template <> +void OpDesc::SetAttr(const std::string &name, + const std::string &v); + +template <> +void OpDesc::SetAttr>(const std::string &name, + const std::vector &v); + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/program_desc.cc b/paddle/infrt/paddle/pb/program_desc.cc new file mode 100644 index 0000000000..ed8a7e36e0 --- /dev/null +++ b/paddle/infrt/paddle/pb/program_desc.cc @@ -0,0 +1,35 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/pb/program_desc.h" + +#include +#include + +namespace infrt::paddle::pb { + +template <> +framework_proto::BlockDesc* ProgramDesc::GetBlock( + int32_t idx) { + CHECK_LT(idx, static_cast(BlocksSize())) << "idx >= blocks.size()"; + return desc_->mutable_blocks(idx); +} + +template <> +framework_proto::BlockDesc* +ProgramDesc::AddBlock() { + return desc_->add_blocks(); +} + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/program_desc.h b/paddle/infrt/paddle/pb/program_desc.h new file mode 100644 index 0000000000..4adad650c9 --- /dev/null +++ b/paddle/infrt/paddle/pb/program_desc.h @@ -0,0 +1,61 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include + +#include "paddle/infrt/paddle/cpp/desc_api.h" +#include "paddle/infrt/paddle/framework.pb.h" + +namespace infrt::paddle::pb { +namespace framework_proto = ::paddle::framework::proto; + +class ProgramDesc : public cpp::ProgramDescAPI { + public: + ProgramDesc() = delete; + + explicit ProgramDesc(framework_proto::ProgramDesc *desc) : desc_(desc) { + CHECK(desc_); + } + + framework_proto::ProgramDesc *Proto() { return desc_; } + + const framework_proto::ProgramDesc &ReadonlyProto() const { return *desc_; } + + size_t BlocksSize() const override { return desc_->blocks_size(); } + + void ClearBlocks() override { desc_->clear_blocks(); } + + template + T *GetBlock(int32_t idx); + + template + T *AddBlock(); + + bool HasVersion() const override { return desc_->has_version(); } + + int64_t Version() const override { return desc_->version().version(); } + + void SetVersion(int64_t version) override { + desc_->mutable_version()->set_version(version); + } + + private: + framework_proto::ProgramDesc *desc_; // not_own +}; + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/var_desc.cc b/paddle/infrt/paddle/pb/var_desc.cc new file mode 100644 index 0000000000..cf80df4f1b --- /dev/null +++ b/paddle/infrt/paddle/pb/var_desc.cc @@ -0,0 +1,367 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/pb/var_desc.h" + +#include + +#include "paddle/infrt/paddle/cpp/desc_api.h" +#include "paddle/infrt/paddle/framework.pb.h" + +namespace infrt::paddle::pb { + +cpp::VarDescAPI::Type VarDesc::GetType() const { + auto type = desc_->type().type(); + +#define GET_TYPE_CASE_ITEM(type__) \ + case framework_proto::VarType::type__: \ + return cpp::VarDescAPI::Type::type__; + + switch (type) { + GET_TYPE_CASE_ITEM(LOD_TENSOR); + GET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY); + GET_TYPE_CASE_ITEM(LOD_RANK_TABLE); + GET_TYPE_CASE_ITEM(SELECTED_ROWS); + GET_TYPE_CASE_ITEM(FEED_MINIBATCH); + GET_TYPE_CASE_ITEM(FETCH_LIST); + GET_TYPE_CASE_ITEM(STEP_SCOPES); + GET_TYPE_CASE_ITEM(PLACE_LIST); + GET_TYPE_CASE_ITEM(READER); + default: + LOG(FATAL) << "Unknown var type"; + return VarDescAPI::Type(); + } +#undef GET_TYPE_CASE_ITEM +} + +void VarDesc::SetType(VarDescAPI::Type type) { +#define SET_TYPE_CASE_ITEM(type__) \ + case VarDescAPI::Type::type__: \ + desc_->mutable_type()->set_type(framework_proto::VarType::type__); \ + break; + + switch (type) { + SET_TYPE_CASE_ITEM(LOD_TENSOR); + SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY); + SET_TYPE_CASE_ITEM(LOD_RANK_TABLE); + SET_TYPE_CASE_ITEM(SELECTED_ROWS); + SET_TYPE_CASE_ITEM(FEED_MINIBATCH); + SET_TYPE_CASE_ITEM(FETCH_LIST); + SET_TYPE_CASE_ITEM(STEP_SCOPES); + SET_TYPE_CASE_ITEM(PLACE_LIST); + SET_TYPE_CASE_ITEM(READER); + default: + LOG(FATAL) << "Unknown var type"; + } +#undef SET_TYPE_CASE_ITEM +} + +void VarDesc::SetShape(const std::vector &dims) { + VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); +} + +void VarDesc::SetTensorDescNum(size_t num) { + switch (desc_->type().type()) { + case framework_proto::VarType::READER: { + auto *lod_tensors_ptr = + desc_->mutable_type()->mutable_reader()->mutable_lod_tensor(); + lod_tensors_ptr->Clear(); + for (size_t i = 0; i < num; ++i) { + lod_tensors_ptr->Add(); + } + return; + } break; + default: + LOG(FATAL) << "Setting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + } +} + +size_t VarDesc::GetTensorDescNum() const { + switch (desc_->type().type()) { + case framework_proto::VarType::READER: + return desc_->type().reader().lod_tensor_size(); + break; + default: + LOG(FATAL) << "Getting 'sub_tensor_number' is not supported by the type " + "of var %s." + << this->Name(); + } + return 0; +} + +void VarDesc::SetShapes( + const std::vector> &multiple_dims) { + if (multiple_dims.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_dims.size()); + } + std::vector tensors = + mutable_tensor_descs(); + for (size_t i = 0; i < multiple_dims.size(); ++i) { + VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims()); + } +} + +std::vector VarDesc::GetShape() const { + return RepeatedToVector(tensor_desc().dims()); +} + +std::vector> VarDesc::GetShapes() const { + std::vector descs = tensor_descs(); + std::vector> res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(RepeatedToVector(tensor_desc.dims())); + } + return res; +} + +void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) { +#define SET_DATA_TYPE_CASE_ITEM(type__) \ + case cpp::VarDescAPI::Type::type__: \ + mutable_tensor_desc()->set_data_type(framework_proto::VarType::type__); \ + break; + + switch (data_type) { + SET_DATA_TYPE_CASE_ITEM(BOOL); + SET_DATA_TYPE_CASE_ITEM(SIZE_T); + SET_DATA_TYPE_CASE_ITEM(UINT8); + SET_DATA_TYPE_CASE_ITEM(INT8); + SET_DATA_TYPE_CASE_ITEM(INT16); + SET_DATA_TYPE_CASE_ITEM(INT32); + SET_DATA_TYPE_CASE_ITEM(INT64); + SET_DATA_TYPE_CASE_ITEM(FP16); + SET_DATA_TYPE_CASE_ITEM(FP32); + SET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var type: " << static_cast(data_type); + } +#undef SET_DATA_TYPE_CASE_ITEM +} + +void VarDesc::SetDataTypes( + const std::vector &multiple_data_type) { + if (multiple_data_type.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given data types(" + << multiple_data_type.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_data_type.size()); + } + std::vector tensor_descs = + mutable_tensor_descs(); + for (size_t i = 0; i < multiple_data_type.size(); ++i) { + tensor_descs[i]->set_data_type(multiple_data_type[i]); + } +} + +// proto::VarType::Type VarDesc::GetDataType() const { +// return tensor_desc().data_type(); +// } +cpp::VarDescAPI::VarDataType VarDesc::GetDataType() const { + CHECK(desc_->has_type()) << "The var's type hasn't been set."; + CHECK(desc_->type().has_type()) << "The var type hasn't been set."; + if (desc_->type().type() != framework_proto::VarType::LOD_TENSOR) { + return VarDescAPI::Type(); + } + auto type = tensor_desc().data_type(); +#define GET_DATA_TYPE_CASE_ITEM(type__) \ + case framework_proto::VarType::Type::VarType_Type_##type__: \ + return VarDescAPI::Type::type__ + + switch (type) { + GET_DATA_TYPE_CASE_ITEM(BOOL); + GET_DATA_TYPE_CASE_ITEM(SIZE_T); + GET_DATA_TYPE_CASE_ITEM(UINT8); + GET_DATA_TYPE_CASE_ITEM(INT8); + GET_DATA_TYPE_CASE_ITEM(INT16); + GET_DATA_TYPE_CASE_ITEM(INT32); + GET_DATA_TYPE_CASE_ITEM(INT64); + GET_DATA_TYPE_CASE_ITEM(FP16); + GET_DATA_TYPE_CASE_ITEM(FP32); + GET_DATA_TYPE_CASE_ITEM(FP64); + default: + LOG(FATAL) << "Unknown var type: " << static_cast(type); + return VarDescAPI::Type(); + } +#undef GET_DATA_TYPE_CASE_ITEM +} + +std::vector VarDesc::GetDataTypes() const { + std::vector descs = tensor_descs(); + std::vector res; + res.reserve(descs.size()); + for (const auto &tensor_desc : descs) { + res.push_back(tensor_desc.data_type()); + } + return res; +} + +void VarDesc::SetLoDLevel(int32_t lod_level) { + switch (desc_->type().type()) { + case framework_proto::VarType::LOD_TENSOR: + desc_->mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level); + break; + case framework_proto::VarType::LOD_TENSOR_ARRAY: + desc_->mutable_type()->mutable_tensor_array()->set_lod_level(lod_level); + break; + default: + LOG(FATAL) + << "Setting 'lod_level' is not supported by the type of var %s." + << this->Name(); + } +} + +void VarDesc::SetLoDLevels(const std::vector &multiple_lod_level) { + if (multiple_lod_level.size() != GetTensorDescNum()) { + VLOG(3) << "WARNING: The number of given lod_levels(" + << multiple_lod_level.size() + << ") doesn't match the existing tensor number(" + << GetTensorDescNum() + << "). The Reader is going to be reinitialized."; + SetTensorDescNum(multiple_lod_level.size()); + } + switch (desc_->type().type()) { + case framework_proto::VarType::READER: { + size_t i = 0; + for (auto &lod_tensor : + *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { + lod_tensor.set_lod_level(multiple_lod_level[i++]); + } + } break; + default: + LOG(FATAL) + << "Setting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + } +} + +int32_t VarDesc::GetLoDLevel() const { + switch (desc_->type().type()) { + case framework_proto::VarType::LOD_TENSOR: + return desc_->type().lod_tensor().lod_level(); + case framework_proto::VarType::LOD_TENSOR_ARRAY: + return desc_->type().tensor_array().lod_level(); + default: + LOG(FATAL) + << "Getting 'lod_level' is not supported by the type of var %s." + << this->Name(); + } + return 0; +} + +std::vector VarDesc::GetLoDLevels() const { + std::vector res; + switch (desc_->type().type()) { + case framework_proto::VarType::READER: + res.reserve(desc_->type().reader().lod_tensor_size()); + for (auto &lod_tensor : desc_->type().reader().lod_tensor()) { + res.push_back(lod_tensor.lod_level()); + } + return res; + break; + default: + LOG(FATAL) + << "Getting 'lod_levels' is not supported by the type of var %s." + << this->Name(); + } + return std::vector(); +} + +const framework_proto::VarType::TensorDesc &VarDesc::tensor_desc() const { + CHECK(desc_->has_type()) << "The var's type hasn't been set."; + CHECK(desc_->type().has_type()) << "The var type hasn't been set."; + switch (desc_->type().type()) { + case framework_proto::VarType::SELECTED_ROWS: + return desc_->type().selected_rows(); + case framework_proto::VarType::LOD_TENSOR: + return desc_->type().lod_tensor().tensor(); + case framework_proto::VarType::LOD_TENSOR_ARRAY: + return desc_->type().tensor_array().tensor(); + default: + LOG(FATAL) + << "Getting 'tensor_desc' is not supported by the type of var %s." + << this->Name(); + } + return framework_proto::VarDesc().type().lod_tensor().tensor(); +} + +std::vector VarDesc::tensor_descs() + const { + CHECK(desc_->has_type()) << "The var type hasn't been set."; + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_->type().type()) { + case framework_proto::VarType::READER: + for (const auto &lod_tensor : desc_->type().reader().lod_tensor()) { + res.push_back(lod_tensor.tensor()); + } + return res; + default: + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + } + return std::vector(); +} + +framework_proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { + CHECK(desc_->has_type()) << "The var type hasn't been set."; + CHECK(desc_->type().has_type()) << "The var type hasn't been set."; + switch (desc_->type().type()) { + case framework_proto::VarType::SELECTED_ROWS: + return desc_->mutable_type()->mutable_selected_rows(); + case framework_proto::VarType::LOD_TENSOR: + return desc_->mutable_type()->mutable_lod_tensor()->mutable_tensor(); + case framework_proto::VarType::LOD_TENSOR_ARRAY: + return desc_->mutable_type()->mutable_tensor_array()->mutable_tensor(); + default: + LOG(FATAL) << "Getting 'mutable_tensor_desc' is not supported by the " + "type of var " + "%s." + << this->Name(); + } + return nullptr; +} + +std::vector +VarDesc::mutable_tensor_descs() { + CHECK(desc_->has_type()) << "The var type hasn't been set."; + CHECK(desc_->type().has_type()) << "The var type hasn't been set."; + std::vector res; + res.reserve(GetTensorDescNum()); + switch (desc_->type().type()) { + case framework_proto::VarType::READER: + for (auto &lod_tensor : + *desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) { + res.push_back(lod_tensor.mutable_tensor()); + } + return res; + default: + LOG(FATAL) + << "Getting 'tensor_descs' is not supported by the type of var " + "%s." + << this->Name(); + } + return std::vector(); +} + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/pb/var_desc.h b/paddle/infrt/paddle/pb/var_desc.h new file mode 100644 index 0000000000..4cff5fdee0 --- /dev/null +++ b/paddle/infrt/paddle/pb/var_desc.h @@ -0,0 +1,124 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include + +#include +#include +#include + +#include "paddle/infrt/paddle/cpp/desc_api.h" +#include "paddle/infrt/paddle/framework.pb.h" + +namespace infrt::paddle::pb { +namespace framework_proto = ::paddle::framework::proto; + +// convert between std::vector and protobuf repeated. +template +inline std::vector RepeatedToVector( + const google::protobuf::RepeatedField &repeated_field) { + std::vector ret; + ret.reserve(repeated_field.size()); + std::copy( + repeated_field.begin(), repeated_field.end(), std::back_inserter(ret)); + return ret; +} + +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Clear(); + repeated_field->Reserve(vec.size()); + for (const auto &elem : vec) { + *repeated_field->Add() = elem; + } +} + +// Specialize vector. +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Clear(); + repeated_field->Reserve(vec.size()); + for (auto elem : vec) { + *repeated_field->Add() = elem; + } +} + +class VarDesc : public cpp::VarDescAPI { + public: + VarDesc() = delete; + + explicit VarDesc(framework_proto::VarDesc *desc) : desc_(desc) { + CHECK(desc_); + } + + ::paddle::framework::proto::VarDesc *Proto() { return desc_; } + const framework_proto::VarDesc &ReadonlyProto() const { return *desc_; } + + std::string Name() const override { return desc_->name(); } + + void SetName(std::string name) override { desc_->set_name(name); } + + void SetTensorDescNum(size_t num); + + size_t GetTensorDescNum() const; + + void SetShape(const std::vector &dims); + + void SetShapes(const std::vector> &multiple_dims); + + std::vector GetShape() const; + + std::vector> GetShapes() const; + + void SetDataType(VarDescAPI::VarDataType data_type); + + void SetDataTypes( + const std::vector &multiple_data_type); + + VarDescAPI::VarDataType GetDataType() const; + + std::vector GetDataTypes() const; + + void SetLoDLevel(int32_t lod_level); + + void SetLoDLevels(const std::vector &multiple_lod_level); + + int32_t GetLoDLevel() const; + + std::vector GetLoDLevels() const; + + VarDescAPI::Type GetType() const override; + + void SetType(VarDescAPI::Type type) override; + + bool Persistable() const override { return desc_->persistable(); } + + void SetPersistable(bool persistable) override { + desc_->set_persistable(persistable); + } + + private: + const framework_proto::VarType::TensorDesc &tensor_desc() const; + std::vector tensor_descs() const; + framework_proto::VarType::TensorDesc *mutable_tensor_desc(); + std::vector mutable_tensor_descs(); + + framework_proto::VarDesc *desc_; +}; + +} // namespace infrt::paddle::pb diff --git a/paddle/infrt/paddle/scope.cc b/paddle/infrt/paddle/scope.cc new file mode 100644 index 0000000000..d7bab9f749 --- /dev/null +++ b/paddle/infrt/paddle/scope.cc @@ -0,0 +1,44 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/scope.h" + +#include "paddle/infrt/common/common.h" + +namespace infrt { +namespace paddle { + +_Variable* Scope::FindVar(const std::string& name) const { + auto it = data_.find(name); + if (it != data_.end()) return it->second.get(); + return nullptr; +} + +Tensor Scope::GetTensor(const std::string& name) const { + CheckVarNameValid(name); + auto* var = FindVar(name); + CHECK(var) << "No variable called [" << name << "] found"; + return var->get(); +} + +std::vector Scope::var_names() const { + std::vector names; + for (auto& item : data_) { + names.push_back(item.first); + } + return names; +} + +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/scope.h b/paddle/infrt/paddle/scope.h new file mode 100644 index 0000000000..4ebf846374 --- /dev/null +++ b/paddle/infrt/paddle/scope.h @@ -0,0 +1,68 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include +#include +#include + +#include "paddle/infrt/common/macros.h" +#include "paddle/infrt/paddle/tensor.h" +#include "paddle/infrt/support/variant.h" + +namespace infrt { +namespace paddle { + +using _Variable = Variant; + +struct _Tensor_; + +class Scope { + public: + static std::shared_ptr Create() { return std::make_shared(); } + + //! Get or create a variable. + template + _Variable* Var(const std::string& name); + + //! Find a variable, get null if not exists. + _Variable* FindVar(const std::string& name) const; + + Tensor GetTensor(const std::string& name) const; + + //! Get variable names. + std::vector var_names() const; + + Scope() = default; + + private: + std::unordered_map> data_; + + INFRT_DISALLOW_COPY_AND_ASSIGN(Scope); +}; + +template +_Variable* Scope::Var(const std::string& name) { + VLOG(4) << "Scope insert Var [" << name << "]"; + _Variable* x = FindVar(name); + if (x) return x; + auto* data = new _Variable(T()); + data_[name].reset(data); + return data; +} + +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/tensor.cc b/paddle/infrt/paddle/tensor.cc new file mode 100644 index 0000000000..072701ee90 --- /dev/null +++ b/paddle/infrt/paddle/tensor.cc @@ -0,0 +1,19 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/infrt/paddle/tensor.h" + +namespace infrt { +namespace paddle {} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/paddle/tensor.h b/paddle/infrt/paddle/tensor.h new file mode 100644 index 0000000000..5c4458bb62 --- /dev/null +++ b/paddle/infrt/paddle/tensor.h @@ -0,0 +1,107 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include + +#include "paddle/infrt/common/buffer.h" +#include "paddle/infrt/common/common.h" +#include "paddle/infrt/common/object.h" + +namespace infrt { +namespace paddle { +using common::Target; + +struct Shape { + using dim_t = int; + + Shape() = default; + explicit Shape(const std::vector& data) : data_(data) {} + + void SetData(const std::vector& data) { data_ = data; } + + const std::vector& data() const INFRT_RESULT_SHOULD_USE { + return data_; + } + std::vector& data() INFRT_RESULT_SHOULD_USE { return data_; } + size_t size() const INFRT_RESULT_SHOULD_USE { return data_.size(); } + uint32_t numel() const INFRT_RESULT_SHOULD_USE { + return std::accumulate( + data_.begin(), data_.end(), 1, [](dim_t a, dim_t b) { return a * b; }); + } + + private: + std::vector data_; +}; + +class _Tensor_ : public common::Object { + public: + _Tensor_() : buffer_(std::make_shared()) {} + + Shape& shape() { return shape_; } + + void Resize(const Shape& shape) { + shape_ = shape; + buffer_->data()->resize( + reinterpret_cast(shape.data().data()), + shape.size()); + } + + template + inline T* mutable_data(const Target& target) { + set_type(type_of()); + if (target == common::DefaultHostTarget()) { + int alignment = type_of().ElementOf().bits(); + buffer_->ResizeLazy(alignment, shape_.numel() * sizeof(T), target); + } else { + buffer_->ResizeLazy(shape_.numel() * sizeof(T), target); + } + return reinterpret_cast(buffer_->data()->memory); + } + + template + const T* data() const { + return reinterpret_cast(buffer_->data()->memory); + } + + const Type& type() { return type_; } + + void set_type(Type type) { type_ = type; } + const Type& type() const { return type_; } + + infrt_buffer_t* buffer() { return buffer_->data(); } + + const char* type_info() const override { return __type_info__; } + + private: + common::Type type_; + // A shared ptr to make it easier to share buffer between tensors. + std::shared_ptr buffer_; + Shape shape_; + + static constexpr const char* __type_info__ = "_frontend_tensor_"; +}; + +class Tensor : public Shared<_Tensor_> { + public: + Tensor() : Shared(new _Tensor_) {} + explicit Tensor(_Tensor_* x) : Shared(x) {} +}; + +} // namespace paddle +} // namespace infrt diff --git a/paddle/infrt/support/CMakeLists.txt b/paddle/infrt/support/CMakeLists.txt new file mode 100644 index 0000000000..9bcce6cab3 --- /dev/null +++ b/paddle/infrt/support/CMakeLists.txt @@ -0,0 +1 @@ +core_gather_headers() diff --git a/paddle/infrt/support/type_traits.h b/paddle/infrt/support/type_traits.h new file mode 100644 index 0000000000..341dabb7c1 --- /dev/null +++ b/paddle/infrt/support/type_traits.h @@ -0,0 +1,147 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file defines type traits related utilities. + +#pragma once + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" + +namespace infrt { + +// Utility template for tag dispatching. +template +struct TypeTag {}; + +// This is the equivalent of std::void_t in C++17. +template +struct make_void { + typedef void type; +}; +template +using void_t = typename make_void::type; + +// The same as std::disjunction in C++17. +template +struct disjunction : std::false_type {}; +template +struct disjunction : B1 {}; +template +struct disjunction + : std::conditional_t> {}; + +// Check whether T may be a base class. +template +using MaybeBase = + llvm::conjunction, llvm::negation>>; + +// Find the index of a type in a tuple. +// +// Example: +// using Tuple = std::tuple; +// static_assert(TupleIndexOf::value == 0); +// static_assert(TupleIndexOf::value == 2); +template +struct TupleIndexOf; + +template +struct TupleIndexOf> + : std::integral_constant {}; + +template +struct TupleIndexOf> + : std::integral_constant>::value> { +}; + +template +struct TupleHasType; + +template +struct TupleHasType> + : disjunction...> {}; + +// The detector pattern in C++ that can be used for checking whether a type has +// a specific property, e.g. whether an internal type is present or whether a +// particular operation is valid. +// +// Sample usage: +// +// struct Foo { +// using difference_type = int; +// int get(); +// }; +// struct Bar {}; +// +// // Check whether a type T has an internal difference_type. +// template +// using diff_t = typename T::difference_type; +// +// static_assert(is_detected_v, "Foo has difference_type"); +// static_assert(!is_detected_v, "Bar has no difference_type"); +// +// // Check whether a type T has a get() member function. +// template +// using has_get_t = decltype(std::declval().get()); +// +// static_assert(is_detected_v, "Foo has get()"); +// static_assert(!is_detected_v, "Bar has no get()"); +// +// See https://en.cppreference.com/w/cpp/experimental/is_detected for details. + +namespace internal { + +// nonesuch is a class type used to indicate detection failure. +struct nonesuch { + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template class Op, + class... Args> +struct detector : std::false_type { + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> { + using value_t = std::true_type; + using type = Op; +}; + +} // namespace internal + +template