未验证 提交 70dea138 编写于 作者: Y Yan Chunwei 提交者: GitHub

introduce INF-RT (#37669)

* add infrt code

refined with Paddle's code style.

* rename CinnRtConfig to InfRtConfig

* rename CinnRt to InfRt of some code

* rename CINNRT to INFRT

* remove unnecessary code

* replace CINN to INFRT in the source code

* replace all "cinn" in code to "infrt"

* remove some const_cast
上级 890bd626
......@@ -216,6 +216,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}
option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF)
option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF)
option(WITH_CINN "Compile PaddlePaddle with CINN" OFF)
option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON)
option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF)
......
include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF)
FetchContent_Declare(external_llvm
URL ${LLVM_DOWNLOAD_URL}
URL_MD5 ${LLVM_MD5}
PREFIX ${THIRD_PARTY_PATH}/llvm
SOURCE_DIR ${THIRD_PARTY_PATH}/install/llvm
)
if (NOT LLVM_PATH)
FetchContent_GetProperties(external_llvm)
if (NOT external_llvm_POPULATED)
FetchContent_Populate(external_llvm)
endif()
set(LLVM_PATH ${THIRD_PARTY_PATH}/install/llvm)
set(LLVM_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/llvm)
set(MLIR_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/mlir)
else ()
set(LLVM_DIR ${LLVM_PATH}/lib/cmake/llvm)
set(MLIR_DIR ${LLVM_PATH}/lib/cmake/mlir)
endif()
if (${CMAKE_CXX_COMPILER} STREQUAL "clang++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi")
endif()
message(STATUS "set LLVM_DIR: ${LLVM_DIR}")
message(STATUS "set MLIR_DIR: ${MLIR_DIR}")
find_package(LLVM REQUIRED CONFIG HINTS ${LLVM_DIR})
find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR})
find_package(ZLIB REQUIRED)
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM)
include_directories(${LLVM_INCLUDE_DIRS})
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
include(AddLLVM)
include(TableGen)
include(AddMLIR)
message(STATUS "Found MLIR: ${MLIR_DIR}")
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags:
#[==[
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \
#]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(llvm_libs Support Core irreader
X86 executionengine orcjit mcjit all codegen)
message(STATUS "LLVM libs: ${llvm_libs}")
get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS)
message(STATUS "MLIR libs: ${mlir_libs}")
add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix)
function(mlir_tablegen_on td_base)
set(options)
set(oneValueArgs DIALECT)
cmake_parse_arguments(mlir_tablegen_on "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-op-decls)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
endif()
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction()
function(mlir_add_rewriter td_base)
set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass")
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction()
# Execute the mlir script with infrt-exec program.
# @name: name of the test
# @script: path to the mlir script file
function (infrt_exec_check name script)
add_test(NAME ${name}
COMMAND sh -c "${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i ${CMAKE_CURRENT_SOURCE_DIR}/${script}| ${LLVM_PATH}/bin/FileCheck ${CMAKE_CURRENT_SOURCE_DIR}/${script}")
endfunction()
......@@ -391,6 +391,11 @@ if (WIN32)
list(APPEND third_party_deps extern_dirent)
endif (WIN32)
if (WITH_INFRT)
include(external/llvm)
list(APPEND third_party_deps external_llvm)
endif()
if (WITH_IPU)
include(external/poplar)
list(APPEND third_party_deps extern_poplar)
......
......@@ -2,4 +2,5 @@ add_subdirectory(scripts)
add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(pten)
add_subdirectory(infrt)
add_subdirectory(fluid)
if (NOT WITH_INFRT)
return()
endif()
set(infrt_src CACHE INTERNAL "" FORCE)
# Gather headers for library publish.
function(core_gather_headers)
file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h)
foreach(header ${includes})
set(core_includes "${core_includes};${header}" CACHE INTERNAL "")
endforeach()
endfunction()
function(gather_srcs SRC_GROUP)
set(options)
set(oneValueArgs)
set(multiValueArgs "SRCS")
cmake_parse_arguments(prefix "" "" "${multiValueArgs}" ${ARGN})
foreach(cpp ${prefix_SRCS})
set(${SRC_GROUP} "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" CACHE INTERNAL "")
endforeach()
endfunction()
# This method is similar to the global cc_test, but discard the huge amount default dependencies those are
# not needed by INFRT.
function(cc_test_tiny TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test_tiny "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_tiny_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_tiny_DEPS} ${os_dependency_modules} infrt_gtest_main gtest )
add_dependencies(${TARGET_NAME} ${cc_test_tiny_DEPS} infrt_gtest_main gtest extern_gtest)
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} "${cc_test_tiny_ARGS}"
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_tiny_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
endif()
endfunction()
if (WITH_TESTING)
cc_library(infrt_gtest_main SRCS gtest_main.cc DEPS gtest glog gflags)
endif()
add_subdirectory(api)
add_subdirectory(common)
add_subdirectory(dialect)
add_subdirectory(host_context)
add_subdirectory(kernel)
add_subdirectory(tensor)
add_subdirectory(support)
add_subdirectory(external_kernels)
add_subdirectory(paddle)
# MLIR td file generations
set(infrt_mlir_incs
ops_inc
basic_kernels_inc
test_kernels_inc
infrt_base_inc
tensor_shape_inc
dense_tensor_inc
pd_ops_inc
rewrite_inc
)
message(STATUS "infrt srcs:\n${infrt_src}")
cc_library(infrt SRCS ${infrt_src} DEPS glog ${mlir_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs})
core_gather_headers()
gather_srcs(infrt_src SRCS
infrt_api.cc
)
# Disable temporarily for the external-kernel's mkldnn is outdate
# cc_test(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/api/infrt_api.h"
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/DynamicLibrary.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Parser.h>
#include <unordered_map>
#include <vector>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/tensor/tensor_map.h"
using namespace infrt::host_context; // NOLINT
using namespace infrt::tensor; // NOLINT
using namespace infrt::tensor; // NOLINT
using infrt::dt::TensorMapType; // NOLINT
using infrt::dt::TensorType; // NOLINT
namespace infrt {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
std::string buffer;
llvm::raw_string_ostream os(buffer);
op.print(os);
os.flush();
return buffer;
}
struct MlirToRuntimeTranslator::Impl {
mlir::ModuleOp module;
// The runtime for a function call.
CoreRuntimeBuilder* runtime{};
// The current working op, the translator process the ops one by one, each
// time it updates `cur_op` here to current op
// working on.
OpExecutableBuilder* cur_op{};
// record the current function name.
std::string cur_func_name;
// Name to function definitions.
std::unordered_map<std::string, mlir::FuncOp> func_defs;
// Map from an operation to its results.
std::unordered_map<const mlir::Operation*, std::vector<ValueRef>> op_results;
llvm::DenseMap<mlir::Value, ValueRef> value_map;
};
/**
* Execute the mlir program in predict mode.
*/
class PredictExecutor : public MlirToRuntimeTranslator {
public:
CoreRuntimeBuilder core_runtime;
PredictExecutor(mlir::ModuleOp module,
KernelRegistry* registry,
TensorMap* map)
: MlirToRuntimeTranslator(module, &core_runtime),
core_runtime(registry),
registry_(registry) {
CHECK(registry_);
Init(map);
}
void Run() {
auto arguments = llvm::makeArrayRef(arguments_);
auto results = llvm::makeMutableArrayRef(results_.begin(), results_.size());
function_executable_->Execute(arguments, results);
}
int GetInputNum() { return inputs_.size(); }
DenseHostTensor* GetInput(int i) { return inputs_[i]; }
int GetOutputNum() { return outputs_.size(); }
DenseHostTensor* GetOutput(int i) { return outputs_[i]; }
private:
void Init(TensorMap* map) {
EmitFunctions();
llvm::Optional<mlir::FuncOp> predict_func_ = llvm::None;
for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
if (func_op.getName().str() != "predict") continue;
predict_func_ = func_op;
break;
}
if (!predict_func_) {
std::cout << "ERROR: init failed, no predict function found in mlir."
<< std::endl;
return;
}
auto& predict_func = predict_func_.getValue();
function_executable_ =
new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs);
// process parammeters
for (size_t i = 0; i < predict_func.getNumArguments(); ++i) {
auto arg = predict_func.getArgument(i);
auto type = arg.getType();
// this param is TensorMap
if (type.isa<TensorMapType>()) {
auto* value = new host_context::Value(std::move(*map));
arguments_.push_back(value);
AddValue(predict_func.getArgument(i), value);
} else {
// this param is an input Tensor
auto dht = DenseHostTensor();
auto* value = new host_context::Value(std::move(dht));
arguments_.push_back(value);
inputs_.push_back(&(value->get<DenseHostTensor>()));
}
}
// process results
auto& last_op = predict_func.front().back();
if (last_op.getName().getStringRef() == "infrt.return") {
for (size_t i = 0; i < last_op.getNumOperands(); ++i) {
auto* value = AddValue(mlir::Value(last_op.getOperand(i)));
results_.push_back(ValueRef(value));
outputs_.push_back(&(value->get<DenseHostTensor>()));
}
}
}
protected:
std::unordered_map<std::string, mlir::FuncOp> func_def_table;
void EmitFunction(mlir::FuncOp op) override {
CHECK(!impl_->func_defs.count(op.getName().str()))
<< "Duplicate function defition found for function ["
<< op.getName().str();
impl_->func_defs.emplace(op.getName().str(), op);
}
private:
KernelRegistry* registry_{};
MlirFunctionExecutable* function_executable_;
llvm::SmallVector<DenseHostTensor*, 1> inputs_;
llvm::SmallVector<host_context::Value*, 2> arguments_;
llvm::SmallVector<DenseHostTensor*, 1> outputs_;
llvm::SmallVector<ValueRef, 1> results_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(
const InfRtConfig& config) {
auto x = std::make_shared<InfRtPredictor>();
x->Init(config);
return x;
}
struct InfRtPredictor::Impl {
mlir::OwningModuleRef module_ref;
std::unique_ptr<PredictExecutor> executor;
};
InfRtPredictor::InfRtPredictor() : impl_(new Impl) {}
InfRtPredictor::~InfRtPredictor() {}
void InfRtPredictor::Run() { impl_->executor->Run(); }
int InfRtPredictor::Init(const InfRtConfig& config) {
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context);
KernelRegistry* registry = new KernelRegistry();
kernel::RegisterBasicKernels(registry);
kernel::RegisterTestKernels(registry);
kernel::RegisterTensorShapeKernels(registry);
kernel::RegisterTensorKernels(registry);
kernel::RegisterControlFlowKernels(registry);
impl_->module_ref = std::move(module_ref);
// load extra shared library
for (const std::string& lib_path : config.shared_libs()) {
std::string err;
llvm::sys::DynamicLibrary dynLib =
llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err);
if (!dynLib.isValid()) {
llvm::errs() << "Load shared library failed. Error: " << err << "\n";
return 1;
}
if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) {
auto reg_func = reinterpret_cast<void (*)(KernelRegistry*)>(reg_sym);
reg_func(registry);
} else {
llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path
<< "\". Skip.\n";
}
}
// Load params
TensorMap* tensor_map = LoadParams(config.model_dir());
// Create PredictExecutor
impl_->executor.reset(
new PredictExecutor(impl_->module_ref.get(), registry, tensor_map));
return 0;
}
int InfRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); }
DenseHostTensor* InfRtPredictor::GetInput(int i) {
return impl_->executor->GetInput(i);
}
int InfRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); }
DenseHostTensor* InfRtPredictor::GetOutput(int i) {
return impl_->executor->GetOutput(i);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
class InfRtConfig {
std::string model_dir_;
std::string mlir_path_;
std::vector<std::string> shared_libs_;
public:
InfRtConfig() = default;
void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; }
const std::string& model_dir() const { return model_dir_; }
void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; }
const std::string& mlir_path() const { return mlir_path_; }
void set_shared_libs(const std::vector<std::string>& shared_libs) {
shared_libs_ = shared_libs;
}
const std::vector<std::string>& shared_libs() const { return shared_libs_; }
virtual ~InfRtConfig() = default;
};
class InfRtPredictor {
public:
InfRtPredictor();
~InfRtPredictor();
void Run();
int Init(const InfRtConfig& config);
int GetInputNum();
tensor::DenseHostTensor* GetInput(int i);
int GetOutputNum();
tensor::DenseHostTensor* GetOutput(int i);
protected:
struct Impl;
std::unique_ptr<Impl> impl_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(const InfRtConfig& config);
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/api/infrt_api.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "llvm/Support/raw_ostream.h"
#include "paddle/infrt/common/buffer.h"
#include "paddle/infrt/common/dtype.h"
using infrt::InfRtConfig;
using infrt::InfRtPredictor;
using infrt::CreateInfRtPredictor;
namespace infrt {
TEST(InfRtPredictor, predictor) {
std::vector<std::string> shared_libs;
shared_libs.push_back("../../paddle/libexternal_kernels.so");
InfRtConfig config;
// set external shared libraries that contain kernels.
config.set_shared_libs(shared_libs);
// set model dir
config.set_model_dir("../../paddle/paddle_1.8_fc_model");
// set mlir path
config.set_mlir_path("../../../infrt/dialect/mlir_tests/tensor_map.mlir");
std::shared_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
auto* input = predictor->GetInput(0);
std::vector<int64_t> shape = {3, 3};
input->Init(shape, infrt::GetDType<float>());
llvm::outs() << input->shape() << "\n";
// init input tensor
auto* input_data = reinterpret_cast<float*>(input->buffer()->data()->memory);
for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0;
predictor->Run();
// get and print output tensor
auto* output = predictor->GetOutput(0);
auto* output_data =
reinterpret_cast<float*>(output->buffer()->data()->memory);
std::vector<float> ans = {0.428458,
0.244493,
0.572342,
0.572008,
0.509771,
0.495599,
0.651287,
0.326426,
0.404649};
ASSERT_EQ(output->shape().GetNumElements(), ans.size());
for (int i = 0; i < output->shape().GetNumElements(); ++i) {
ASSERT_NEAR(output_data[i], ans[i], 0.000001);
}
}
} // namespace infrt
core_gather_headers()
set(core_includes "${core_includes};infrt/common/dtype.def" CACHE INTERNAL "")
gather_srcs(infrt_src SRCS
dtype.cc
global.cc
target.cc
type.cc
shared.cc
object.cc
string.cc
buffer.cc
memory.cc
)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/buffer.h"
#include <stdarg.h>
#include <stdio.h>
#include <cmath>
namespace infrt {
void Buffer::Resize(uint32_t size) {
if (size_ > 0) {
Free();
size_ = 0;
}
if (size_ != size) {
data_.memory = reinterpret_cast<uint8_t*>(Malloc(size));
size_ = size;
}
}
void Buffer::Resize(uint32_t alignment, uint32_t size) {
if (size_ > 0) {
Free();
size_ = 0;
}
if (size_ != size) {
data_.memory = reinterpret_cast<uint8_t*>(AlignedAlloc(alignment, size));
size_ = size;
}
}
void Buffer::SetTarget(const infrt::common::Target& target) {
target_ = target;
memory_mng_cache_ = MemoryManager::Global().RetrieveSafely(target_.arch);
}
void Buffer::ResizeLazy(uint32_t size) {
if (size <= size_) return;
Resize(size);
}
void Buffer::ResizeLazy(uint32_t alignment, uint32_t size) {
if (size <= size_) return;
Resize(alignment, size);
}
void Buffer::Resize(uint32_t size, const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
Resize(size);
}
void Buffer::Resize(uint32_t alignment,
uint32_t size,
const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
Resize(alignment, size);
}
void Buffer::ResizeLazy(uint32_t size, const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
ResizeLazy(size);
}
void Buffer::ResizeLazy(uint32_t alignment,
uint32_t size,
const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
ResizeLazy(alignment, size);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/memory.h"
#include "paddle/infrt/common/target.h"
namespace infrt {
#ifdef __cplusplus
extern "C" {
#endif
#define INFRT_ALWAYS_INLINE __attribute__((always_inline)) inline
//! Code for the primitive types supported in INFRT.
typedef enum infrt_type_code_t {
infrt_type_unk = -1, //! Unknown type
infrt_type_int = 0, //! signed int
infrt_type_uint = 1, //! unsigned int
infrt_type_float = 2, //! floating point
infrt_type_handle = 3 //! void*
} infrt_type_code_t;
#ifndef INFRT_ATTRIBUTE_ALIGN
#define INFRT_ATTRIBUTE_ALIGN(n) __attribute__((aligned(n)))
#endif
/**
* A tuntime tag for type in INFRT system.
*/
typedef struct infrt_type_t {
#if __cplusplus >= 201103L
INFRT_ATTRIBUTE_ALIGN(1) infrt_type_code_t code;
#else
uint8_t code;
#endif
//! Number of bits.
uint8_t bits;
//! Number of elements in a vector, 1 for scalar.
uint16_t lanes;
//! Number of '*', e.g. for `float*`, the num_asterisks is 1, `float**` it is
//! 2.
uint8_t num_asterisks{0};
#ifdef __cplusplus
INFRT_ALWAYS_INLINE infrt_type_t()
: code(infrt_type_int), bits(0), lanes(0) {}
INFRT_ALWAYS_INLINE infrt_type_t(infrt_type_code_t code,
uint8_t bits,
uint16_t lanes = 1,
uint8_t num_asterisks = 0)
: code(code), bits(bits), lanes(lanes), num_asterisks(num_asterisks) {}
INFRT_ALWAYS_INLINE bool operator==(const infrt_type_t& other) const {
return code == other.code && bits == other.bits && lanes == other.lanes;
}
INFRT_ALWAYS_INLINE bool operator!=(const infrt_type_t& other) const {
return !(*this == other);
}
INFRT_ALWAYS_INLINE uint16_t bytes() const { return (bits + 7) / 8; }
#endif // __cplusplus
} infrt_type_t;
//! Help to define the size of a dimension, due to polyhedral representation, we
//! no need to record the extend or
//! min(default to 0).
typedef int infrt_dimension_t;
//! Help to tell the kind of the device.
typedef enum infrt_device_kind_t {
infrt_unk_device = -1, // Undefined device.
infrt_x86_device = 0, // X86 device
infrt_opencl_device = 1, // OpenCL device
infrt_arm_device = 2 // ARM device
} infrt_device_kind_t;
struct infrt_buffer_t;
/**
* All INFRT backends implementation should provide an interface to be used.
*/
struct infrt_device_interface_impl_t;
struct infrt_device_interface_t {
int (*malloc)(void* context, struct infrt_buffer_t* buf);
int (*free)(void* context, struct infrt_buffer_t* buf);
int (*sync)(void* context, struct infrt_buffer_t* buf);
int (*release)(void* context,
const struct infrt_device_interface_t* device_interface);
int (*copy_to_host)(void* context, struct infrt_buffer_t* buf);
int (*copy_to_device)(void* context, struct infrt_buffer_t* buf);
int (*buffer_copy)(void* context,
struct infrt_buffer_t* src,
struct infrt_buffer_t* dst);
struct infrt_device_interface_impl_t* impl;
};
//! The raw representation of a buffer,used in the generated code/lib.
#define INFRT_BUFFER_MAX_DIMS 8
typedef struct infrt_buffer_t {
//! Tell which kind of device this buffer locates.
infrt_device_kind_t device;
//! The interface used to operate on device.
const struct infrt_device_interface_t* device_interface;
//! A pointer to the memory in host.
uint8_t* memory;
//! Extra flags.
uint64_t flag;
//! Data type.
infrt_type_t type;
//! Number of dimensions.
int32_t dimensions;
infrt_dimension_t dims[INFRT_BUFFER_MAX_DIMS];
//! Allocate and deallocate lazily, default true.
char lazy;
//! The actual memory size(in bytes).
uint64_t memory_size;
uint16_t align;
#ifdef __cplusplus
infrt_buffer_t()
: device(infrt_unk_device),
device_interface(NULL),
memory(NULL),
flag(0UL),
type(infrt_type_t()),
dimensions(0),
lazy(true),
memory_size(0),
align(0) {}
static void delete_(struct infrt_buffer_t* x) { delete x; }
~infrt_buffer_t() {}
// NOTE the buffer should be resized first.
static void alloc(struct infrt_buffer_t*);
//! Set the shape of the buffer. NOTE this just record the shape, not allocate
//! the memory.
INFRT_ALWAYS_INLINE void resize(const infrt_dimension_t* dims,
int dimensions) {
this->dimensions = dimensions;
memcpy(this->dims, dims, dimensions * sizeof(infrt_dimension_t));
}
INFRT_ALWAYS_INLINE uint64_t num_elements() const {
uint64_t res = 1;
for (int i = 0; i < dimensions; i++) {
res *= dims[i];
}
return res;
}
INFRT_ALWAYS_INLINE int device_sync(void* ctx = NULL) {
if (device_interface && device_interface->sync) {
return device_interface->sync(ctx, this);
}
return 0;
}
INFRT_ALWAYS_INLINE uint8_t* begin() const { return 0; }
INFRT_ALWAYS_INLINE uint8_t* end() const {
return memory + num_elements() * type.bytes();
}
#endif // __cplusplus
} infrt_buffer_t;
#ifdef __cplusplus
struct infrt_device_interface_impl_t {
int (*malloc)(void* context, struct infrt_buffer_t* buf);
int (*free)(void* context, struct infrt_buffer_t* buf);
int (*sync)(void* context, struct infrt_buffer_t* buf);
int (*release)(void* context);
int (*copy_to_host)(void* context, struct infrt_buffer_t* buf);
int (*copy_to_device)(void* context, struct infrt_buffer_t* buf);
int (*buffer_copy)(void* context,
struct infrt_buffer_t* src,
struct infrt_buffer_t* dst);
};
// The device implementations
extern struct infrt_device_interface_t* infrt_x86_device_interface();
#endif // __cplusplus
#ifdef __cplusplus
} // extern "C"
#endif
#define INFRT_LOG(fmt, ...) \
do { \
fprintf(stderr, \
"%s:%d:%s(): " fmt, \
__FILE__, \
__LINE__, \
__func__, \
__VA_ARGS__); \
} while (0)
#define INFRT_CHECK(cond) \
if (!(cond)) { \
INFRT_LOG("check %s failed", #cond); \
abort(); \
}
/**
* Buffer helps to hold the memory, and offers a set of methods to help manage
* the memory.
*/
struct Buffer final {
Buffer() = default;
explicit Buffer(const common::Target& target) { SetTarget(target); }
//! Resize the memory hold by this buffer *exactlly* to \p size.
void Resize(uint32_t size);
void Resize(uint32_t alignment, uint32_t size);
//! Lazily resize the memory.
void ResizeLazy(uint32_t size);
void ResizeLazy(uint32_t alignment, uint32_t size);
//! Resize the memory to \p size in target \p target.
void Resize(uint32_t size, const common::Target& target);
void Resize(uint32_t alignment, uint32_t size, const common::Target& target);
//! Lazily resize the memory to \p size in target \p target.
void ResizeLazy(uint32_t size, const common::Target& target);
void ResizeLazy(uint32_t alignment,
uint32_t size,
const common::Target& target);
void SetTarget(const common::Target& target);
const infrt_buffer_t* data() const { return &data_; }
infrt_buffer_t* data() { return &data_; }
//! Free all the memory owned by this buffer.
void Free() {
if (!data_.memory) return;
memory_mng_cache_->free(data_.memory);
}
private:
inline void* Malloc(uint32_t size) INFRT_RESULT_SHOULD_USE {
CHECK(memory_mng_cache_) << "Should set target first";
return memory_mng_cache_->malloc(size);
}
inline void* AlignedAlloc(uint32_t alignment,
uint32_t size) INFRT_RESULT_SHOULD_USE {
CHECK(memory_mng_cache_) << "Should set target first";
return memory_mng_cache_->aligned_alloc(alignment, size);
}
private:
infrt_buffer_t data_;
//! The place where this buffer locates.
common::Target target_;
//! Number of bytes of this buffer.
uint32_t size_{};
//! Hold the corresponding memory manager for speed.
MemoryInterface* memory_mng_cache_{};
};
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/shared.h"
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt {
// export some general concepts.
using common::make_shared;
using common::Object;
using common::ref_count;
using common::Shared;
// Type related.
using common::Bool;
using common::Float;
using common::Int;
using common::UInt;
using common::Void;
using common::type_of;
using common::Target;
using common::Type;
using common::UnkTarget;
template <typename T>
T& Reference(const T* x) {
return *const_cast<T*>(x);
}
static void CheckVarNameValid(const std::string& name) {
CHECK(!name.empty());
CHECK(name.find(' ') == std::string::npos && //
name.find('.') == std::string::npos && //
name.find('/') == std::string::npos && //
name.find('\t') == std::string::npos && //
name.find('\n') == std::string::npos && //
name.find('\r') == std::string::npos)
<< "Some invalid character found";
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/dtype.h"
namespace infrt {
const char* DType::name() const {
switch (kind_) {
case Kind::Unk:
return "Unk";
break;
#define INFRT_DTYPE(enum__, value__) \
case Kind::enum__: \
return #enum__; \
break;
#include "paddle/infrt/common/dtype.def"
#undef INFRT_DTYPE
}
return "";
}
size_t DType::GetHostSize() const {
switch (kind_) {
#define INFRT_DTYPE(enum__, value__) \
case DType::Kind::enum__: \
return sizeof(DTypeInternal<DType::Kind::enum__>::type);
#include "paddle/infrt/common/dtype.def" // NOLINT
#undef INFRT_DTYPE
case Kind::Unk:
return 0;
break;
}
return 0;
}
} // namespace infrt
// Define all INFRT dtypes
// DTYPE(ENUM, VALUE)
#ifdef INFRT_DTYPE
INFRT_DTYPE(UI8, 1)
INFRT_DTYPE(UI16, 2)
INFRT_DTYPE(UI32, 3)
INFRT_DTYPE(UI64, 4)
INFRT_DTYPE(I1, 5)
INFRT_DTYPE(I8, 6)
INFRT_DTYPE(I16, 7)
INFRT_DTYPE(I32, 8)
INFRT_DTYPE(I64, 9)
INFRT_DTYPE(F32, 10)
INFRT_DTYPE(F64, 11)
INFRT_DTYPE(STRING, 12)
#endif
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
namespace infrt {
class DType {
public:
enum class Kind : uint8_t {
Unk = 0,
// Automatically generate the enum definition
#define INFRT_DTYPE(enum__, value__) enum__ = value__,
#include "paddle/infrt/common/dtype.def"
#undef INFRT_DTYPE
BOOL = I1,
};
DType() = default;
explicit constexpr DType(Kind kind) : kind_(kind) { assert(IsValid()); }
DType(const DType&) = default;
DType& operator=(const DType&) = default;
bool operator==(DType other) const { return kind_ == other.kind_; }
bool operator!=(DType other) const { return !(*this == other); }
constexpr Kind kind() const { return kind_; }
bool IsValid() const { return kind_ != Kind::Unk; }
bool IsInvalid() const { return !IsValid(); }
const char* name() const;
size_t GetHostSize() const;
private:
Kind kind_{Kind::Unk};
};
template <typename T>
constexpr DType GetDType();
template <DType::Kind kind>
struct DTypeInternal;
#define INFRT_IMPL_GET_DTYPE(cpp_type__, enum__) \
template <> \
inline constexpr DType GetDType<cpp_type__>() { \
return DType{DType::Kind::enum__}; \
} \
template <> \
struct DTypeInternal<DType::Kind::enum__> { \
using type = cpp_type__; \
};
INFRT_IMPL_GET_DTYPE(bool, I1);
INFRT_IMPL_GET_DTYPE(int8_t, I8);
INFRT_IMPL_GET_DTYPE(int16_t, I16);
INFRT_IMPL_GET_DTYPE(int32_t, I32);
INFRT_IMPL_GET_DTYPE(int64_t, I64);
INFRT_IMPL_GET_DTYPE(uint8_t, UI8);
INFRT_IMPL_GET_DTYPE(uint16_t, UI16);
INFRT_IMPL_GET_DTYPE(uint32_t, UI32);
INFRT_IMPL_GET_DTYPE(uint64_t, UI64);
INFRT_IMPL_GET_DTYPE(float, F32);
INFRT_IMPL_GET_DTYPE(double, F64);
INFRT_IMPL_GET_DTYPE(std::string, STRING);
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/global.h"
namespace infrt {
Global::Global() {}
mlir::MLIRContext* Global::context = nullptr;
mlir::MLIRContext* Global::getMLIRContext() {
if (nullptr == context) {
context = new mlir::MLIRContext();
}
return context;
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
// global variables
class Global {
private:
static mlir::MLIRContext *context;
Global();
public:
static mlir::MLIRContext *getMLIRContext();
}; // class Global
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if !defined(NDEBUG)
#define INFRT_DEBUG
#endif
#define INFRT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#ifndef INFRT_NOT_IMPLEMENTED
#define INFRT_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented";
#endif
#define INFRT_RESULT_SHOULD_USE __attribute__((warn_unused_result))
/**
* A trick to enforce the registry.
*
* usage:
*
* INFRT_REGISTER_HELPER(some_key) {
* // register methods
* }
*
* INFRT_USE_REGISTER(some_key);
*/
#define INFRT_REGISTER_HELPER(symbol__) bool __infrt__##symbol__##__registrar()
#define INFRT_USE_REGISTER(symbol__) \
extern bool __infrt__##symbol__##__registrar(); \
[[maybe_unused]] static bool __infrt_extern_registrar_##symbol__ = \
__infrt__##symbol__##__registrar();
#if __cplusplus >= 201703L
#define INFRT_NODISCARD [[nodiscard]]
#else
#define INFRT_NODISCARD
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/memory.h"
namespace infrt {
using infrt::common::Target;
namespace {
class X86MemoryMng : public MemoryInterface {
public:
void* malloc(size_t nbytes) override { return ::malloc(nbytes); }
void free(void* data) override {
if (!data) return;
::free(data);
}
void* aligned_alloc(size_t alignment, size_t nbytes) override {
return ::aligned_alloc(alignment, nbytes);
}
};
} // namespace
MemoryManager::MemoryManager() {
Register(Target::Arch::Unk, new X86MemoryMng);
Register(Target::Arch::X86, new X86MemoryMng);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <unordered_map>
#include <memory>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/target.h"
namespace infrt {
class MemoryInterface {
public:
virtual void* malloc(size_t nbytes) = 0;
virtual void free(void* data) = 0;
virtual void* aligned_alloc(size_t alignment, size_t nbytes) {
return nullptr;
}
virtual ~MemoryInterface() {}
};
/**
* MemoryManager holds a map of MemoryInterface for each articture.
*/
class MemoryManager final {
public:
using key_t = common::Target::Arch;
static MemoryManager& Global() {
static auto* x = new MemoryManager;
return *x;
}
MemoryInterface* Retrieve(key_t key) INFRT_RESULT_SHOULD_USE {
auto it = memory_mngs_.find(key);
if (it != memory_mngs_.end()) return it->second.get();
return nullptr;
}
MemoryInterface* RetrieveSafely(key_t key) {
auto* res = Retrieve(key);
CHECK(res) << "no MemoryInterface for architecture [" << key << "]";
return res;
}
MemoryInterface* Register(key_t key, MemoryInterface* item) {
CHECK(!memory_mngs_.count(key)) << "Duplicate register [" << key << "]";
memory_mngs_[key].reset(item);
return item;
}
private:
MemoryManager();
std::unordered_map<common::Target::Arch, std::unique_ptr<MemoryInterface>>
memory_mngs_;
INFRT_DISALLOW_COPY_AND_ASSIGN(MemoryManager);
};
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/object.h"
namespace infrt {
namespace common {} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include "paddle/infrt/common/shared.h"
namespace infrt {
namespace common {
template <typename T>
class Shared;
/**
* Object is the basic element in the INFRT, with `Shared` wrapper, the object
* can be shared accross the system.
*/
struct Object {
//! Get the type representation of this object.
virtual const char* type_info() const = 0;
virtual ~Object() {}
//! Cast to a derived type.
template <typename T>
T* as() {
return static_cast<T*>(this);
}
//! Cast to a derived type.
template <typename T>
const T* as() const {
return static_cast<const T*>(this);
}
//! Type safe cast.
template <typename T>
T* safe_as() {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<T*>(this);
}
return nullptr;
}
//! Type safe cast.
template <typename T>
const T* safe_as() const {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<const T*>(this);
}
return nullptr;
}
//! Check if the type is right.
template <typename T>
bool is_type() const {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return true;
}
return false;
}
//! The reference count, which make all the derived type able to share.
mutable RefCount __ref_count__;
};
using object_ptr = Object*;
using shared_object = Shared<Object>;
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/shared.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <type_traits>
namespace infrt {
namespace common {
class RefCount {
public:
using value_type = int32_t;
RefCount() = default;
value_type Inc() { return ++count_; }
value_type Dec() { return --count_; }
bool is_zero() const { return 0 == count_; }
std::string to_string() { return std::to_string(count_.load()); }
int32_t val() const { return count_; }
private:
std::atomic<value_type> count_{0};
};
class Object;
/**
* The templated methods are used to unify the way to get the RefCount instance
* in client classes.
*/
template <typename T>
RefCount& ref_count(const T* t) {
static_assert(std::is_base_of<Object, T>::value, "T is not a Object");
return t->__ref_count__;
}
template <typename T>
void Destroy(const T* t) {
delete t;
}
template <typename T>
struct Shared {
using object_ptr = T*;
Shared() = default;
explicit Shared(T* p) : p_(p) {
if (p) IncRef(p);
}
Shared(const Shared& other) : p_(other.p_) { IncRef(p_); }
Shared(Shared&& other) : p_(other.p_) { other.p_ = nullptr; }
Shared<T>& operator=(const Shared<T>& other);
//! Reset to another pointer \p x.
void Reset(T* x = nullptr);
//! Access the pointer in various ways.
// @{
inline T* get() const { return p_; }
inline T& operator*() const { return *p_; }
inline T* operator->() const { return p_; }
inline T* self() { return p_; }
inline const T* self() const { return p_; }
// @}
inline bool same_as(const Shared& other) { return p_ == other.p_; }
inline bool defined() const { return p_; }
inline bool operator<(const Shared& other) const { return p_ < other.p_; }
inline Shared<T>& operator=(T* x);
inline bool operator==(const Shared& other) const { return p_ == other.p_; }
~Shared();
private:
//! Increase the share count.
void IncRef(T* p);
//! Decrease the share count.
void DecRef(T* p);
protected:
T* p_{};
};
template <typename T>
void Shared<T>::IncRef(T* p) {
if (p) {
ref_count(p).Inc();
}
}
template <typename T>
void Shared<T>::DecRef(T* p) {
if (p) {
if (ref_count(p).Dec() == 0) {
Destroy(p);
}
}
}
template <typename T>
Shared<T>& Shared<T>::operator=(const Shared<T>& other) {
if (other.p_ == p_) return *this;
// Other can be inside of something owned by this, so we should be careful to
// incref other before we decref
// ourselves.
T* tmp = other.p_;
IncRef(tmp);
DecRef(p_);
p_ = tmp;
return *this;
}
template <typename T, typename... Args>
T* make_shared(Args&&... args) {
return new T(args...);
}
template <typename T>
Shared<T>& Shared<T>::operator=(T* x) {
if (p_ == x) return *this;
T* tmp = x;
IncRef(tmp);
DecRef(p_);
p_ = tmp;
return *this;
}
template <typename T>
Shared<T>::~Shared() {
DecRef(p_);
p_ = nullptr;
}
template <typename T>
void Shared<T>::Reset(T* x) {
if (x) IncRef(x);
DecRef(p_);
p_ = x;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/string.h"
#include <stdarg.h>
#include <cstring>
namespace infrt {
namespace infrt {
std::string StringFormat(const std::string &fmt_str, ...) {
/* Reserve two times as much as the length of the fmt_str */
int final_n, n = (static_cast<int>(fmt_str.size())) * 2;
std::unique_ptr<char[]> formatted;
va_list ap;
while (1) {
formatted.reset(
new char[n]); /* Wrap the plain char array into the unique_ptr */
std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT
va_start(ap, fmt_str);
final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap);
va_end(ap);
if (final_n < 0 || final_n >= n)
n += abs(final_n - n + 1);
else
break;
}
return std::string(formatted.get());
}
std::string Trim(const std::string &s, const char *empty) {
if (s.empty()) return s;
auto start = s.find_first_not_of(empty);
if (start == std::string::npos) return "";
auto end = s.find_last_not_of(empty);
return s.substr(start, end - start + 1);
}
std::string Uppercase(const std::string &x) {
auto res = x;
for (auto &c : res) {
c = toupper(c);
}
return res;
}
bool Startswith(const std::string &x, const std::string &str) {
return x.find(str) == 0;
}
bool Endswith(const std::string &x, const std::string &str) {
if (x.length() >= str.length()) {
return std::equal(str.rbegin(), str.rend(), x.rbegin());
}
return false;
}
std::vector<std::string> Split(const std::string &str,
const std::string &splitter) {
std::vector<std::string> results;
std::string::size_type pos1, pos2;
pos2 = str.find(splitter);
pos1 = 0;
while (std::string::npos != pos2) {
results.push_back(str.substr(pos1, pos2 - pos1));
pos1 = pos2 + splitter.size();
pos2 = str.find(splitter, pos1);
}
if (pos1 != str.length()) {
results.push_back(str.substr(pos1));
}
return results;
}
void Replace(std::string *s, const std::string &from, const std::string &to) {
size_t pos = 0;
while ((pos = s->find(from, pos)) != std::string::npos) {
s->replace(pos, from.size(), to);
pos += to.length();
}
}
size_t Count(std::string *s, const std::string &sub) {
size_t pos = 0;
size_t times = 0;
while ((pos = s->find(sub, pos)) != std::string::npos) {
if ((pos == 0 || !IsPrefix(s->at(pos - 1))) &&
(pos + sub.length() == s->size() ||
!IsSuffix(s->at(pos + sub.length())))) {
pos += sub.length();
times++;
} else {
pos++;
}
}
return times;
}
bool IsPrefix(const char &c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_');
}
bool IsSuffix(const char &c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_') ||
(c >= '0' && c <= '9') || (c == '\'');
}
std::string TransValidVarName(std::string name) {
Replace(&name, ".", "__");
Replace(&name, "/", "___");
name.erase(0, name.find_first_not_of("_"));
return name;
}
} // namespace infrt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include <vector>
namespace infrt {
namespace infrt {
//! Get the content of a stream.
template <typename T>
std::string GetStreamCnt(const T& x);
/**
* Construct a formatted string with arguments.
* @param fmt_str The format.
* @param ... The parameters of the format.
* @return The formated string.
*/
std::string StringFormat(const std::string& fmt_str, ...);
/**
* Join multiple fields to a single string. Similar to Python's str.join method.
*/
template <typename T = std::string>
std::string Join(const std::vector<T>& fields, const std::string& splitter) {
if (fields.empty()) return "";
std::stringstream ss;
for (int i = 0; i < fields.size() - 1; i++) ss << fields[i] << splitter;
ss << fields.back();
return ss.str();
}
std::vector<std::string> Split(const std::string& str,
const std::string& splitter);
std::string Trim(const std::string& s, const char* empty = " \n\r\t");
//! Convert a string to its uppercase.
std::string Uppercase(const std::string& x);
//! Replace a substr 'from' to 'to' in string s.
void Replace(std::string* s, const std::string& from, const std::string& to);
//! Count how many times substr 'sub' appears in string s.
size_t Count(std::string* s, const std::string& sub);
//! Tell if a char is prefix of a tensor's name.
bool IsPrefix(const char& c);
//! Tell if a char is suffix of a tensor's name.
bool IsSuffix(const char& c);
//! Tell if a string \p x start with \p str.
bool Startswith(const std::string& x, const std::string& str);
//! Tell if a string \p x ends with \p str.
bool Endswith(const std::string& x, const std::string& str);
template <typename T>
std::string GetStreamCnt(const T& x) {
std::stringstream os;
os << x;
return os.str();
}
std::string TransValidVarName(std::string name);
} // namespace infrt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/target.h"
#include <glog/logging.h>
namespace infrt {
namespace common {
bool Target::operator==(const Target &other) const {
return os == other.os && //
arch == other.arch && //
bits == other.bits && //
features == other.features;
}
int Target::max_num_threads() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get max number of threads.";
return 1024;
}
std::vector<Target::Lib> Target::get_target_libs() const { return libs; }
int Target::get_target_bits() const {
switch (bits) {
case Bit::k32:
return 32;
case Bit::k64:
return 64;
case Bit::Unk:
return 0;
default:
LOG(FATAL) << "Not supported Bit";
}
return -1;
}
std::ostream &operator<<(std::ostream &os, const Target &target) {
os << "Target<";
switch (target.os) {
case Target::OS::Linux:
os << "linux";
break;
case Target::OS::Windows:
os << "windows";
break;
case Target::OS::Unk:
os << "unk";
break;
}
os << ",";
switch (target.arch) {
case Target::Arch::X86:
os << "x86";
break;
case Target::Arch::ARM:
os << "arm";
break;
case Target::Arch::NVGPU:
os << "nvgpu";
break;
case Target::Arch::Unk:
os << "unk";
break;
}
os << ",";
switch (target.bits) {
case Target::Bit::k32:
os << "32";
break;
case Target::Bit::k64:
os << "64";
break;
case Target::Bit::Unk:
os << "unk";
break;
}
os << ">";
return os;
}
std::ostream &operator<<(std::ostream &os, Target::Arch arch) {
switch (arch) {
case Target::Arch::Unk:
os << "Unk";
break;
case Target::Arch::X86:
os << "X86";
break;
case Target::Arch::ARM:
os << "ARM";
break;
case Target::Arch::NVGPU:
os << "NVGPU";
break;
}
return os;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <vector>
namespace infrt {
namespace common {
struct Target {
/**
* The operating system used by the target. Determines which system calls to
* generate.
*/
enum class OS : int {
Unk = -1,
Linux,
Windows,
};
/**
* The architecture used by the target. Determines the instruction set to use.
*/
enum class Arch : int {
Unk = -1,
X86,
ARM,
NVGPU,
};
enum class Bit : int {
Unk = -1,
k32,
k64,
};
OS os{OS::Unk};
Arch arch{Arch::Unk};
Bit bits{Bit::Unk};
enum class Feature : int {
JIT = 0,
Debug,
};
/**
* The library used by the target.
*/
enum class Lib : int {
Unk = -1,
MKL,
};
std::vector<Feature> features;
std::vector<Lib> libs;
explicit Target(OS o = OS::Linux,
Arch a = Arch::Unk,
Bit b = Bit::Unk,
const std::vector<Feature>& features = {},
const std::vector<Lib>& libs = {})
: os(o), arch(a), bits(b), features(features), libs(libs) {}
bool defined() const {
return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk;
}
int max_num_threads() const;
int get_target_bits() const;
std::vector<Lib> get_target_libs() const;
bool operator==(const Target& other) const;
bool operator!=(const Target& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& os, const Target& target);
};
static const Target& UnkTarget() {
static Target target(
Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {});
return target;
}
static const Target& DefaultHostTarget() {
static Target target(
Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {});
return target;
}
static const Target& DefaultNVGPUTarget() {
static Target target(
Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {});
return target;
}
std::ostream& operator<<(std::ostream& os, Target::Arch arch);
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/type.h"
#include <utility>
namespace infrt {
namespace common {
struct Type::Storage {
Storage() = default;
Storage(type_t t, int b, int w) : type_(t), bits_(b), lanes_(w) {}
type_t type_{type_t::Unk};
cpp_type_t cpp_type_{cpp_type_t::None};
//! How many bits per element.
int bits_{};
//! How many elements(if a vector type), for scalar types, it should be 1.
int lanes_{1};
//! Name of the customized type.
std::string customized_type_;
};
Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) {
if (t.is_cpp_const()) os << "const ";
switch (t.type()) {
case Type::type_t::Int:
if (t.bits() == 1) {
os << "bool";
} else {
os << "int" << t.bits();
}
break;
case Type::type_t::UInt:
os << "uint" << t.bits();
break;
case Type::type_t::Float:
os << "float" << t.bits();
break;
case Type::type_t::Void:
os << "void";
break;
case Type::type_t::Customized:
os << t.customized_type();
break;
case Type::type_t::String:
os << "string";
break;
case Type::type_t::Unk:
os << "unk";
break;
}
if (t.lanes() > 1) os << "<" << t.lanes() << ">";
if (t.is_cpp_handle()) os << "*";
if (t.is_cpp_handle2()) os << "**";
return os;
}
std::ostream &operator<<(std::ostream &os, Type::type_t t) {
switch (t) {
case Type::type_t::String:
os << "String";
break;
case Type::type_t::Void:
os << "Void";
break;
case Type::type_t::UInt:
os << "UInt";
break;
case Type::type_t::Int:
os << "Int";
break;
case Type::type_t::Float:
os << "Float";
break;
case Type::type_t::Unk:
os << "Unk";
break;
case Type::type_t::Customized:
os << "Customized";
}
return os;
}
Type &Type::set_cpp_handle(bool x) {
// unset the other handle-related bits.
set_cpp_handle2(false);
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));
// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
if (x)
v |= static_cast<uint8_t>(cpp_type_t::Handle);
else
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
return *this;
}
Type &Type::set_cpp_handle2(bool x) {
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));
// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
if (x)
v |= static_cast<uint8_t>(cpp_type_t::HandleHandle);
else
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
return *this;
}
Type Type::VectorOf(int w) const {
CheckTypeValid();
return Type(type(), w, bits());
}
Type::Type(const Type &other) {
if (other.storage_) storage_.reset(new Storage(*other.storage_));
}
Type Type::ElementOf() const {
CheckTypeValid();
auto type = *this;
type.storage_->lanes_ = 1;
return type;
}
void Type::CheckTypeValid() const { CHECK_NE(GetStorage().type_, type_t::Unk); }
Type Type::PointerOf() const {
CheckTypeValid();
auto x = *this;
CHECK(!x.is_cpp_handle2()) << "Not support three level of PointerOf";
if (x.is_cpp_handle())
x.set_cpp_handle2();
else
x.set_cpp_handle();
return x;
}
Type Type::ConstOf() const {
CheckTypeValid();
auto x = *this;
x.set_cpp_const();
return x;
}
Type Type::IgnoreConst() const {
CheckTypeValid();
auto x = *this;
x.set_cpp_const(false);
return x;
}
Type Type::with_bits(int x) const {
CHECK(is_primitive());
Type type = *this;
type.GetStorage().bits_ = x;
return type;
}
Type Type::with_type(Type::type_t x) const {
Type type = *this;
type.GetStorage().type_ = x;
return type;
}
Type Type::with_lanes(int x) const {
CHECK(valid());
Type type = *this;
type.GetStorage().lanes_ = x;
return type;
}
Type Type::with_cpp_const(bool x) const {
Type type = *this;
type.set_cpp_const(x);
return type;
}
Type &Type::set_cpp_const(bool is_const) {
uint8_t &data = *reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_);
if (is_const) {
data |= static_cast<uint8_t>(cpp_type_t::Const);
} else {
data &= ~(static_cast<uint8_t>(cpp_type_t::Const));
}
return *this;
}
Type &Type::set_customized_type(const std::string &t) {
GetStorage().type_ = type_t::Customized;
GetStorage().customized_type_ = t;
return *this;
}
bool Type::valid() const {
if (is_unk()) return false;
if (is_customized()) {
return !GetStorage().customized_type_.empty();
}
if (is_primitive()) {
return bits() != 0;
}
return true;
}
Type::Type(Type::type_t t, int b, int w) : storage_(new Storage(t, b, w)) {}
bool Type::is_primitive() const {
return !is_unk() && type() != type_t::Customized;
}
bool Type::is_customized() const {
return !is_unk() && type() == type_t::Customized;
}
bool Type::is_unk() const { return type() == type_t::Unk; }
bool Type::is_bool() const { return type() == type_t::UInt && bits() == 1; }
bool Type::is_void() const { return type() == type_t::Void; }
bool Type::is_vector() const { return lanes() > 1; }
bool Type::is_scalar() const { return lanes() == 1; }
bool Type::is_float(int bits) const {
return type() == type_t::Float && (bits < 0 || bits == this->bits());
}
bool Type::is_uint(int bits) const {
return type() == type_t::UInt && (bits < 0 || bits == this->bits());
}
bool Type::is_int(int bits) const {
return type() == type_t::Int && (bits < 0 || bits == this->bits());
}
bool Type::is_integer(int bits) const {
return (type() == type_t::Int || type() == type_t::UInt) &&
(bits < 0 || bits == this->bits());
}
bool Type::is_index_type() {
return is_int() && lanes() == 1 && (bits() == 32 || bits() == 64);
}
bool Type::is_cpp_handle() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) &
static_cast<uint8_t>(cpp_type_t::Handle);
}
bool Type::is_cpp_handle2() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) &
static_cast<uint8_t>(cpp_type_t::HandleHandle);
}
bool Type::is_cpp_const() const {
return static_cast<uint8_t>(cpp_type_t::Const) &
static_cast<uint8_t>(GetStorage().cpp_type_);
}
const std::string &Type::customized_type() const {
return GetStorage().customized_type_;
}
bool Type::is_customized_type() const {
return !GetStorage().customized_type_.empty();
}
Type::type_t Type::type() const { return GetStorage().type_; }
int Type::bits() const { return GetStorage().bits_; }
int Type::lanes() const { return GetStorage().lanes_; }
Type::cpp_type_t Type::cpp_type() const { return GetStorage().cpp_type_; }
bool Type::operator==(const Type &other) const {
return type() == other.type() && bits() == other.bits() &&
lanes() == other.lanes() &&
GetStorage().cpp_type_ == other.GetStorage().cpp_type_ &&
customized_type() == other.customized_type();
}
bool Type::is_string() const { return type() == type_t::String; }
Type &Type::operator=(const Type &other) {
if (other.storage_) storage_.reset(new Storage(*other.storage_));
return *this;
}
Type::Storage &Type::GetStorage() { return *storage_; }
const Type::Storage &Type::GetStorage() const { return *storage_; }
Type::Type() : storage_(new Storage) {}
Type::Type(Type &&other) : storage_(std::move(other.storage_)) {}
const Type &F16() {
static auto t = Float(16);
return t;
}
const Type &F32() {
static auto t = Float(32);
return t;
}
const Type &F64() {
static auto t = Float(64);
return t;
}
const Type &I8() {
static auto t = Int(8);
return t;
}
const Type &I16() {
static auto t = Int(16);
return t;
}
const Type &I32() {
static auto t = Int(32);
return t;
}
const Type &I64() {
static auto t = Int(64);
return t;
}
const Type &UI8() {
static auto t = UInt(8);
return t;
}
const Type &UI16() {
static auto t = UInt(16);
return t;
}
const Type &UI32() {
static auto t = UInt(32);
return t;
}
const Type &UI64() {
static auto t = UInt(64);
return t;
}
const Type &I1() {
static auto t = Int(1);
return t;
}
const Type &UI1() {
static auto t = UInt(1);
return t;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include <string>
#include "paddle/infrt/common/macros.h"
//! Much of the concepts are borrowed from Halide project.
namespace infrt {
namespace common {
/**
* Types in the INFRT type system. They can be ints, unsigned ints, or floats of
* various bit-widths.
* They can also be vectors of the same (by setting the `lanes` field to
* something larger than one).
* NOTE: Front-end code other than vectorize shouldn't use vector types.
*/
struct Type {
enum class type_t {
Unk = -1,
Int,
UInt,
Float,
String,
Void,
// stupid idea to mix the Customized with other primitive types, large
// refactor needs here.
Customized, // Customized type
};
//! type decorators in C++, the different code can used together.
enum class cpp_type_t : uint8_t {
None = 0, // None information.
Const = 1, // const.
Handle = 1 << 1, // pointer type, such as `infrt_buffer_t*`.
HandleHandle = 1 << 2, // pointer of pointer, such as `infrt_buffer_t**`.
};
Type();
Type(type_t t, int b, int w);
Type(const Type& other);
explicit Type(Type&& other);
Type& operator=(const Type& other);
INFRT_NODISCARD bool is_primitive() const;
INFRT_NODISCARD bool is_customized() const;
INFRT_NODISCARD bool valid() const;
//! Some helper functions to check a type.
// @{
INFRT_NODISCARD bool is_unk() const;
INFRT_NODISCARD bool is_void() const;
INFRT_NODISCARD bool is_bool() const;
INFRT_NODISCARD bool is_vector() const;
INFRT_NODISCARD bool is_scalar() const;
INFRT_NODISCARD bool is_float(int bits = -1) const;
INFRT_NODISCARD bool is_int(int bits = -1) const;
INFRT_NODISCARD bool is_integer(int bits = -1) const;
INFRT_NODISCARD bool is_uint(int bits = -1) const;
INFRT_NODISCARD bool is_string() const;
INFRT_NODISCARD bool is_index_type();
// @}
Type& set_cpp_handle(bool x = true);
INFRT_NODISCARD bool is_cpp_handle() const;
Type& set_cpp_handle2(bool x = true);
INFRT_NODISCARD bool is_cpp_handle2() const;
Type& set_cpp_const(bool is_const = true);
INFRT_NODISCARD bool is_cpp_const() const;
Type& set_customized_type(const std::string& t);
const std::string& customized_type() const;
INFRT_NODISCARD bool is_customized_type() const;
// Get a new type with bits set to \p x.
Type with_bits(int x) const;
// Get a new type with type set to \p x.
Type with_type(type_t x) const;
// Get a new type with lanes set to \p x.
Type with_lanes(int x) const;
// Get a new type with cpp_const set to \p x.
Type with_cpp_const(bool x = true) const;
//! Getters
// @{
type_t type() const;
int bits() const;
int lanes() const;
cpp_type_t cpp_type() const;
// @}
//! Compare two types for equality.
bool operator==(const Type& other) const;
//! Compare two types for inequality.
bool operator!=(const Type& other) const { return !(*this == other); }
//! Generate a vector of this type, with `w` elements.
Type VectorOf(int w) const;
//! Generate a element type of this type.
Type ElementOf() const;
//! Generate the address type.
Type PointerOf() const;
//! Ignore const.
Type IgnoreConst() const;
//! Add const.
Type ConstOf() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type();
private:
void CheckTypeValid() const;
struct Storage;
Storage& GetStorage();
const Storage& GetStorage() const;
std::unique_ptr<Storage> storage_;
}; // namespace common
inline Type Void() { return Type(Type::type_t::Void, 1, 0); }
inline Type Int(int bits, int lanes = 1) {
return Type(Type::type_t::Int, bits, lanes);
}
inline Type UInt(int bits, int lanes = 1) {
return Type(Type::type_t::UInt, bits, lanes);
}
inline Type Float(int bits, int lanes = 1) {
return Type(Type::type_t::Float, bits, lanes);
}
inline Type Bool(int lanes = 1) { return Type(Type::type_t::UInt, 1, lanes); }
inline Type String() { return Type(Type::type_t::String, 1, 1); }
//! Builtin native types as global singletons.
// @{
const Type& F16();
const Type& F32();
const Type& F64();
const Type& I8();
const Type& I16();
const Type& I32();
const Type& I64();
const Type& UI8();
const Type& UI16();
const Type& UI32();
const Type& UI64();
const Type& I1();
const Type& UI1();
// @}
template <typename T>
Type type_of();
// clang-format off
template <> inline Type type_of<float>() { return F32(); }
template <> inline Type type_of<double>() { return F64(); }
template <> inline Type type_of<unsigned char>() { return UI8(); }
template <> inline Type type_of<int16_t>() { return UI16(); }
template <> inline Type type_of<int32_t>() { return I32(); }
template <> inline Type type_of<uint32_t>() { return UI32(); }
template <> inline Type type_of<bool>() { return UI1(); }
template <> inline Type type_of<char>() { return I8(); }
template <> inline Type type_of<int64_t>() { return I64(); }
template <> inline Type type_of<uint64_t>() { return UI64(); }
template <> inline Type type_of<signed char>() { return I8(); }
template <> inline Type type_of<void>() { return Void(); }
// clang-format on
template <>
inline Type type_of<int8_t*>() {
Type x = Int(8);
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<void*>() {
Type x = type_of<void>();
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<void**>() {
Type x = type_of<void>();
x.set_cpp_handle2();
return x;
}
template <>
inline Type type_of<float*>() {
Type x = type_of<float>();
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<double*>() {
Type x = type_of<double>();
x.set_cpp_handle();
return x;
}
std::ostream& operator<<(std::ostream& os, Type::type_t t);
} // namespace common
} // namespace infrt
core_gather_headers()
gather_srcs(infrt_src SRCS
dialect.cc
types.cc
basic_kernels.cc
test_kernels.cc
infrt_base.cc
init_infrt_dialects.cc
tensor_shape.cc
dense_tensor.cc
mlir_loader.cc
diagnostic_utils.cc
pd_types.cc
pd_ops.cc
)
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt)
mlir_tablegen_on(tensor_shape DIALECT ts)
mlir_tablegen_on(dense_tensor DIALECT dt)
mlir_tablegen_on(pd_op_base DIALECT pd)
mlir_tablegen_on(pd_ops)
mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs})
add_dependencies(infrtopt infrt)
add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs})
add_dependencies(print-ir pd_ops_inc)
# MLIR opt tests
# %{
set(infrt_opt_path ${CMAKE_BINARY_DIR}/infrt/dialect/infrtopt)
add_test(test_infrt_mlir_opt_on_basic ${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/basic.mlir)
add_test(test_infrt_mlir_opt_on_tensor_shape ${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/tensor_shape.mlir)
add_test(test_infrt_mlir_opt_on_paddle_ops
${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/paddle_ops.mlir)
# %}
cc_test_tiny(test_infrt_mlir_loader SRCS mlir_loader_test.cc DEPS infrt ${MLIR_IR_LIBS})
# execute mlir and run FileCheck
infrt_exec_check(run_and_check_tensor_type mlir_tests/tensor_type.mlir)
infrt_exec_check(run_and_check_basic mlir_tests/basic.mlir)
infrt_exec_check(run_and_check_benchmark mlir_tests/benchmark.mlir)
#infrt_exec_check(run_and_check_dense_tensor mlir_tests/dense_tensor.mlir)
add_test(test_infrt_mlir_dense_tensor
${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec
-i
${CMAKE_CURRENT_SOURCE_DIR}/mlir_tests/dense_tensor.mlir)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/basic_kernels.h"
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect {
using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SymbolRefAttr callee_attr;
FunctionType callee_type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto callee_loc = parser.getNameLoc();
if (parser.parseAttribute(callee_attr, "callee", result.attributes) ||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(callee_type) ||
parser.addTypesToList(callee_type.getResults(), result.types) ||
parser.resolveOperands(
operands, callee_type.getInputs(), callee_loc, result.operands))
return failure();
return success();
}
static ParseResult parseConstantOp(Type attrType,
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(valueAttr, attrType, "value", result.attributes) ||
parser.addTypeToList(attrType, result.types))
return failure();
return success();
}
static ParseResult parseConstantF32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
FloatType::getF32(result.getContext()), parser, result);
}
static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
FloatType::getF64(result.getContext()), parser, result);
}
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result);
}
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result);
}
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands));
}
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "(";
p.printOperands(op.getOperands());
p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
p << " : ";
}
static void printConstant(OpAsmPrinter &p, mlir::Operation *op) { // NOLINT
p << op->getName() << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
if (op->getAttrs().size() > 1) p << ' ';
Attribute attr = op->getAttr("value");
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
bool is_signed = int_attr.getType().isIndex() ||
int_attr.getType().getIntOrFloatBitWidth() != 1;
int_attr.getValue().print(p.getStream(), is_signed);
} else if (auto float_attr = attr.dyn_cast<FloatAttr>()) {
p << float_attr.getValue().convertToFloat();
} else {
op->emitOpError("unknown attribute type");
}
}
static void print(OpAsmPrinter &p, ConstantF32Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantF64Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantI32Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantI64Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ReturnOp op) { // NOLINT
p << "infrt.return";
if (op.getNumOperands() > 0) {
p << ' ';
p.printOperands(op.getOperands());
p << " : ";
llvm::interleaveComma(op.getOperands(), p);
}
}
static LogicalResult verify(CallOp op) { return success(); }
static LogicalResult verify(ConstantF32Op op) { return success(); }
static LogicalResult verify(ConstantI32Op op) { return success(); }
static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp());
if (!function) return success();
auto results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError("has ")
<< op.getNumOperands()
<< " operands, but enclosing function returns " << results.size();
return success();
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
// Operation definitions for basic kernels.
#ifdef BASIC_OPS
#else
#define BASIC_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> : Op<INFRT_Dialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
// Each registered op needs to provide all of a printer, parser and verifier.
let printer = [{ return infrt::dialect::print(p, *this); }];
let verifier = [{ return infrt::dialect::verify(*this); }];
let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }];
}
def CallOp : INFRT_Op<"call"> {
let summary = "call a host operation";
let description = [{
The "infrt.call" operation represents a direct call to a function. The operands and result types of the call must match the specified function type.
%2 = infrt.call @add(%0, %1) : (f32, f32) -> f32
}];
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType();
}];
}
class ConstantOp<string suffix, Type baseType, Attr attr>
: INFRT_Op<"constant." # suffix, [NoSideEffect]> {
let summary = "constant value constructor in host";
let arguments = (ins attr:$value);
let results = (outs baseType);
}
def ConstantI32Op : ConstantOp<"i32", I32, I32Attr>;
def ConstantI64Op : ConstantOp<"i64", I64, I64Attr>;
def ConstantF32Op : ConstantOp<"f32", F32, F32Attr>;
def ConstantF64Op : ConstantOp<"f64", F64, F64Attr>;
def ReturnOp : INFRT_Op<"return", [Terminator]> {
let summary = "host executor return operation";
let description = [{
The "infrt.return" operation represents a return operation within a function.
func @foo() : (i32, f8) {
infrt.return %0, %1 : i32, f8
}
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result",
[{ build(b, result, llvm::None); }]>];
}
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
let summary = "infrt.add operation";
let description = [{
An operation that takes two inputs and returns their sum as result.
}];
let arguments = (ins type, type);
let results = (outs type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
def AddI32Op : AddOp<"i32", I32>;
def AddI64Op : AddOp<"i64", I64>;
def AddF32Op : AddOp<"f32", F32>;
def AddF64Op : AddOp<"f64", F64>;
class MulOp<string suffix, Type type> : INFRT_Op<"mul." # suffix, [NoSideEffect]> {
let summary = "infrt.mul operation";
let description = [{
An operation that takes two inputs and returns their mul as result.
}];
let arguments = (ins type, type);
let results = (outs type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
def MulI32Op : MulOp<"i32", I32>;
def MulI64Op : MulOp<"i64", I64>;
def MulF32Op : MulOp<"f32", F32>;
def MulF64Op : MulOp<"f64", F64>;
class PrintOp<string suffix, Type type> : INFRT_Op<"print." # suffix> {
let summary = "infrt.print operation";
let description = [{
An operation takes a number as input and prints to stdout.
}];
let arguments = (ins type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
//def PrintI32Op : PrintOp<"i32", I32>;
//def PrintI64Op : PrintOp<"i64", I64>;
def PrintF32Op : PrintOp<"f32", F32>;
//def PrintF64Op : PrintOp<"f64", F64>;
def GetStringOp : INFRT_Op<"get_string"> {
let summary = "infrt.get_string";
let description = [{
Get a !infrt.string value from the given string attribute.
}];
let arguments = (ins StrAttr:$value);
let results = (outs StringType);
let assemblyFormat = "`(` $value `)` attr-dict";
let verifier = ?;
}
def PrintStringOp : INFRT_Op<"print_string"> {
let summary = "infrt.print_string";
let description = [{
An operation that prints a string.
}];
let arguments = (ins StringType:$input);
let results = (outs);
let assemblyFormat = "`(` $input `)` attr-dict";
let verifier = ?;
}
#endif // basic kernels
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/dense_tensor.h"
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <tuple>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt {
void DTDialect::initialize() {
allowUnknownTypes();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>();
}
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86"))
return TargetType::X86;
else if (key.equals_lower("cuda"))
return TargetType::CUDA;
else
return llvm::None;
}
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw"))
return LayoutType::NCHW;
else if (key.equals_lower("nhwc"))
return LayoutType::NHWC;
else
return llvm::None;
}
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32"))
return PrecisionType::I32;
else if (key.equals_lower("f32"))
return PrecisionType::F32;
else
return llvm::None;
}
TensorType TensorType::get(TargetType target,
LayoutType layout,
PrecisionType precision) {
return Base::get(
::infrt::Global::getMLIRContext(), target, layout, precision);
}
TargetType TensorType::target() { return getImpl()->target_; }
LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return os;
}
TensorMapType TensorMapType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
TensorMapType TensorMapType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
StringType StringType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
raw_ostream &operator<<(raw_ostream &os, TargetType type) {
switch (type) {
case (TargetType::X86):
os << "X86";
break;
case (TargetType::CUDA):
os << "CUDA";
break;
default:
os << "Unsupported";
}
return os;
}
raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
switch (type) {
case (LayoutType::NCHW):
os << "NCHW";
break;
case (LayoutType::NHWC):
os << "NHWC";
break;
default:
os << "Unsupported";
}
return os;
}
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
switch (type) {
case (PrecisionType::I32):
os << "I32";
break;
case (PrecisionType::F32):
os << "F32";
break;
default:
os << "Unsupported";
}
return os;
}
static Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context);
}
static ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(),
"shape",
result.attributes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
if (parser.parseArrow()) return failure();
if (parser.parseType(outputRawTypes[0])) return failure();
if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes);
return success();
}
template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName();
p << " ";
p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> ";
p << op.getOperation()->getResultTypes();
}
// TODO(shibo): can be removed?
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser,
// OperationState& result) {
// auto loc = parser.getCurrentLocation();
// ::mlir::OpAsmParser::OperandType inputRawOperands[1];
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
auto tensor_type = getTensorType(result.getContext());
Attribute value_attr;
return failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes));
}
template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand());
p << " " << op.getAttr("values");
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Dialect.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <string>
using namespace mlir; // NOLINT
namespace infrt::dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 };
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type);
class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type,
detail::TensorTypeStorage> {
public:
using Base::Base;
static TensorType get(TargetType target,
LayoutType layout,
PrecisionType precision);
TargetType target();
LayoutType layout();
PrecisionType precision();
};
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type,
mlir::TypeStorage> {
public:
using Base::Base;
static TensorMapType get();
static TensorMapType get(mlir::MLIRContext *context);
};
class StringType
: public mlir::Type::TypeBase<StringType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
static StringType get();
static StringType get(mlir::MLIRContext *context);
};
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
#ifdef DT_OPS
#else
#define DT_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/tensor_shape_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def DT_Dialect : Dialect {
let name = "dt";
let description = [{
The DenseTensor dialect.
}];
let cppNamespace = "::infrt::dt";
}
class DT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<DT_Dialect, mnemonic, traits>;
class CreateUninitTensorOp<string dtype>
: DT_Op<"create_uninit_tensor." # dtype, [NoSideEffect]> {
let summary = "dt.create_uninit_tensor operation";
let description = [{
An operation that creates an uninitialized tensor.
}];
let arguments = (ins I64ArrayAttr:$shape);
let results = (outs TensorType:$output);
let parser = [{ return infrt::dt::parseCreateUninitTensorOp(parser, result); }];
let printer = [{ return infrt::dt::printCreateUninitTensorOp(p, *this); }];
}
def ShallowCopyTensorOp
: DT_Op<"shallow_copy_tensor", [NoSideEffect]> {
let summary = "dt.shallow_copy_tensor operation";
let description = [{
An operation that copy a tensor shallowly.
}];
let arguments = (ins TensorType:$input);
let results = (outs TensorType:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
}
class FillTensorWithConstantOp<string dtype> :
DT_Op<"fill_tensor_with_constant." # dtype> {
let summary = "dt.fill_tensor_with_constant operation";
let description = [{
An operation that fills an input tensor with a value.
}];
let arguments = (ins
TensorType:$input,
AnyAttr:$value
);
let results = (outs);
// TODO: can be removed?
//let parser = [{ return infrt::dt::parseFillTensorWithConstantOp(parser, result); }];
//let printer = [{ return infrt::dt::printFillTensorWithConstantOp(p, *this); }];
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
def PrintTensorOp : DT_Op<"print_tensor"> {
let summary = "dt.print_tensor operation";
let description = [{
An operation that prints a tensor.
}];
let arguments = (ins TensorType:$input);
let results = (outs);
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
class SetTensorOp<string dtype> :
DT_Op<"set_tensor_with_constant_values." # dtype> {
let summary = "dt.set_tensor_with_constant_values operation";
let description = [{
An operation that sets an input tensor with given values.
}];
let arguments = (ins TensorType);
let results = (outs);
let parser = [{ return infrt::dt::parseSetTensorOp(parser, result); }];
let printer = [{ return infrt::dt::printSetTensorOp(p, *this); }];
}
def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> {
let summary = "dt.load_params operation";
let description = [{
An operation that can load tensors to TensorMap.
}];
// input path of model params.
let arguments = (ins StringType:$path);
let results = (outs TensorMapType);
let assemblyFormat = "`(` operands `)` attr-dict";
let verifier = ?;
}
def GetParamOp : DT_Op<"get_param", [NoSideEffect]> {
let summary = "dt.get_param operation";
let description = [{
An operation that can get a tensor from TensorMap.
}];
// input path of model params.
let arguments = (ins
TensorMapType:$map,
StrAttr:$name
);
let results = (outs TensorType:$output);
let assemblyFormat = "`(` $map `,` $name `)` attr-dict `->` type($output)";
let verifier = ?;
}
def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> {
let summary = "dt.get_tensor_shape operation";
let description = [{
An operation that returns the shape of the input tensor.
}];
let arguments = (ins TensorType:$input);
let results = (outs TS_Shape:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
}
foreach dtype = ["ui8", "ui16", "ui32", "ui64", "i32", "f32", "f64", "i64"] in {
def DT_CreateUninitTensorOp_#dtype : CreateUninitTensorOp<dtype>;
def DT_FillTensorOp_#dtype : FillTensorWithConstantOp<dtype>;
def DT_SetTensorOp_#dtype : SetTensorOp<dtype>;
}
#endif // DT_OPS
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include <string>
namespace infrt::dialect {
struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {}
// String stream to assemble the final error message.
std::string diag_str_;
llvm::raw_string_ostream diag_stream_;
// A SourceMgr to use for the base handler class.
llvm::SourceMgr source_mgr_;
// Log detail information.
bool log_info_{};
};
MyScopedDiagnosicHandler::MyScopedDiagnosicHandler(mlir::MLIRContext *ctx,
bool propagate)
: mlir::SourceMgrDiagnosticHandler(
impl_->source_mgr_, ctx, impl_->diag_stream_),
impl_(new Impl) {
setHandler([this](mlir::Diagnostic &diag) { return this->handler(&diag); });
}
mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
if (diag->getSeverity() != mlir::DiagnosticSeverity::Error &&
!impl_->log_info_)
return mlir::success();
emitDiagnostic(*diag);
impl_->diag_stream_.flush();
return mlir::failure(true);
}
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Diagnostics.h>
#include <memory>
namespace infrt::dialect {
/**
* A scoped diagnostic handler to help debug MLIR process.
*/
class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
public:
MyScopedDiagnosicHandler(mlir::MLIRContext* ctx, bool propagate);
mlir::LogicalResult handler(mlir::Diagnostic* diag);
~MyScopedDiagnosicHandler();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect {
class CinnDialect : public ::mlir::Dialect {
public:
explicit CinnDialect(::mlir::MLIRContext* ctx);
//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect";
}
};
} // namespace infrt::hlir::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect {
// ----INFRTDialect definition begin----
void INFRTDialect::initialize() {
allowUnknownTypes();
allowUnknownOperations();
addTypes<infrt::dt::StringType>();
addTypes<infrt::dt::TensorType>();
addTypes<infrt::dt::TensorMapType>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
>();
}
mlir::Type INFRTDialect::parseType(mlir::DialectAsmParser &parser) const {
llvm::StringRef keyword;
if (parser.parseKeyword(&keyword)) return mlir::Type();
// parse TensorType, for example: !infrt.tensor<X86, CUDA, F32>
if (keyword == "tensor") {
llvm::StringRef target;
llvm::StringRef layout;
llvm::StringRef precision;
// parse "<"
if (parser.parseLess()) return mlir::Type();
// parse target
if (parser.parseKeyword(&target)) return mlir::Type();
auto targetType = infrt::dt::GetTargetType(target);
if (!targetType) {
parser.emitError(parser.getCurrentLocation(), "unknown target type: ")
<< target;
return mlir::Type();
}
// parse ","
if (parser.parseComma()) return mlir::Type();
// parse layout
if (parser.parseKeyword(&layout)) return mlir::Type();
auto layoutType = infrt::dt::GetLayoutType(layout);
if (!layoutType) {
parser.emitError(parser.getCurrentLocation(), "unknown layout type: ")
<< layout;
return mlir::Type();
}
// parse ","
if (parser.parseComma()) return mlir::Type();
// parse precision
if (parser.parseKeyword(&precision)) return mlir::Type();
auto precisionType = infrt::dt::GetPrecisionType(precision);
if (!precisionType) {
parser.emitError(parser.getCurrentLocation(), "unknown precision type: ")
<< precision;
return mlir::Type();
}
// parse ">"
if (parser.parseGreater()) return mlir::Type();
return infrt::dt::TensorType::get(*targetType, *layoutType, *precisionType);
}
// parse TensorMapType, for example: !infrt.tensor_map
if (keyword == "tensor_map") {
return infrt::dt::TensorMapType::get();
}
// parse StringType, for example: !infrt.string
if (keyword == "string") {
return infrt::dt::StringType::get();
}
parser.emitError(parser.getCurrentLocation(), "unknown infrt type: ")
<< keyword;
return mlir::Type();
}
void INFRTDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// print TensorType, for example: !infrt.tensor<X86, CUDA, F32>
if (type.isa<infrt::dt::TensorType>()) {
auto tensorType = type.cast<infrt::dt::TensorType>();
printer << "tensor<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return;
}
// print TensorMapType, for example: !infrt.tensor_map
if (type.isa<infrt::dt::TensorMapType>()) {
printer << "tensor_map";
return;
}
// print StringType, for example: !infrt.string
if (type.isa<infrt::dt::StringType>()) {
printer << "string";
return;
}
llvm_unreachable("unknown infrt type.");
}
// ----INFRTDialect definition end----
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect {
class INFRTDialect : public ::mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(),
context,
::mlir::TypeID::get<INFRTDialect>()) {
initialize();
}
// parse types registered to the dialect.
mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
// print types registered to the dialect.
void printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const override;
void initialize();
friend class ::mlir::MLIRContext;
public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
};
} // namespace infrt::dialect
namespace mlir {
template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
mlir::Location loc,
T constant) {
return b.getIntegerAttr(b.getI32Type(), constant);
}
static mlir::ValueRange cvtValueToValueRange(const mlir::Value &operand) {
return mlir::ValueRange(operand);
}
static mlir::ValueRange concatTwoValueRange(mlir::ValueRange operand_0,
mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end());
return operands;
}
} // namespace mlir
#ifndef INFRT_BASE
#define INFRT_BASE
include "mlir/IR/OpBase.td"
def INFRT_Dialect : Dialect {
let name = "infrt";
let description = [{
The INFRT host dialect.
}];
let cppNamespace = "::infrt::dialect";
}
// Type definitions
def StringType :
Type<CPred<"$_self.isa<::infrt::dt::StringType>()">, "!infrt.string type">,
BuildableType<"$_builder.getType<::infrt::dt::StringType>()">;
def TensorType :
Type<CPred<"$_self.isa<::infrt::dt::TensorType>()">, "!infrt.tensor type">;
def TensorMapType :
Type<CPred<"$_self.isa<::infrt::dt::TensorMapType>()">, "!infrt.tensor_map type">,
BuildableType<"$_builder.getType<::infrt::dt::TensorMapType>()">;
def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">;
class IsBoolAttrEq<string value> : Constraint<
CPred<"($0.getValue() ==" # value # ")">,
"Bool attrbute value constraint">;
#endif // INFRT_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include <glog/logging.h>
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT
registry.insert<ts::TensorShapeDialect>();
registry.insert<dialect::INFRTDialect>();
registry.insert<dt::DTDialect>();
registry.insert<mlir::pd::PaddleDialect>();
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Dialect.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/mlir_loader.h"
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <unordered_map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
return mlir::success();
LOG(INFO) << "diag: " << diag.str();
return mlir::failure(true);
});
auto res = mlir::parseSourceString(
llvm::StringRef(mlir_source.data(), mlir_source.length()), context);
CHECK(*res) << "failed to parse MLIR string";
return res;
}
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) {
context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
return mlir::success();
LOG(INFO) << "diag: " << diag.str();
return mlir::failure(true);
});
return mlir::parseSourceFile(std::string(file_name), context);
}
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <mlir/IR/Module.h>
#include <string>
#include <memory>
namespace infrt::dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context);
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/mlir_loader.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h>
#include <mlir/Parser.h>
#include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
TEST(MlirLoader, basic) {
mlir::MLIRContext context;
auto source = R"ROC(
func @main() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
"infrt.print.f32"(%v0) : (f32) -> ()
infrt.return %value : f32
}
)ROC";
auto module = LoadMlirSource(&context, source);
module->verify();
LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str();
int num_args = func.getNumArguments();
for (int i = 0; i < num_args; i++) {
LOG(INFO) << "arg: " << func.getArgument(i).getArgNumber();
}
}
}
} // namespace infrt::dialect
// CHECK-LABEL: @basic_f32
func @basic_f32() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
// CHECK-NEXT: 3
"infrt.print.f32"(%value) : (f32) -> ()
infrt.return %value : f32
}
/// ================================================================
/// @caller call the other function @callee
func @callee.add.f32(%x : f32, %y : f32, %y1 : f32) -> f32 {
%z = "infrt.add.f32"(%x, %y) : (f32, f32) -> f32
%z1 = "infrt.add.f32"(%z, %y1) : (f32, f32) -> f32
infrt.return %z1 : f32
}
// CHECK-LABEL: @caller.add.f32
func @caller.add.f32() -> f32 {
%x = infrt.constant.f32 1.0
%y = infrt.constant.f32 2.0
%y1 = infrt.constant.f32 3.0
%z = infrt.call @callee.add.f32(%x, %y, %y1) : (f32, f32, f32) -> f32
// CHECK-NEXT: 6
"infrt.print.f32"(%z) : (f32) -> ()
infrt.return %z : f32
}
/// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// CHECK-LABEL: @string_test
func @string_test() {
%path = infrt.get_string("this is get_string op.")
// CHECK-LABEL: string = this is get_string op.
infrt.print_string(%path)
infrt.return
}
// CHECK-LABEL: @benchmark
func @benchmark() {
// CHECK-LABEL: BM:add.f32:Count: 3
// CHECK-LABEL: BM:add.f32:Duration(ns)
// CHECK-LABEL: BM:add.f32:Time Min(ns)
// CHECK-LABEL: BM:add.f32:Time 50%(ns)
// CHECK-LABEL: BM:add.f32:Time 95%(ns)
// CHECK-LABEL: BM:add.f32:Time 99%(ns)
// CHECK-LABEL: BM:add.f32:CPU Min(ns)
// CHECK-LABEL: BM:add.f32:CPU 50%(ns)
// CHECK-LABEL: BM:add.f32:CPU 95%(ns)
// CHECK-LABEL: BM:add.f32:CPU 99%(ns)
// CHECK-LABEL: BM:add.f32:CPU utilization(percent)
infrt.benchmark "add.f32"() duration_secs = 1, max_count = 3, num_warmup_runs = 3
{
%0 = infrt.constant.f32 1.0
%1 = infrt.constant.f32 2.0
%res = "infrt.add.f32"(%0, %1) : (f32, f32) -> f32
"infrt.print.f32"(%res) : (f32) -> ()
infrt.return %res : f32
}
infrt.return
}
func @dense_shape0() {
%shape = ts.build_shape [1:i64, 57:i64]
%a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor<X86, NCHW, F32>
infrt.return
}
func @predict(%a: !infrt.tensor<X86, NCHW, F32>, %b: !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) {
%a0 = dt.shallow_copy_tensor %a : !infrt.tensor<X86, NCHW, F32> -> !infrt.tensor<X86, NCHW, F32>
%b0 = dt.shallow_copy_tensor %b : !infrt.tensor<X86, NCHW, F32> -> !infrt.tensor<X86, NCHW, F32>
infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>
}
func @main() {
%shape = ts.build_shape [1:i64, 57:i64]
%a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor<X86, NCHW, F32>
%b, %c = infrt.call @predict(%a, %a) : (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
func @ops() {
%a = pd.Feed() : tensor<?xf32>
%b = pd.Feed() : tensor<?xf32>
%c = "pd.Matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
infrt.return
}
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.Feed"() : () -> tensor<?xf32>
%b = "pd.Feed"() : () -> tensor<?xf32>
%bias = "pd.Feed"() : () -> tensor<?xf32>
%b1 = "pd.Feed"() : () -> tensor<?xf32>
%b2 = "pd.Feed"() : () -> tensor<?xf32>
%bias1 = "pd.Feed"() : () -> tensor<?xf32>
%bias2 = "pd.Feed"() : () -> tensor<?xf32>
%c = "pd.Matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.ElementwiseAdd"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.Relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
%c1 = "pd.Matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d1 = "pd.ElementwiseAdd"(%c1, %bias1) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e1 = "pd.Relu"(%d1) {} : (tensor<?xf32>) -> tensor<?xf32>
%c2 = "pd.Matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.ElementwiseAdd"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.Relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32>
}
\ No newline at end of file
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.Feed"() : () -> tensor<?x3x256x256xf32>
%filter = "pd.Constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%bias = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%scale = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%bias2 = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%mean = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%var = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32>
}
\ No newline at end of file
// CHECK-LABEL: @predict
func @predict(%input:!infrt.tensor<X86, NCHW, F32>, %map: !infrt.tensor_map) -> (!infrt.tensor<X86, NCHW, F32>) {
%w = dt.get_param(%map, "create_parameter_0.w_0") -> !infrt.tensor<X86, NCHW, F32>
%bias = dt.get_param(%map, "create_parameter_1.w_0") -> !infrt.tensor<X86, NCHW, F32>
%out = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor<X86, NCHW, F32>
// fc
"external.matmul"(%input, %w, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.sigmoid"(%out, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
//dt.print_tensor (%out : !infrt.tensor<X86, NCHW, F32>)
infrt.return %out : !infrt.tensor<X86, NCHW, F32>
}
// CHECK-LABEL: @main
func @main() {
%input = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
%path = infrt.get_string("/infrt/build/paddle/paddle_1.8_fc_model")
// CHECK-LABEL: loading params
%map = dt.load_params(%path)
%out = infrt.call @predict(%input, %map): (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor_map) -> (!infrt.tensor<X86, NCHW, F32>)
dt.print_tensor (%out : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
func @build_tensor1() {
%a = ts.build_shape [1:i64, 57:i64, 92:i64]
ts.print_shape %a
infrt.return
}
// CHECK-LABEL: test_tensor_type
func @test_tensor_type() {
%a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
// CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
dt.print_tensor (%a : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::registerCanonicalizerPass();
return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry));
}
// This file defines some basic elements of Paddle(alias pd) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
#ifndef PD_OP_BASE
#define PD_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def PD_Dialect : Dialect {
let name = "pd";
let description = [{
The PaddlePaddle dialect.
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::mlir::pd";
}
class PD_Op<string mnemonic, list<OpTrait> traits = []> :
Op<PD_Dialect, mnemonic, traits>;
class PD_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::pd::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;
//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//
def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::pd::PDType>()">, "PaddlePaddle type">;
class PD_PaddleType <string name, string description> :
Type<CPred<"$_self.isa<mlir::pd::" # name #"Type>()">,
"Paddle " # description # " type">,
BuildableType<"getType<mlir::pd::" # name # "Type>()">;
//===----------------------------------------------------------------------===//
// Integer types
def PD_Bool : AnyTypeOf<[I<1>], "bool">;
def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">;
def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;
def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">;
def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">;
def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">;
// Float types
def PD_Float16 : AnyTypeOf<[F16], "16-bit float">;
def PD_Float32 : AnyTypeOf<[F32], "32-bit float">;
def PD_Float64 : AnyTypeOf<[F64], "64-bit float">;
def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">;
// Tensor types
def PD_ElementType : Type<Or<[PD_Float.predicate,
PD_Bool.predicate,
PD_Int.predicate]>,
"pd.dtype">;
def PD_Tensor : TensorOf<[PD_ElementType]>;
#endif // PD_OP_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "paddle/infrt/dialect/infrt_base.h"
namespace mlir {
namespace pd {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
PaddleDialect::PaddleDialect(MLIRContext *context)
: Dialect("pd", context, TypeID::get<PaddleDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
return builder.create<ConstantOp>(loc, value);
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder,
OperationState &state,
Attribute value) {
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
return ConstantOp::build(builder, state, elem_attr);
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
ShapedType type = RankedTensorType::get(/*shape=*/{}, value.getType());
state.addAttribute("value", DenseElementsAttr::get(type, value));
state.addTypes(type);
return;
}
llvm_unreachable("unsupported attribute type for building pd.constant");
}
LogicalResult ConstantOp::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(attributes.get("value").getType());
return success();
}
::mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) {
return value();
}
LogicalResult ElementwiseAdd::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context);
}
::mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {};
DenseElementsAttr lhs = operands[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = operands[1].dyn_cast<DenseElementsAttr>();
if (!lhs || !rhs) return {};
ShapedType type = getType().template cast<ShapedType>();
if (!type.hasStaticShape()) return {};
Type etype = type.getElementType();
if (!etype.isa<FloatType>()) return {};
SmallVector<APFloat, 6> values;
values.reserve(lhs.getNumElements());
for (const auto zip :
llvm::zip(lhs.getValues<APFloat>(), rhs.getValues<APFloat>())) {
values.push_back(
std::plus<APFloat>()(std::get<0>(zip), std::get<1>(zip)));
}
return DenseElementsAttr::get(type, values);
}
return {};
}
LogicalResult ElementwiseDiv::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult ElementwiseMul::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult ElementwiseSub::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult MulOp::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context);
}
void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context);
}
void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context);
}
} // namespace pd
} // namespace mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace pd {
class PaddleDialect : public Dialect {
public:
explicit PaddleDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "pd"; }
/// A hook used to materialize constant values with the given type.
Operation* materializeConstant(OpBuilder& builder,
Attribute value,
Type type,
Location loc) override;
Type parseType(DialectAsmParser& parser) const override {
return Dialect::parseType(parser);
}
void printType(Type type, DialectAsmPrinter& printer) const override {
Dialect::printType(type, printer);
}
};
} // namespace pd
} // namespace mlir
#ifndef PD_OPS
#define PD_OPS
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> {
let summary = "Feed Op";
let description = [{
Feed a tensor into the model.
}];
let arguments = (ins);
let results = (outs PD_Tensor:$out);
let assemblyFormat = [{
`(` `)` attr-dict `:` type($out)
}];
}
def PD_ConstantOp : PD_Op<"Constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op";
let description = [{}];
let arguments = (ins ElementsAttr:$value);
let results = (outs PD_Tensor:$output);
let hasFolder = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">,
];
}
def PD_AbsOp : PD_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the absolute value of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_SqrtOp : PD_Op<"sqrt", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the sqrt value of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_ReluOp : PD_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Relu of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
let hasCanonicalizer = 1;
}
def PD_Relu6Op : PD_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Relu6 of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_ElementwiseAdd : PD_Op<"ElementwiseAdd", [NoSideEffect, Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseAdd Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def PD_ElementwiseSub : PD_Op<"ElementwiseSub", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseSub Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_ElementwiseMul : PD_Op<"ElementwiseMul", [NoSideEffect, Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseMul Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_ElementwiseDiv : PD_Op<"ElementwiseDiv", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseDiv Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_MatmulOp : PD_Op<"Matmul", [NoSideEffect]> {
let summary = "Computes the matrix mulplication result of two tensors";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y,
DefaultValuedAttr<BoolAttr, "false">:$transpose_x,
DefaultValuedAttr<BoolAttr, "false">:$transpose_y,
DefaultValuedAttr<F32Attr, "1.0">:$alpha);
let results = (outs PD_Tensor:$out);
//let hasCanonicalizer = 1;
}
def PD_MulOp : PD_Op<"mul", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "paddle mul op";
let description = [{}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y);
let results = (outs PD_Tensor:$out);
//let hasCanonicalizer = 1;
}
def PD_Conv2dOp : PD_Op<"conv2d", [NoSideEffect]> {
let summary = "paddle conv2d operation";
let description = [{
}];
let arguments = (ins PD_Tensor:$Input, PD_Tensor:$Filter, PD_Tensor:$Bias);
let results = (outs PD_Tensor:$Output);
//let hasCanonicalizer = 1;
}
def PD_BatchNormOp : PD_Op<"batch_norm", [NoSideEffect]> {
let summary = "paddle batch_norm operation";
let description = [{
}];
let arguments = (ins PD_Tensor:$X, PD_Tensor:$Scale, PD_Tensor:$Bias,
PD_Tensor:$Mean, PD_Tensor:$Variance,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon);
let results = (outs PD_Tensor:$Y);
let hasCanonicalizer = 1;
}
def PD_FusedFC : PD_Op<"FC", [NoSideEffect]> {
let summary = "Computes the Fully Connected result of two tensors";
let description = [{
}];
let arguments = (ins PD_Tensor:$input, PD_Tensor:$w, PD_Tensor:$bias, DefaultValuedAttr<I32Attr, "1">:$in_num_col_dims);
let results = (outs PD_Tensor:$out);
}
def PD_FusedRepeatedFCRelu : PD_Op<"RepeatedFCRelu", [SameVariadicOperandSize, NoSideEffect]> {
let summary = "";
let description = [{ }];
let arguments = (ins PD_Tensor:$input, Variadic<PD_Tensor>:$w, Variadic<PD_Tensor>:$bias);
let results = (outs PD_Tensor:$out);
let hasCanonicalizer = 1;
}
#endif // PD_OPS
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/pd_types.h"
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册