未验证 提交 70dea138 编写于 作者: Y Yan Chunwei 提交者: GitHub

introduce INF-RT (#37669)

* add infrt code

refined with Paddle's code style.

* rename CinnRtConfig to InfRtConfig

* rename CinnRt to InfRt of some code

* rename CINNRT to INFRT

* remove unnecessary code

* replace CINN to INFRT in the source code

* replace all "cinn" in code to "infrt"

* remove some const_cast
上级 890bd626
...@@ -216,6 +216,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE} ...@@ -216,6 +216,7 @@ option(WITH_DGC "Use DGC(Deep Gradient Compression) or not" ${WITH_DISTRIBUTE}
option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF) option(SANITIZER_TYPE "Choose the type of sanitizer, options are: Address, Leak, Memory, Thread, Undefined" OFF)
option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF) option(WITH_LITE "Compile Paddle Fluid with Lite Engine" OFF)
option(WITH_CINN "Compile PaddlePaddle with CINN" OFF) option(WITH_CINN "Compile PaddlePaddle with CINN" OFF)
option(WITH_INFRT "Compile PaddlePaddle with INFRT" OFF)
option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON) option(WITH_NCCL "Compile PaddlePaddle with NCCL support" ON)
option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON) option(WITH_RCCL "Compile PaddlePaddle with RCCL support" ON)
option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF) option(WITH_XPU_BKCL "Compile PaddlePaddle with BAIDU KUNLUN XPU BKCL" OFF)
......
include(FetchContent)
set(LLVM_DOWNLOAD_URL https://paddle-inference-dist.bj.bcebos.com/CINN/llvm11.tar.gz)
set(LLVM_MD5 39d32b6be466781dddf5869318dcba53)
set(FETCHCONTENT_BASE_DIR ${THIRD_PARTY_PATH}/llvm)
set(FETCHCONTENT_QUIET OFF)
FetchContent_Declare(external_llvm
URL ${LLVM_DOWNLOAD_URL}
URL_MD5 ${LLVM_MD5}
PREFIX ${THIRD_PARTY_PATH}/llvm
SOURCE_DIR ${THIRD_PARTY_PATH}/install/llvm
)
if (NOT LLVM_PATH)
FetchContent_GetProperties(external_llvm)
if (NOT external_llvm_POPULATED)
FetchContent_Populate(external_llvm)
endif()
set(LLVM_PATH ${THIRD_PARTY_PATH}/install/llvm)
set(LLVM_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/llvm)
set(MLIR_DIR ${THIRD_PARTY_PATH}/install/llvm/lib/cmake/mlir)
else ()
set(LLVM_DIR ${LLVM_PATH}/lib/cmake/llvm)
set(MLIR_DIR ${LLVM_PATH}/lib/cmake/mlir)
endif()
if (${CMAKE_CXX_COMPILER} STREQUAL "clang++")
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -stdlib=libc++ -lc++abi")
endif()
message(STATUS "set LLVM_DIR: ${LLVM_DIR}")
message(STATUS "set MLIR_DIR: ${MLIR_DIR}")
find_package(LLVM REQUIRED CONFIG HINTS ${LLVM_DIR})
find_package(MLIR REQUIRED CONFIG HINTS ${MLIR_DIR})
find_package(ZLIB REQUIRED)
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
include(AddLLVM)
include_directories(${LLVM_INCLUDE_DIRS})
list(APPEND CMAKE_MODULE_PATH "${LLVM_CMAKE_DIR}")
list(APPEND CMAKE_MODULE_PATH "${MLIR_CMAKE_DIR}")
include(AddLLVM)
include(TableGen)
include(AddMLIR)
message(STATUS "Found MLIR: ${MLIR_DIR}")
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
message(STATUS "Using LLVMConfig.cmake in: ${LLVM_DIR}")
# To build with MLIR, the LLVM is build from source code using the following flags:
#[==[
cmake -G Ninja ../llvm \
-DLLVM_ENABLE_PROJECTS="mlir;clang" \
-DLLVM_BUILD_EXAMPLES=OFF \
-DLLVM_TARGETS_TO_BUILD="X86" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_ENABLE_ZLIB=OFF \
-DLLVM_ENABLE_RTTI=ON \
#]==]
# The matched llvm-project version is f9dc2b7079350d0fed3bb3775f496b90483c9e42 (currently a temporary commit)
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(llvm_libs Support Core irreader
X86 executionengine orcjit mcjit all codegen)
message(STATUS "LLVM libs: ${llvm_libs}")
get_property(mlir_libs GLOBAL PROPERTY MLIR_ALL_LIBS)
message(STATUS "MLIR libs: ${mlir_libs}")
add_definitions(${LLVM_DEFINITIONS})
# The minimum needed libraries for MLIR IR parse and transform.
set(MLIR_IR_LIBS MLIRAnalysis MLIRStandardOps MLIRPass MLIRParser MLIRDialect MLIRIR MLIROptLib)
# tb_base is the name of a xxx.td file (without the .td suffix)
function(mlir_tablegen_on td_base)
set(options)
set(oneValueArgs DIALECT)
cmake_parse_arguments(mlir_tablegen_on "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-op-decls)
mlir_tablegen(${td_base}.cpp.inc -gen-op-defs)
if (mlir_tablegen_on_DIALECT)
mlir_tablegen(${td_base}_dialect.hpp.inc --gen-dialect-decls -dialect=${mlir_tablegen_on_DIALECT})
endif()
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction()
function(mlir_add_rewriter td_base)
set(LLVM_TARGET_DEFINITIONS ${td_base}.td)
mlir_tablegen(${td_base}.hpp.inc -gen-rewriters "-I${CMAKE_SOURCE_DIR}/infrt/dialect/pass")
add_public_tablegen_target(${td_base}_IncGen)
add_custom_target(${td_base}_inc DEPENDS ${td_base}_IncGen)
endfunction()
# Execute the mlir script with infrt-exec program.
# @name: name of the test
# @script: path to the mlir script file
function (infrt_exec_check name script)
add_test(NAME ${name}
COMMAND sh -c "${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i ${CMAKE_CURRENT_SOURCE_DIR}/${script}| ${LLVM_PATH}/bin/FileCheck ${CMAKE_CURRENT_SOURCE_DIR}/${script}")
endfunction()
...@@ -391,6 +391,11 @@ if (WIN32) ...@@ -391,6 +391,11 @@ if (WIN32)
list(APPEND third_party_deps extern_dirent) list(APPEND third_party_deps extern_dirent)
endif (WIN32) endif (WIN32)
if (WITH_INFRT)
include(external/llvm)
list(APPEND third_party_deps external_llvm)
endif()
if (WITH_IPU) if (WITH_IPU)
include(external/poplar) include(external/poplar)
list(APPEND third_party_deps extern_poplar) list(APPEND third_party_deps extern_poplar)
......
...@@ -2,4 +2,5 @@ add_subdirectory(scripts) ...@@ -2,4 +2,5 @@ add_subdirectory(scripts)
add_subdirectory(testing) add_subdirectory(testing)
set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory") set(PYTHON_TESTS_DIR ${PADDLE_BINARY_DIR}/python/paddle/fluid/tests CACHE INTERNAL "python tests directory")
add_subdirectory(pten) add_subdirectory(pten)
add_subdirectory(infrt)
add_subdirectory(fluid) add_subdirectory(fluid)
if (NOT WITH_INFRT)
return()
endif()
set(infrt_src CACHE INTERNAL "" FORCE)
# Gather headers for library publish.
function(core_gather_headers)
file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h)
foreach(header ${includes})
set(core_includes "${core_includes};${header}" CACHE INTERNAL "")
endforeach()
endfunction()
function(gather_srcs SRC_GROUP)
set(options)
set(oneValueArgs)
set(multiValueArgs "SRCS")
cmake_parse_arguments(prefix "" "" "${multiValueArgs}" ${ARGN})
foreach(cpp ${prefix_SRCS})
set(${SRC_GROUP} "${${SRC_GROUP}};${CMAKE_CURRENT_SOURCE_DIR}/${cpp}" CACHE INTERNAL "")
endforeach()
endfunction()
# This method is similar to the global cc_test, but discard the huge amount default dependencies those are
# not needed by INFRT.
function(cc_test_tiny TARGET_NAME)
if(WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS ARGS)
cmake_parse_arguments(cc_test_tiny "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
add_executable(${TARGET_NAME} ${cc_test_tiny_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_tiny_DEPS} ${os_dependency_modules} infrt_gtest_main gtest )
add_dependencies(${TARGET_NAME} ${cc_test_tiny_DEPS} infrt_gtest_main gtest extern_gtest)
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} "${cc_test_tiny_ARGS}"
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
if (${cc_test_tiny_SERIAL})
set_property(TEST ${TARGET_NAME} PROPERTY RUN_SERIAL 1)
endif()
endif()
endfunction()
if (WITH_TESTING)
cc_library(infrt_gtest_main SRCS gtest_main.cc DEPS gtest glog gflags)
endif()
add_subdirectory(api)
add_subdirectory(common)
add_subdirectory(dialect)
add_subdirectory(host_context)
add_subdirectory(kernel)
add_subdirectory(tensor)
add_subdirectory(support)
add_subdirectory(external_kernels)
add_subdirectory(paddle)
# MLIR td file generations
set(infrt_mlir_incs
ops_inc
basic_kernels_inc
test_kernels_inc
infrt_base_inc
tensor_shape_inc
dense_tensor_inc
pd_ops_inc
rewrite_inc
)
message(STATUS "infrt srcs:\n${infrt_src}")
cc_library(infrt SRCS ${infrt_src} DEPS glog ${mlir_libs} paddle_framework_proto)
add_dependencies(infrt ${infrt_mlir_incs})
core_gather_headers()
gather_srcs(infrt_src SRCS
infrt_api.cc
)
# Disable temporarily for the external-kernel's mkldnn is outdate
# cc_test(test_infrt_api SRCS infrt_api_test.cc DEPS infrt ${MLIR_IR_LIBS})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/api/infrt_api.h"
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/DynamicLibrary.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/Parser.h>
#include <unordered_map>
#include <vector>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
#include "paddle/infrt/tensor/tensor_map.h"
using namespace infrt::host_context; // NOLINT
using namespace infrt::tensor; // NOLINT
using namespace infrt::tensor; // NOLINT
using infrt::dt::TensorMapType; // NOLINT
using infrt::dt::TensorType; // NOLINT
namespace infrt {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
std::string buffer;
llvm::raw_string_ostream os(buffer);
op.print(os);
os.flush();
return buffer;
}
struct MlirToRuntimeTranslator::Impl {
mlir::ModuleOp module;
// The runtime for a function call.
CoreRuntimeBuilder* runtime{};
// The current working op, the translator process the ops one by one, each
// time it updates `cur_op` here to current op
// working on.
OpExecutableBuilder* cur_op{};
// record the current function name.
std::string cur_func_name;
// Name to function definitions.
std::unordered_map<std::string, mlir::FuncOp> func_defs;
// Map from an operation to its results.
std::unordered_map<const mlir::Operation*, std::vector<ValueRef>> op_results;
llvm::DenseMap<mlir::Value, ValueRef> value_map;
};
/**
* Execute the mlir program in predict mode.
*/
class PredictExecutor : public MlirToRuntimeTranslator {
public:
CoreRuntimeBuilder core_runtime;
PredictExecutor(mlir::ModuleOp module,
KernelRegistry* registry,
TensorMap* map)
: MlirToRuntimeTranslator(module, &core_runtime),
core_runtime(registry),
registry_(registry) {
CHECK(registry_);
Init(map);
}
void Run() {
auto arguments = llvm::makeArrayRef(arguments_);
auto results = llvm::makeMutableArrayRef(results_.begin(), results_.size());
function_executable_->Execute(arguments, results);
}
int GetInputNum() { return inputs_.size(); }
DenseHostTensor* GetInput(int i) { return inputs_[i]; }
int GetOutputNum() { return outputs_.size(); }
DenseHostTensor* GetOutput(int i) { return outputs_[i]; }
private:
void Init(TensorMap* map) {
EmitFunctions();
llvm::Optional<mlir::FuncOp> predict_func_ = llvm::None;
for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
if (func_op.getName().str() != "predict") continue;
predict_func_ = func_op;
break;
}
if (!predict_func_) {
std::cout << "ERROR: init failed, no predict function found in mlir."
<< std::endl;
return;
}
auto& predict_func = predict_func_.getValue();
function_executable_ =
new MlirFunctionExecutable(predict_func, registry_, impl_->func_defs);
// process parammeters
for (size_t i = 0; i < predict_func.getNumArguments(); ++i) {
auto arg = predict_func.getArgument(i);
auto type = arg.getType();
// this param is TensorMap
if (type.isa<TensorMapType>()) {
auto* value = new host_context::Value(std::move(*map));
arguments_.push_back(value);
AddValue(predict_func.getArgument(i), value);
} else {
// this param is an input Tensor
auto dht = DenseHostTensor();
auto* value = new host_context::Value(std::move(dht));
arguments_.push_back(value);
inputs_.push_back(&(value->get<DenseHostTensor>()));
}
}
// process results
auto& last_op = predict_func.front().back();
if (last_op.getName().getStringRef() == "infrt.return") {
for (size_t i = 0; i < last_op.getNumOperands(); ++i) {
auto* value = AddValue(mlir::Value(last_op.getOperand(i)));
results_.push_back(ValueRef(value));
outputs_.push_back(&(value->get<DenseHostTensor>()));
}
}
}
protected:
std::unordered_map<std::string, mlir::FuncOp> func_def_table;
void EmitFunction(mlir::FuncOp op) override {
CHECK(!impl_->func_defs.count(op.getName().str()))
<< "Duplicate function defition found for function ["
<< op.getName().str();
impl_->func_defs.emplace(op.getName().str(), op);
}
private:
KernelRegistry* registry_{};
MlirFunctionExecutable* function_executable_;
llvm::SmallVector<DenseHostTensor*, 1> inputs_;
llvm::SmallVector<host_context::Value*, 2> arguments_;
llvm::SmallVector<DenseHostTensor*, 1> outputs_;
llvm::SmallVector<ValueRef, 1> results_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(
const InfRtConfig& config) {
auto x = std::make_shared<InfRtPredictor>();
x->Init(config);
return x;
}
struct InfRtPredictor::Impl {
mlir::OwningModuleRef module_ref;
std::unique_ptr<PredictExecutor> executor;
};
InfRtPredictor::InfRtPredictor() : impl_(new Impl) {}
InfRtPredictor::~InfRtPredictor() {}
void InfRtPredictor::Run() { impl_->executor->Run(); }
int InfRtPredictor::Init(const InfRtConfig& config) {
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module_ref = dialect::LoadMlirFile(config.mlir_path(), context);
KernelRegistry* registry = new KernelRegistry();
kernel::RegisterBasicKernels(registry);
kernel::RegisterTestKernels(registry);
kernel::RegisterTensorShapeKernels(registry);
kernel::RegisterTensorKernels(registry);
kernel::RegisterControlFlowKernels(registry);
impl_->module_ref = std::move(module_ref);
// load extra shared library
for (const std::string& lib_path : config.shared_libs()) {
std::string err;
llvm::sys::DynamicLibrary dynLib =
llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err);
if (!dynLib.isValid()) {
llvm::errs() << "Load shared library failed. Error: " << err << "\n";
return 1;
}
if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) {
auto reg_func = reinterpret_cast<void (*)(KernelRegistry*)>(reg_sym);
reg_func(registry);
} else {
llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path
<< "\". Skip.\n";
}
}
// Load params
TensorMap* tensor_map = LoadParams(config.model_dir());
// Create PredictExecutor
impl_->executor.reset(
new PredictExecutor(impl_->module_ref.get(), registry, tensor_map));
return 0;
}
int InfRtPredictor::GetInputNum() { return impl_->executor->GetInputNum(); }
DenseHostTensor* InfRtPredictor::GetInput(int i) {
return impl_->executor->GetInput(i);
}
int InfRtPredictor::GetOutputNum() { return impl_->executor->GetOutputNum(); }
DenseHostTensor* InfRtPredictor::GetOutput(int i) {
return impl_->executor->GetOutput(i);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
class InfRtConfig {
std::string model_dir_;
std::string mlir_path_;
std::vector<std::string> shared_libs_;
public:
InfRtConfig() = default;
void set_model_dir(const std::string& model_dir) { model_dir_ = model_dir; }
const std::string& model_dir() const { return model_dir_; }
void set_mlir_path(const std::string& mlir_path) { mlir_path_ = mlir_path; }
const std::string& mlir_path() const { return mlir_path_; }
void set_shared_libs(const std::vector<std::string>& shared_libs) {
shared_libs_ = shared_libs;
}
const std::vector<std::string>& shared_libs() const { return shared_libs_; }
virtual ~InfRtConfig() = default;
};
class InfRtPredictor {
public:
InfRtPredictor();
~InfRtPredictor();
void Run();
int Init(const InfRtConfig& config);
int GetInputNum();
tensor::DenseHostTensor* GetInput(int i);
int GetOutputNum();
tensor::DenseHostTensor* GetOutput(int i);
protected:
struct Impl;
std::unique_ptr<Impl> impl_;
};
std::shared_ptr<InfRtPredictor> CreateInfRtPredictor(const InfRtConfig& config);
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/api/infrt_api.h"
#include <gtest/gtest.h>
#include <iostream>
#include <vector>
#include "llvm/Support/raw_ostream.h"
#include "paddle/infrt/common/buffer.h"
#include "paddle/infrt/common/dtype.h"
using infrt::InfRtConfig;
using infrt::InfRtPredictor;
using infrt::CreateInfRtPredictor;
namespace infrt {
TEST(InfRtPredictor, predictor) {
std::vector<std::string> shared_libs;
shared_libs.push_back("../../paddle/libexternal_kernels.so");
InfRtConfig config;
// set external shared libraries that contain kernels.
config.set_shared_libs(shared_libs);
// set model dir
config.set_model_dir("../../paddle/paddle_1.8_fc_model");
// set mlir path
config.set_mlir_path("../../../infrt/dialect/mlir_tests/tensor_map.mlir");
std::shared_ptr<InfRtPredictor> predictor = CreateInfRtPredictor(config);
auto* input = predictor->GetInput(0);
std::vector<int64_t> shape = {3, 3};
input->Init(shape, infrt::GetDType<float>());
llvm::outs() << input->shape() << "\n";
// init input tensor
auto* input_data = reinterpret_cast<float*>(input->buffer()->data()->memory);
for (int i = 0; i < input->shape().GetNumElements(); i++) input_data[i] = 1.0;
predictor->Run();
// get and print output tensor
auto* output = predictor->GetOutput(0);
auto* output_data =
reinterpret_cast<float*>(output->buffer()->data()->memory);
std::vector<float> ans = {0.428458,
0.244493,
0.572342,
0.572008,
0.509771,
0.495599,
0.651287,
0.326426,
0.404649};
ASSERT_EQ(output->shape().GetNumElements(), ans.size());
for (int i = 0; i < output->shape().GetNumElements(); ++i) {
ASSERT_NEAR(output_data[i], ans[i], 0.000001);
}
}
} // namespace infrt
core_gather_headers()
set(core_includes "${core_includes};infrt/common/dtype.def" CACHE INTERNAL "")
gather_srcs(infrt_src SRCS
dtype.cc
global.cc
target.cc
type.cc
shared.cc
object.cc
string.cc
buffer.cc
memory.cc
)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/buffer.h"
#include <stdarg.h>
#include <stdio.h>
#include <cmath>
namespace infrt {
void Buffer::Resize(uint32_t size) {
if (size_ > 0) {
Free();
size_ = 0;
}
if (size_ != size) {
data_.memory = reinterpret_cast<uint8_t*>(Malloc(size));
size_ = size;
}
}
void Buffer::Resize(uint32_t alignment, uint32_t size) {
if (size_ > 0) {
Free();
size_ = 0;
}
if (size_ != size) {
data_.memory = reinterpret_cast<uint8_t*>(AlignedAlloc(alignment, size));
size_ = size;
}
}
void Buffer::SetTarget(const infrt::common::Target& target) {
target_ = target;
memory_mng_cache_ = MemoryManager::Global().RetrieveSafely(target_.arch);
}
void Buffer::ResizeLazy(uint32_t size) {
if (size <= size_) return;
Resize(size);
}
void Buffer::ResizeLazy(uint32_t alignment, uint32_t size) {
if (size <= size_) return;
Resize(alignment, size);
}
void Buffer::Resize(uint32_t size, const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
Resize(size);
}
void Buffer::Resize(uint32_t alignment,
uint32_t size,
const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
Resize(alignment, size);
}
void Buffer::ResizeLazy(uint32_t size, const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
ResizeLazy(size);
}
void Buffer::ResizeLazy(uint32_t alignment,
uint32_t size,
const infrt::common::Target& target) {
if (target.arch != target_.arch) {
Free();
SetTarget(target);
}
ResizeLazy(alignment, size);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/memory.h"
#include "paddle/infrt/common/target.h"
namespace infrt {
#ifdef __cplusplus
extern "C" {
#endif
#define INFRT_ALWAYS_INLINE __attribute__((always_inline)) inline
//! Code for the primitive types supported in INFRT.
typedef enum infrt_type_code_t {
infrt_type_unk = -1, //! Unknown type
infrt_type_int = 0, //! signed int
infrt_type_uint = 1, //! unsigned int
infrt_type_float = 2, //! floating point
infrt_type_handle = 3 //! void*
} infrt_type_code_t;
#ifndef INFRT_ATTRIBUTE_ALIGN
#define INFRT_ATTRIBUTE_ALIGN(n) __attribute__((aligned(n)))
#endif
/**
* A tuntime tag for type in INFRT system.
*/
typedef struct infrt_type_t {
#if __cplusplus >= 201103L
INFRT_ATTRIBUTE_ALIGN(1) infrt_type_code_t code;
#else
uint8_t code;
#endif
//! Number of bits.
uint8_t bits;
//! Number of elements in a vector, 1 for scalar.
uint16_t lanes;
//! Number of '*', e.g. for `float*`, the num_asterisks is 1, `float**` it is
//! 2.
uint8_t num_asterisks{0};
#ifdef __cplusplus
INFRT_ALWAYS_INLINE infrt_type_t()
: code(infrt_type_int), bits(0), lanes(0) {}
INFRT_ALWAYS_INLINE infrt_type_t(infrt_type_code_t code,
uint8_t bits,
uint16_t lanes = 1,
uint8_t num_asterisks = 0)
: code(code), bits(bits), lanes(lanes), num_asterisks(num_asterisks) {}
INFRT_ALWAYS_INLINE bool operator==(const infrt_type_t& other) const {
return code == other.code && bits == other.bits && lanes == other.lanes;
}
INFRT_ALWAYS_INLINE bool operator!=(const infrt_type_t& other) const {
return !(*this == other);
}
INFRT_ALWAYS_INLINE uint16_t bytes() const { return (bits + 7) / 8; }
#endif // __cplusplus
} infrt_type_t;
//! Help to define the size of a dimension, due to polyhedral representation, we
//! no need to record the extend or
//! min(default to 0).
typedef int infrt_dimension_t;
//! Help to tell the kind of the device.
typedef enum infrt_device_kind_t {
infrt_unk_device = -1, // Undefined device.
infrt_x86_device = 0, // X86 device
infrt_opencl_device = 1, // OpenCL device
infrt_arm_device = 2 // ARM device
} infrt_device_kind_t;
struct infrt_buffer_t;
/**
* All INFRT backends implementation should provide an interface to be used.
*/
struct infrt_device_interface_impl_t;
struct infrt_device_interface_t {
int (*malloc)(void* context, struct infrt_buffer_t* buf);
int (*free)(void* context, struct infrt_buffer_t* buf);
int (*sync)(void* context, struct infrt_buffer_t* buf);
int (*release)(void* context,
const struct infrt_device_interface_t* device_interface);
int (*copy_to_host)(void* context, struct infrt_buffer_t* buf);
int (*copy_to_device)(void* context, struct infrt_buffer_t* buf);
int (*buffer_copy)(void* context,
struct infrt_buffer_t* src,
struct infrt_buffer_t* dst);
struct infrt_device_interface_impl_t* impl;
};
//! The raw representation of a buffer,used in the generated code/lib.
#define INFRT_BUFFER_MAX_DIMS 8
typedef struct infrt_buffer_t {
//! Tell which kind of device this buffer locates.
infrt_device_kind_t device;
//! The interface used to operate on device.
const struct infrt_device_interface_t* device_interface;
//! A pointer to the memory in host.
uint8_t* memory;
//! Extra flags.
uint64_t flag;
//! Data type.
infrt_type_t type;
//! Number of dimensions.
int32_t dimensions;
infrt_dimension_t dims[INFRT_BUFFER_MAX_DIMS];
//! Allocate and deallocate lazily, default true.
char lazy;
//! The actual memory size(in bytes).
uint64_t memory_size;
uint16_t align;
#ifdef __cplusplus
infrt_buffer_t()
: device(infrt_unk_device),
device_interface(NULL),
memory(NULL),
flag(0UL),
type(infrt_type_t()),
dimensions(0),
lazy(true),
memory_size(0),
align(0) {}
static void delete_(struct infrt_buffer_t* x) { delete x; }
~infrt_buffer_t() {}
// NOTE the buffer should be resized first.
static void alloc(struct infrt_buffer_t*);
//! Set the shape of the buffer. NOTE this just record the shape, not allocate
//! the memory.
INFRT_ALWAYS_INLINE void resize(const infrt_dimension_t* dims,
int dimensions) {
this->dimensions = dimensions;
memcpy(this->dims, dims, dimensions * sizeof(infrt_dimension_t));
}
INFRT_ALWAYS_INLINE uint64_t num_elements() const {
uint64_t res = 1;
for (int i = 0; i < dimensions; i++) {
res *= dims[i];
}
return res;
}
INFRT_ALWAYS_INLINE int device_sync(void* ctx = NULL) {
if (device_interface && device_interface->sync) {
return device_interface->sync(ctx, this);
}
return 0;
}
INFRT_ALWAYS_INLINE uint8_t* begin() const { return 0; }
INFRT_ALWAYS_INLINE uint8_t* end() const {
return memory + num_elements() * type.bytes();
}
#endif // __cplusplus
} infrt_buffer_t;
#ifdef __cplusplus
struct infrt_device_interface_impl_t {
int (*malloc)(void* context, struct infrt_buffer_t* buf);
int (*free)(void* context, struct infrt_buffer_t* buf);
int (*sync)(void* context, struct infrt_buffer_t* buf);
int (*release)(void* context);
int (*copy_to_host)(void* context, struct infrt_buffer_t* buf);
int (*copy_to_device)(void* context, struct infrt_buffer_t* buf);
int (*buffer_copy)(void* context,
struct infrt_buffer_t* src,
struct infrt_buffer_t* dst);
};
// The device implementations
extern struct infrt_device_interface_t* infrt_x86_device_interface();
#endif // __cplusplus
#ifdef __cplusplus
} // extern "C"
#endif
#define INFRT_LOG(fmt, ...) \
do { \
fprintf(stderr, \
"%s:%d:%s(): " fmt, \
__FILE__, \
__LINE__, \
__func__, \
__VA_ARGS__); \
} while (0)
#define INFRT_CHECK(cond) \
if (!(cond)) { \
INFRT_LOG("check %s failed", #cond); \
abort(); \
}
/**
* Buffer helps to hold the memory, and offers a set of methods to help manage
* the memory.
*/
struct Buffer final {
Buffer() = default;
explicit Buffer(const common::Target& target) { SetTarget(target); }
//! Resize the memory hold by this buffer *exactlly* to \p size.
void Resize(uint32_t size);
void Resize(uint32_t alignment, uint32_t size);
//! Lazily resize the memory.
void ResizeLazy(uint32_t size);
void ResizeLazy(uint32_t alignment, uint32_t size);
//! Resize the memory to \p size in target \p target.
void Resize(uint32_t size, const common::Target& target);
void Resize(uint32_t alignment, uint32_t size, const common::Target& target);
//! Lazily resize the memory to \p size in target \p target.
void ResizeLazy(uint32_t size, const common::Target& target);
void ResizeLazy(uint32_t alignment,
uint32_t size,
const common::Target& target);
void SetTarget(const common::Target& target);
const infrt_buffer_t* data() const { return &data_; }
infrt_buffer_t* data() { return &data_; }
//! Free all the memory owned by this buffer.
void Free() {
if (!data_.memory) return;
memory_mng_cache_->free(data_.memory);
}
private:
inline void* Malloc(uint32_t size) INFRT_RESULT_SHOULD_USE {
CHECK(memory_mng_cache_) << "Should set target first";
return memory_mng_cache_->malloc(size);
}
inline void* AlignedAlloc(uint32_t alignment,
uint32_t size) INFRT_RESULT_SHOULD_USE {
CHECK(memory_mng_cache_) << "Should set target first";
return memory_mng_cache_->aligned_alloc(alignment, size);
}
private:
infrt_buffer_t data_;
//! The place where this buffer locates.
common::Target target_;
//! Number of bytes of this buffer.
uint32_t size_{};
//! Hold the corresponding memory manager for speed.
MemoryInterface* memory_mng_cache_{};
};
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/shared.h"
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt {
// export some general concepts.
using common::make_shared;
using common::Object;
using common::ref_count;
using common::Shared;
// Type related.
using common::Bool;
using common::Float;
using common::Int;
using common::UInt;
using common::Void;
using common::type_of;
using common::Target;
using common::Type;
using common::UnkTarget;
template <typename T>
T& Reference(const T* x) {
return *const_cast<T*>(x);
}
static void CheckVarNameValid(const std::string& name) {
CHECK(!name.empty());
CHECK(name.find(' ') == std::string::npos && //
name.find('.') == std::string::npos && //
name.find('/') == std::string::npos && //
name.find('\t') == std::string::npos && //
name.find('\n') == std::string::npos && //
name.find('\r') == std::string::npos)
<< "Some invalid character found";
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/dtype.h"
namespace infrt {
const char* DType::name() const {
switch (kind_) {
case Kind::Unk:
return "Unk";
break;
#define INFRT_DTYPE(enum__, value__) \
case Kind::enum__: \
return #enum__; \
break;
#include "paddle/infrt/common/dtype.def"
#undef INFRT_DTYPE
}
return "";
}
size_t DType::GetHostSize() const {
switch (kind_) {
#define INFRT_DTYPE(enum__, value__) \
case DType::Kind::enum__: \
return sizeof(DTypeInternal<DType::Kind::enum__>::type);
#include "paddle/infrt/common/dtype.def" // NOLINT
#undef INFRT_DTYPE
case Kind::Unk:
return 0;
break;
}
return 0;
}
} // namespace infrt
// Define all INFRT dtypes
// DTYPE(ENUM, VALUE)
#ifdef INFRT_DTYPE
INFRT_DTYPE(UI8, 1)
INFRT_DTYPE(UI16, 2)
INFRT_DTYPE(UI32, 3)
INFRT_DTYPE(UI64, 4)
INFRT_DTYPE(I1, 5)
INFRT_DTYPE(I8, 6)
INFRT_DTYPE(I16, 7)
INFRT_DTYPE(I32, 8)
INFRT_DTYPE(I64, 9)
INFRT_DTYPE(F32, 10)
INFRT_DTYPE(F64, 11)
INFRT_DTYPE(STRING, 12)
#endif
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <string>
namespace infrt {
class DType {
public:
enum class Kind : uint8_t {
Unk = 0,
// Automatically generate the enum definition
#define INFRT_DTYPE(enum__, value__) enum__ = value__,
#include "paddle/infrt/common/dtype.def"
#undef INFRT_DTYPE
BOOL = I1,
};
DType() = default;
explicit constexpr DType(Kind kind) : kind_(kind) { assert(IsValid()); }
DType(const DType&) = default;
DType& operator=(const DType&) = default;
bool operator==(DType other) const { return kind_ == other.kind_; }
bool operator!=(DType other) const { return !(*this == other); }
constexpr Kind kind() const { return kind_; }
bool IsValid() const { return kind_ != Kind::Unk; }
bool IsInvalid() const { return !IsValid(); }
const char* name() const;
size_t GetHostSize() const;
private:
Kind kind_{Kind::Unk};
};
template <typename T>
constexpr DType GetDType();
template <DType::Kind kind>
struct DTypeInternal;
#define INFRT_IMPL_GET_DTYPE(cpp_type__, enum__) \
template <> \
inline constexpr DType GetDType<cpp_type__>() { \
return DType{DType::Kind::enum__}; \
} \
template <> \
struct DTypeInternal<DType::Kind::enum__> { \
using type = cpp_type__; \
};
INFRT_IMPL_GET_DTYPE(bool, I1);
INFRT_IMPL_GET_DTYPE(int8_t, I8);
INFRT_IMPL_GET_DTYPE(int16_t, I16);
INFRT_IMPL_GET_DTYPE(int32_t, I32);
INFRT_IMPL_GET_DTYPE(int64_t, I64);
INFRT_IMPL_GET_DTYPE(uint8_t, UI8);
INFRT_IMPL_GET_DTYPE(uint16_t, UI16);
INFRT_IMPL_GET_DTYPE(uint32_t, UI32);
INFRT_IMPL_GET_DTYPE(uint64_t, UI64);
INFRT_IMPL_GET_DTYPE(float, F32);
INFRT_IMPL_GET_DTYPE(double, F64);
INFRT_IMPL_GET_DTYPE(std::string, STRING);
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/global.h"
namespace infrt {
Global::Global() {}
mlir::MLIRContext* Global::context = nullptr;
mlir::MLIRContext* Global::getMLIRContext() {
if (nullptr == context) {
context = new mlir::MLIRContext();
}
return context;
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/MLIRContext.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
// global variables
class Global {
private:
static mlir::MLIRContext *context;
Global();
public:
static mlir::MLIRContext *getMLIRContext();
}; // class Global
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#if !defined(NDEBUG)
#define INFRT_DEBUG
#endif
#define INFRT_DISALLOW_COPY_AND_ASSIGN(TypeName) \
TypeName(const TypeName&) = delete; \
void operator=(const TypeName&) = delete
#ifndef INFRT_NOT_IMPLEMENTED
#define INFRT_NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented";
#endif
#define INFRT_RESULT_SHOULD_USE __attribute__((warn_unused_result))
/**
* A trick to enforce the registry.
*
* usage:
*
* INFRT_REGISTER_HELPER(some_key) {
* // register methods
* }
*
* INFRT_USE_REGISTER(some_key);
*/
#define INFRT_REGISTER_HELPER(symbol__) bool __infrt__##symbol__##__registrar()
#define INFRT_USE_REGISTER(symbol__) \
extern bool __infrt__##symbol__##__registrar(); \
[[maybe_unused]] static bool __infrt_extern_registrar_##symbol__ = \
__infrt__##symbol__##__registrar();
#if __cplusplus >= 201703L
#define INFRT_NODISCARD [[nodiscard]]
#else
#define INFRT_NODISCARD
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/memory.h"
namespace infrt {
using infrt::common::Target;
namespace {
class X86MemoryMng : public MemoryInterface {
public:
void* malloc(size_t nbytes) override { return ::malloc(nbytes); }
void free(void* data) override {
if (!data) return;
::free(data);
}
void* aligned_alloc(size_t alignment, size_t nbytes) override {
return ::aligned_alloc(alignment, nbytes);
}
};
} // namespace
MemoryManager::MemoryManager() {
Register(Target::Arch::Unk, new X86MemoryMng);
Register(Target::Arch::X86, new X86MemoryMng);
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <unordered_map>
#include <memory>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/common/target.h"
namespace infrt {
class MemoryInterface {
public:
virtual void* malloc(size_t nbytes) = 0;
virtual void free(void* data) = 0;
virtual void* aligned_alloc(size_t alignment, size_t nbytes) {
return nullptr;
}
virtual ~MemoryInterface() {}
};
/**
* MemoryManager holds a map of MemoryInterface for each articture.
*/
class MemoryManager final {
public:
using key_t = common::Target::Arch;
static MemoryManager& Global() {
static auto* x = new MemoryManager;
return *x;
}
MemoryInterface* Retrieve(key_t key) INFRT_RESULT_SHOULD_USE {
auto it = memory_mngs_.find(key);
if (it != memory_mngs_.end()) return it->second.get();
return nullptr;
}
MemoryInterface* RetrieveSafely(key_t key) {
auto* res = Retrieve(key);
CHECK(res) << "no MemoryInterface for architecture [" << key << "]";
return res;
}
MemoryInterface* Register(key_t key, MemoryInterface* item) {
CHECK(!memory_mngs_.count(key)) << "Duplicate register [" << key << "]";
memory_mngs_[key].reset(item);
return item;
}
private:
MemoryManager();
std::unordered_map<common::Target::Arch, std::unique_ptr<MemoryInterface>>
memory_mngs_;
INFRT_DISALLOW_COPY_AND_ASSIGN(MemoryManager);
};
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/object.h"
namespace infrt {
namespace common {} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cstring>
#include <iostream>
#include "paddle/infrt/common/shared.h"
namespace infrt {
namespace common {
template <typename T>
class Shared;
/**
* Object is the basic element in the INFRT, with `Shared` wrapper, the object
* can be shared accross the system.
*/
struct Object {
//! Get the type representation of this object.
virtual const char* type_info() const = 0;
virtual ~Object() {}
//! Cast to a derived type.
template <typename T>
T* as() {
return static_cast<T*>(this);
}
//! Cast to a derived type.
template <typename T>
const T* as() const {
return static_cast<const T*>(this);
}
//! Type safe cast.
template <typename T>
T* safe_as() {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<T*>(this);
}
return nullptr;
}
//! Type safe cast.
template <typename T>
const T* safe_as() const {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return static_cast<const T*>(this);
}
return nullptr;
}
//! Check if the type is right.
template <typename T>
bool is_type() const {
if (std::strcmp(type_info(), T::__type_info__) == 0) {
return true;
}
return false;
}
//! The reference count, which make all the derived type able to share.
mutable RefCount __ref_count__;
};
using object_ptr = Object*;
using shared_object = Shared<Object>;
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/shared.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <string>
#include <type_traits>
namespace infrt {
namespace common {
class RefCount {
public:
using value_type = int32_t;
RefCount() = default;
value_type Inc() { return ++count_; }
value_type Dec() { return --count_; }
bool is_zero() const { return 0 == count_; }
std::string to_string() { return std::to_string(count_.load()); }
int32_t val() const { return count_; }
private:
std::atomic<value_type> count_{0};
};
class Object;
/**
* The templated methods are used to unify the way to get the RefCount instance
* in client classes.
*/
template <typename T>
RefCount& ref_count(const T* t) {
static_assert(std::is_base_of<Object, T>::value, "T is not a Object");
return t->__ref_count__;
}
template <typename T>
void Destroy(const T* t) {
delete t;
}
template <typename T>
struct Shared {
using object_ptr = T*;
Shared() = default;
explicit Shared(T* p) : p_(p) {
if (p) IncRef(p);
}
Shared(const Shared& other) : p_(other.p_) { IncRef(p_); }
Shared(Shared&& other) : p_(other.p_) { other.p_ = nullptr; }
Shared<T>& operator=(const Shared<T>& other);
//! Reset to another pointer \p x.
void Reset(T* x = nullptr);
//! Access the pointer in various ways.
// @{
inline T* get() const { return p_; }
inline T& operator*() const { return *p_; }
inline T* operator->() const { return p_; }
inline T* self() { return p_; }
inline const T* self() const { return p_; }
// @}
inline bool same_as(const Shared& other) { return p_ == other.p_; }
inline bool defined() const { return p_; }
inline bool operator<(const Shared& other) const { return p_ < other.p_; }
inline Shared<T>& operator=(T* x);
inline bool operator==(const Shared& other) const { return p_ == other.p_; }
~Shared();
private:
//! Increase the share count.
void IncRef(T* p);
//! Decrease the share count.
void DecRef(T* p);
protected:
T* p_{};
};
template <typename T>
void Shared<T>::IncRef(T* p) {
if (p) {
ref_count(p).Inc();
}
}
template <typename T>
void Shared<T>::DecRef(T* p) {
if (p) {
if (ref_count(p).Dec() == 0) {
Destroy(p);
}
}
}
template <typename T>
Shared<T>& Shared<T>::operator=(const Shared<T>& other) {
if (other.p_ == p_) return *this;
// Other can be inside of something owned by this, so we should be careful to
// incref other before we decref
// ourselves.
T* tmp = other.p_;
IncRef(tmp);
DecRef(p_);
p_ = tmp;
return *this;
}
template <typename T, typename... Args>
T* make_shared(Args&&... args) {
return new T(args...);
}
template <typename T>
Shared<T>& Shared<T>::operator=(T* x) {
if (p_ == x) return *this;
T* tmp = x;
IncRef(tmp);
DecRef(p_);
p_ = tmp;
return *this;
}
template <typename T>
Shared<T>::~Shared() {
DecRef(p_);
p_ = nullptr;
}
template <typename T>
void Shared<T>::Reset(T* x) {
if (x) IncRef(x);
DecRef(p_);
p_ = x;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/string.h"
#include <stdarg.h>
#include <cstring>
namespace infrt {
namespace infrt {
std::string StringFormat(const std::string &fmt_str, ...) {
/* Reserve two times as much as the length of the fmt_str */
int final_n, n = (static_cast<int>(fmt_str.size())) * 2;
std::unique_ptr<char[]> formatted;
va_list ap;
while (1) {
formatted.reset(
new char[n]); /* Wrap the plain char array into the unique_ptr */
std::strcpy(&formatted[0], fmt_str.c_str()); // NOLINT
va_start(ap, fmt_str);
final_n = vsnprintf(&formatted[0], n, fmt_str.c_str(), ap);
va_end(ap);
if (final_n < 0 || final_n >= n)
n += abs(final_n - n + 1);
else
break;
}
return std::string(formatted.get());
}
std::string Trim(const std::string &s, const char *empty) {
if (s.empty()) return s;
auto start = s.find_first_not_of(empty);
if (start == std::string::npos) return "";
auto end = s.find_last_not_of(empty);
return s.substr(start, end - start + 1);
}
std::string Uppercase(const std::string &x) {
auto res = x;
for (auto &c : res) {
c = toupper(c);
}
return res;
}
bool Startswith(const std::string &x, const std::string &str) {
return x.find(str) == 0;
}
bool Endswith(const std::string &x, const std::string &str) {
if (x.length() >= str.length()) {
return std::equal(str.rbegin(), str.rend(), x.rbegin());
}
return false;
}
std::vector<std::string> Split(const std::string &str,
const std::string &splitter) {
std::vector<std::string> results;
std::string::size_type pos1, pos2;
pos2 = str.find(splitter);
pos1 = 0;
while (std::string::npos != pos2) {
results.push_back(str.substr(pos1, pos2 - pos1));
pos1 = pos2 + splitter.size();
pos2 = str.find(splitter, pos1);
}
if (pos1 != str.length()) {
results.push_back(str.substr(pos1));
}
return results;
}
void Replace(std::string *s, const std::string &from, const std::string &to) {
size_t pos = 0;
while ((pos = s->find(from, pos)) != std::string::npos) {
s->replace(pos, from.size(), to);
pos += to.length();
}
}
size_t Count(std::string *s, const std::string &sub) {
size_t pos = 0;
size_t times = 0;
while ((pos = s->find(sub, pos)) != std::string::npos) {
if ((pos == 0 || !IsPrefix(s->at(pos - 1))) &&
(pos + sub.length() == s->size() ||
!IsSuffix(s->at(pos + sub.length())))) {
pos += sub.length();
times++;
} else {
pos++;
}
}
return times;
}
bool IsPrefix(const char &c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_');
}
bool IsSuffix(const char &c) {
return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c == '_') ||
(c >= '0' && c <= '9') || (c == '\'');
}
std::string TransValidVarName(std::string name) {
Replace(&name, ".", "__");
Replace(&name, "/", "___");
name.erase(0, name.find_first_not_of("_"));
return name;
}
} // namespace infrt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <sstream>
#include <string>
#include <vector>
namespace infrt {
namespace infrt {
//! Get the content of a stream.
template <typename T>
std::string GetStreamCnt(const T& x);
/**
* Construct a formatted string with arguments.
* @param fmt_str The format.
* @param ... The parameters of the format.
* @return The formated string.
*/
std::string StringFormat(const std::string& fmt_str, ...);
/**
* Join multiple fields to a single string. Similar to Python's str.join method.
*/
template <typename T = std::string>
std::string Join(const std::vector<T>& fields, const std::string& splitter) {
if (fields.empty()) return "";
std::stringstream ss;
for (int i = 0; i < fields.size() - 1; i++) ss << fields[i] << splitter;
ss << fields.back();
return ss.str();
}
std::vector<std::string> Split(const std::string& str,
const std::string& splitter);
std::string Trim(const std::string& s, const char* empty = " \n\r\t");
//! Convert a string to its uppercase.
std::string Uppercase(const std::string& x);
//! Replace a substr 'from' to 'to' in string s.
void Replace(std::string* s, const std::string& from, const std::string& to);
//! Count how many times substr 'sub' appears in string s.
size_t Count(std::string* s, const std::string& sub);
//! Tell if a char is prefix of a tensor's name.
bool IsPrefix(const char& c);
//! Tell if a char is suffix of a tensor's name.
bool IsSuffix(const char& c);
//! Tell if a string \p x start with \p str.
bool Startswith(const std::string& x, const std::string& str);
//! Tell if a string \p x ends with \p str.
bool Endswith(const std::string& x, const std::string& str);
template <typename T>
std::string GetStreamCnt(const T& x) {
std::stringstream os;
os << x;
return os.str();
}
std::string TransValidVarName(std::string name);
} // namespace infrt
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/target.h"
#include <glog/logging.h>
namespace infrt {
namespace common {
bool Target::operator==(const Target &other) const {
return os == other.os && //
arch == other.arch && //
bits == other.bits && //
features == other.features;
}
int Target::max_num_threads() const {
CHECK(arch == Arch::NVGPU)
<< "The target is not NVGPU! Cannot get max number of threads.";
return 1024;
}
std::vector<Target::Lib> Target::get_target_libs() const { return libs; }
int Target::get_target_bits() const {
switch (bits) {
case Bit::k32:
return 32;
case Bit::k64:
return 64;
case Bit::Unk:
return 0;
default:
LOG(FATAL) << "Not supported Bit";
}
return -1;
}
std::ostream &operator<<(std::ostream &os, const Target &target) {
os << "Target<";
switch (target.os) {
case Target::OS::Linux:
os << "linux";
break;
case Target::OS::Windows:
os << "windows";
break;
case Target::OS::Unk:
os << "unk";
break;
}
os << ",";
switch (target.arch) {
case Target::Arch::X86:
os << "x86";
break;
case Target::Arch::ARM:
os << "arm";
break;
case Target::Arch::NVGPU:
os << "nvgpu";
break;
case Target::Arch::Unk:
os << "unk";
break;
}
os << ",";
switch (target.bits) {
case Target::Bit::k32:
os << "32";
break;
case Target::Bit::k64:
os << "64";
break;
case Target::Bit::Unk:
os << "unk";
break;
}
os << ">";
return os;
}
std::ostream &operator<<(std::ostream &os, Target::Arch arch) {
switch (arch) {
case Target::Arch::Unk:
os << "Unk";
break;
case Target::Arch::X86:
os << "X86";
break;
case Target::Arch::ARM:
os << "ARM";
break;
case Target::Arch::NVGPU:
os << "NVGPU";
break;
}
return os;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ostream>
#include <vector>
namespace infrt {
namespace common {
struct Target {
/**
* The operating system used by the target. Determines which system calls to
* generate.
*/
enum class OS : int {
Unk = -1,
Linux,
Windows,
};
/**
* The architecture used by the target. Determines the instruction set to use.
*/
enum class Arch : int {
Unk = -1,
X86,
ARM,
NVGPU,
};
enum class Bit : int {
Unk = -1,
k32,
k64,
};
OS os{OS::Unk};
Arch arch{Arch::Unk};
Bit bits{Bit::Unk};
enum class Feature : int {
JIT = 0,
Debug,
};
/**
* The library used by the target.
*/
enum class Lib : int {
Unk = -1,
MKL,
};
std::vector<Feature> features;
std::vector<Lib> libs;
explicit Target(OS o = OS::Linux,
Arch a = Arch::Unk,
Bit b = Bit::Unk,
const std::vector<Feature>& features = {},
const std::vector<Lib>& libs = {})
: os(o), arch(a), bits(b), features(features), libs(libs) {}
bool defined() const {
return os != OS::Unk && arch != Arch::Unk && bits != Bit::Unk;
}
int max_num_threads() const;
int get_target_bits() const;
std::vector<Lib> get_target_libs() const;
bool operator==(const Target& other) const;
bool operator!=(const Target& other) const { return !(*this == other); }
friend std::ostream& operator<<(std::ostream& os, const Target& target);
};
static const Target& UnkTarget() {
static Target target(
Target::OS::Unk, Target::Arch::Unk, Target::Bit::Unk, {}, {});
return target;
}
static const Target& DefaultHostTarget() {
static Target target(
Target::OS::Linux, Target::Arch::X86, Target::Bit::k64, {}, {});
return target;
}
static const Target& DefaultNVGPUTarget() {
static Target target(
Target::OS::Linux, Target::Arch::NVGPU, Target::Bit::k64, {}, {});
return target;
}
std::ostream& operator<<(std::ostream& os, Target::Arch arch);
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/common/type.h"
#include <utility>
namespace infrt {
namespace common {
struct Type::Storage {
Storage() = default;
Storage(type_t t, int b, int w) : type_(t), bits_(b), lanes_(w) {}
type_t type_{type_t::Unk};
cpp_type_t cpp_type_{cpp_type_t::None};
//! How many bits per element.
int bits_{};
//! How many elements(if a vector type), for scalar types, it should be 1.
int lanes_{1};
//! Name of the customized type.
std::string customized_type_;
};
Type::~Type() {}
std::ostream &operator<<(std::ostream &os, const Type &t) {
if (t.is_cpp_const()) os << "const ";
switch (t.type()) {
case Type::type_t::Int:
if (t.bits() == 1) {
os << "bool";
} else {
os << "int" << t.bits();
}
break;
case Type::type_t::UInt:
os << "uint" << t.bits();
break;
case Type::type_t::Float:
os << "float" << t.bits();
break;
case Type::type_t::Void:
os << "void";
break;
case Type::type_t::Customized:
os << t.customized_type();
break;
case Type::type_t::String:
os << "string";
break;
case Type::type_t::Unk:
os << "unk";
break;
}
if (t.lanes() > 1) os << "<" << t.lanes() << ">";
if (t.is_cpp_handle()) os << "*";
if (t.is_cpp_handle2()) os << "**";
return os;
}
std::ostream &operator<<(std::ostream &os, Type::type_t t) {
switch (t) {
case Type::type_t::String:
os << "String";
break;
case Type::type_t::Void:
os << "Void";
break;
case Type::type_t::UInt:
os << "UInt";
break;
case Type::type_t::Int:
os << "Int";
break;
case Type::type_t::Float:
os << "Float";
break;
case Type::type_t::Unk:
os << "Unk";
break;
case Type::type_t::Customized:
os << "Customized";
}
return os;
}
Type &Type::set_cpp_handle(bool x) {
// unset the other handle-related bits.
set_cpp_handle2(false);
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));
// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
if (x)
v |= static_cast<uint8_t>(cpp_type_t::Handle);
else
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
return *this;
}
Type &Type::set_cpp_handle2(bool x) {
auto &v = (*reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_));
// unset the other handle-related bits.
v &= ~static_cast<uint8_t>(cpp_type_t::Handle);
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
if (x)
v |= static_cast<uint8_t>(cpp_type_t::HandleHandle);
else
v &= ~static_cast<uint8_t>(cpp_type_t::HandleHandle);
return *this;
}
Type Type::VectorOf(int w) const {
CheckTypeValid();
return Type(type(), w, bits());
}
Type::Type(const Type &other) {
if (other.storage_) storage_.reset(new Storage(*other.storage_));
}
Type Type::ElementOf() const {
CheckTypeValid();
auto type = *this;
type.storage_->lanes_ = 1;
return type;
}
void Type::CheckTypeValid() const { CHECK_NE(GetStorage().type_, type_t::Unk); }
Type Type::PointerOf() const {
CheckTypeValid();
auto x = *this;
CHECK(!x.is_cpp_handle2()) << "Not support three level of PointerOf";
if (x.is_cpp_handle())
x.set_cpp_handle2();
else
x.set_cpp_handle();
return x;
}
Type Type::ConstOf() const {
CheckTypeValid();
auto x = *this;
x.set_cpp_const();
return x;
}
Type Type::IgnoreConst() const {
CheckTypeValid();
auto x = *this;
x.set_cpp_const(false);
return x;
}
Type Type::with_bits(int x) const {
CHECK(is_primitive());
Type type = *this;
type.GetStorage().bits_ = x;
return type;
}
Type Type::with_type(Type::type_t x) const {
Type type = *this;
type.GetStorage().type_ = x;
return type;
}
Type Type::with_lanes(int x) const {
CHECK(valid());
Type type = *this;
type.GetStorage().lanes_ = x;
return type;
}
Type Type::with_cpp_const(bool x) const {
Type type = *this;
type.set_cpp_const(x);
return type;
}
Type &Type::set_cpp_const(bool is_const) {
uint8_t &data = *reinterpret_cast<uint8_t *>(&GetStorage().cpp_type_);
if (is_const) {
data |= static_cast<uint8_t>(cpp_type_t::Const);
} else {
data &= ~(static_cast<uint8_t>(cpp_type_t::Const));
}
return *this;
}
Type &Type::set_customized_type(const std::string &t) {
GetStorage().type_ = type_t::Customized;
GetStorage().customized_type_ = t;
return *this;
}
bool Type::valid() const {
if (is_unk()) return false;
if (is_customized()) {
return !GetStorage().customized_type_.empty();
}
if (is_primitive()) {
return bits() != 0;
}
return true;
}
Type::Type(Type::type_t t, int b, int w) : storage_(new Storage(t, b, w)) {}
bool Type::is_primitive() const {
return !is_unk() && type() != type_t::Customized;
}
bool Type::is_customized() const {
return !is_unk() && type() == type_t::Customized;
}
bool Type::is_unk() const { return type() == type_t::Unk; }
bool Type::is_bool() const { return type() == type_t::UInt && bits() == 1; }
bool Type::is_void() const { return type() == type_t::Void; }
bool Type::is_vector() const { return lanes() > 1; }
bool Type::is_scalar() const { return lanes() == 1; }
bool Type::is_float(int bits) const {
return type() == type_t::Float && (bits < 0 || bits == this->bits());
}
bool Type::is_uint(int bits) const {
return type() == type_t::UInt && (bits < 0 || bits == this->bits());
}
bool Type::is_int(int bits) const {
return type() == type_t::Int && (bits < 0 || bits == this->bits());
}
bool Type::is_integer(int bits) const {
return (type() == type_t::Int || type() == type_t::UInt) &&
(bits < 0 || bits == this->bits());
}
bool Type::is_index_type() {
return is_int() && lanes() == 1 && (bits() == 32 || bits() == 64);
}
bool Type::is_cpp_handle() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) &
static_cast<uint8_t>(cpp_type_t::Handle);
}
bool Type::is_cpp_handle2() const {
return static_cast<uint8_t>(GetStorage().cpp_type_) &
static_cast<uint8_t>(cpp_type_t::HandleHandle);
}
bool Type::is_cpp_const() const {
return static_cast<uint8_t>(cpp_type_t::Const) &
static_cast<uint8_t>(GetStorage().cpp_type_);
}
const std::string &Type::customized_type() const {
return GetStorage().customized_type_;
}
bool Type::is_customized_type() const {
return !GetStorage().customized_type_.empty();
}
Type::type_t Type::type() const { return GetStorage().type_; }
int Type::bits() const { return GetStorage().bits_; }
int Type::lanes() const { return GetStorage().lanes_; }
Type::cpp_type_t Type::cpp_type() const { return GetStorage().cpp_type_; }
bool Type::operator==(const Type &other) const {
return type() == other.type() && bits() == other.bits() &&
lanes() == other.lanes() &&
GetStorage().cpp_type_ == other.GetStorage().cpp_type_ &&
customized_type() == other.customized_type();
}
bool Type::is_string() const { return type() == type_t::String; }
Type &Type::operator=(const Type &other) {
if (other.storage_) storage_.reset(new Storage(*other.storage_));
return *this;
}
Type::Storage &Type::GetStorage() { return *storage_; }
const Type::Storage &Type::GetStorage() const { return *storage_; }
Type::Type() : storage_(new Storage) {}
Type::Type(Type &&other) : storage_(std::move(other.storage_)) {}
const Type &F16() {
static auto t = Float(16);
return t;
}
const Type &F32() {
static auto t = Float(32);
return t;
}
const Type &F64() {
static auto t = Float(64);
return t;
}
const Type &I8() {
static auto t = Int(8);
return t;
}
const Type &I16() {
static auto t = Int(16);
return t;
}
const Type &I32() {
static auto t = Int(32);
return t;
}
const Type &I64() {
static auto t = Int(64);
return t;
}
const Type &UI8() {
static auto t = UInt(8);
return t;
}
const Type &UI16() {
static auto t = UInt(16);
return t;
}
const Type &UI32() {
static auto t = UInt(32);
return t;
}
const Type &UI64() {
static auto t = UInt(64);
return t;
}
const Type &I1() {
static auto t = Int(1);
return t;
}
const Type &UI1() {
static auto t = UInt(1);
return t;
}
} // namespace common
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <memory>
#include <string>
#include "paddle/infrt/common/macros.h"
//! Much of the concepts are borrowed from Halide project.
namespace infrt {
namespace common {
/**
* Types in the INFRT type system. They can be ints, unsigned ints, or floats of
* various bit-widths.
* They can also be vectors of the same (by setting the `lanes` field to
* something larger than one).
* NOTE: Front-end code other than vectorize shouldn't use vector types.
*/
struct Type {
enum class type_t {
Unk = -1,
Int,
UInt,
Float,
String,
Void,
// stupid idea to mix the Customized with other primitive types, large
// refactor needs here.
Customized, // Customized type
};
//! type decorators in C++, the different code can used together.
enum class cpp_type_t : uint8_t {
None = 0, // None information.
Const = 1, // const.
Handle = 1 << 1, // pointer type, such as `infrt_buffer_t*`.
HandleHandle = 1 << 2, // pointer of pointer, such as `infrt_buffer_t**`.
};
Type();
Type(type_t t, int b, int w);
Type(const Type& other);
explicit Type(Type&& other);
Type& operator=(const Type& other);
INFRT_NODISCARD bool is_primitive() const;
INFRT_NODISCARD bool is_customized() const;
INFRT_NODISCARD bool valid() const;
//! Some helper functions to check a type.
// @{
INFRT_NODISCARD bool is_unk() const;
INFRT_NODISCARD bool is_void() const;
INFRT_NODISCARD bool is_bool() const;
INFRT_NODISCARD bool is_vector() const;
INFRT_NODISCARD bool is_scalar() const;
INFRT_NODISCARD bool is_float(int bits = -1) const;
INFRT_NODISCARD bool is_int(int bits = -1) const;
INFRT_NODISCARD bool is_integer(int bits = -1) const;
INFRT_NODISCARD bool is_uint(int bits = -1) const;
INFRT_NODISCARD bool is_string() const;
INFRT_NODISCARD bool is_index_type();
// @}
Type& set_cpp_handle(bool x = true);
INFRT_NODISCARD bool is_cpp_handle() const;
Type& set_cpp_handle2(bool x = true);
INFRT_NODISCARD bool is_cpp_handle2() const;
Type& set_cpp_const(bool is_const = true);
INFRT_NODISCARD bool is_cpp_const() const;
Type& set_customized_type(const std::string& t);
const std::string& customized_type() const;
INFRT_NODISCARD bool is_customized_type() const;
// Get a new type with bits set to \p x.
Type with_bits(int x) const;
// Get a new type with type set to \p x.
Type with_type(type_t x) const;
// Get a new type with lanes set to \p x.
Type with_lanes(int x) const;
// Get a new type with cpp_const set to \p x.
Type with_cpp_const(bool x = true) const;
//! Getters
// @{
type_t type() const;
int bits() const;
int lanes() const;
cpp_type_t cpp_type() const;
// @}
//! Compare two types for equality.
bool operator==(const Type& other) const;
//! Compare two types for inequality.
bool operator!=(const Type& other) const { return !(*this == other); }
//! Generate a vector of this type, with `w` elements.
Type VectorOf(int w) const;
//! Generate a element type of this type.
Type ElementOf() const;
//! Generate the address type.
Type PointerOf() const;
//! Ignore const.
Type IgnoreConst() const;
//! Add const.
Type ConstOf() const;
friend std::ostream& operator<<(std::ostream& os, const Type& t);
~Type();
private:
void CheckTypeValid() const;
struct Storage;
Storage& GetStorage();
const Storage& GetStorage() const;
std::unique_ptr<Storage> storage_;
}; // namespace common
inline Type Void() { return Type(Type::type_t::Void, 1, 0); }
inline Type Int(int bits, int lanes = 1) {
return Type(Type::type_t::Int, bits, lanes);
}
inline Type UInt(int bits, int lanes = 1) {
return Type(Type::type_t::UInt, bits, lanes);
}
inline Type Float(int bits, int lanes = 1) {
return Type(Type::type_t::Float, bits, lanes);
}
inline Type Bool(int lanes = 1) { return Type(Type::type_t::UInt, 1, lanes); }
inline Type String() { return Type(Type::type_t::String, 1, 1); }
//! Builtin native types as global singletons.
// @{
const Type& F16();
const Type& F32();
const Type& F64();
const Type& I8();
const Type& I16();
const Type& I32();
const Type& I64();
const Type& UI8();
const Type& UI16();
const Type& UI32();
const Type& UI64();
const Type& I1();
const Type& UI1();
// @}
template <typename T>
Type type_of();
// clang-format off
template <> inline Type type_of<float>() { return F32(); }
template <> inline Type type_of<double>() { return F64(); }
template <> inline Type type_of<unsigned char>() { return UI8(); }
template <> inline Type type_of<int16_t>() { return UI16(); }
template <> inline Type type_of<int32_t>() { return I32(); }
template <> inline Type type_of<uint32_t>() { return UI32(); }
template <> inline Type type_of<bool>() { return UI1(); }
template <> inline Type type_of<char>() { return I8(); }
template <> inline Type type_of<int64_t>() { return I64(); }
template <> inline Type type_of<uint64_t>() { return UI64(); }
template <> inline Type type_of<signed char>() { return I8(); }
template <> inline Type type_of<void>() { return Void(); }
// clang-format on
template <>
inline Type type_of<int8_t*>() {
Type x = Int(8);
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<void*>() {
Type x = type_of<void>();
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<void**>() {
Type x = type_of<void>();
x.set_cpp_handle2();
return x;
}
template <>
inline Type type_of<float*>() {
Type x = type_of<float>();
x.set_cpp_handle();
return x;
}
template <>
inline Type type_of<double*>() {
Type x = type_of<double>();
x.set_cpp_handle();
return x;
}
std::ostream& operator<<(std::ostream& os, Type::type_t t);
} // namespace common
} // namespace infrt
core_gather_headers()
gather_srcs(infrt_src SRCS
dialect.cc
types.cc
basic_kernels.cc
test_kernels.cc
infrt_base.cc
init_infrt_dialects.cc
tensor_shape.cc
dense_tensor.cc
mlir_loader.cc
diagnostic_utils.cc
pd_types.cc
pd_ops.cc
)
mlir_tablegen_on(ops)
mlir_tablegen_on(basic_kernels)
mlir_tablegen_on(test_kernels)
mlir_tablegen_on(infrt_base DIALECT infrt)
mlir_tablegen_on(tensor_shape DIALECT ts)
mlir_tablegen_on(dense_tensor DIALECT dt)
mlir_tablegen_on(pd_op_base DIALECT pd)
mlir_tablegen_on(pd_ops)
mlir_add_rewriter(rewrite)
# TODO(Superjomn) add a cmake function cc_executable to ecapsulate the following code
add_executable(infrtopt opt.cc)
target_link_libraries(infrtopt infrt ${mlir_libs})
add_dependencies(infrtopt infrt)
add_executable(print-ir print_ir.cc)
target_link_libraries(print-ir infrt ${mlir_libs})
add_dependencies(print-ir pd_ops_inc)
# MLIR opt tests
# %{
set(infrt_opt_path ${CMAKE_BINARY_DIR}/infrt/dialect/infrtopt)
add_test(test_infrt_mlir_opt_on_basic ${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/basic.mlir)
add_test(test_infrt_mlir_opt_on_tensor_shape ${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/tensor_shape.mlir)
add_test(test_infrt_mlir_opt_on_paddle_ops
${infrt_opt_path}
${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/paddle_ops.mlir)
# %}
cc_test_tiny(test_infrt_mlir_loader SRCS mlir_loader_test.cc DEPS infrt ${MLIR_IR_LIBS})
# execute mlir and run FileCheck
infrt_exec_check(run_and_check_tensor_type mlir_tests/tensor_type.mlir)
infrt_exec_check(run_and_check_basic mlir_tests/basic.mlir)
infrt_exec_check(run_and_check_benchmark mlir_tests/benchmark.mlir)
#infrt_exec_check(run_and_check_dense_tensor mlir_tests/dense_tensor.mlir)
add_test(test_infrt_mlir_dense_tensor
${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec
-i
${CMAKE_CURRENT_SOURCE_DIR}/mlir_tests/dense_tensor.mlir)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/basic_kernels.h"
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include "paddle/infrt/dialect/dense_tensor.h"
namespace infrt::dialect {
using namespace mlir; // NOLINT
static ParseResult parseCallOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SymbolRefAttr callee_attr;
FunctionType callee_type;
SmallVector<OpAsmParser::OperandType, 4> operands;
auto callee_loc = parser.getNameLoc();
if (parser.parseAttribute(callee_attr, "callee", result.attributes) ||
parser.parseOperandList(operands, OpAsmParser::Delimiter::Paren) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(callee_type) ||
parser.addTypesToList(callee_type.getResults(), result.types) ||
parser.resolveOperands(
operands, callee_type.getInputs(), callee_loc, result.operands))
return failure();
return success();
}
static ParseResult parseConstantOp(Type attrType,
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
Attribute valueAttr;
if (parser.parseOptionalAttrDict(result.attributes) ||
parser.parseAttribute(valueAttr, attrType, "value", result.attributes) ||
parser.addTypeToList(attrType, result.types))
return failure();
return success();
}
static ParseResult parseConstantF32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
FloatType::getF32(result.getContext()), parser, result);
}
static ParseResult parseConstantF64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
FloatType::getF64(result.getContext()), parser, result);
}
static ParseResult parseConstantI32Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(32, result.getContext()), parser, result);
}
static ParseResult parseConstantI64Op(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
return parseConstantOp(
IntegerType::get(64, result.getContext()), parser, result);
}
static ParseResult parseReturnOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 2> opInfo;
SmallVector<Type, 2> types;
llvm::SMLoc loc = parser.getCurrentLocation();
return failure(parser.parseOperandList(opInfo) ||
(!opInfo.empty() && parser.parseColonTypeList(types)) ||
parser.resolveOperands(opInfo, types, loc, result.operands));
}
static void print(OpAsmPrinter &p, CallOp op) { // NOLINT
p << "infrt.call " << op.getAttr("callee") << "(";
p.printOperands(op.getOperands());
p << ")";
p.printOptionalAttrDict(op.getAttrs(), {"callee"});
p << " : ";
}
static void printConstant(OpAsmPrinter &p, mlir::Operation *op) { // NOLINT
p << op->getName() << " ";
p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"value"});
if (op->getAttrs().size() > 1) p << ' ';
Attribute attr = op->getAttr("value");
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
bool is_signed = int_attr.getType().isIndex() ||
int_attr.getType().getIntOrFloatBitWidth() != 1;
int_attr.getValue().print(p.getStream(), is_signed);
} else if (auto float_attr = attr.dyn_cast<FloatAttr>()) {
p << float_attr.getValue().convertToFloat();
} else {
op->emitOpError("unknown attribute type");
}
}
static void print(OpAsmPrinter &p, ConstantF32Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantF64Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantI32Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ConstantI64Op op) { // NOLINT
printConstant(p, op);
}
static void print(OpAsmPrinter &p, ReturnOp op) { // NOLINT
p << "infrt.return";
if (op.getNumOperands() > 0) {
p << ' ';
p.printOperands(op.getOperands());
p << " : ";
llvm::interleaveComma(op.getOperands(), p);
}
}
static LogicalResult verify(CallOp op) { return success(); }
static LogicalResult verify(ConstantF32Op op) { return success(); }
static LogicalResult verify(ConstantI32Op op) { return success(); }
static LogicalResult verify(ConstantF64Op op) { return success(); }
static LogicalResult verify(ConstantI64Op op) { return success(); }
static LogicalResult verify(ReturnOp op) {
auto function = dyn_cast<FuncOp>(op.getParentOp());
if (!function) return success();
auto results = function.getType().getResults();
if (op.getNumOperands() != results.size())
return op.emitOpError("has ")
<< op.getNumOperands()
<< " operands, but enclosing function returns " << results.size();
return success();
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
using namespace mlir; // NOLINT
namespace infrt::dialect {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/basic_kernels.hpp.inc"
} // namespace infrt::dialect
// Operation definitions for basic kernels.
#ifdef BASIC_OPS
#else
#define BASIC_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> : Op<INFRT_Dialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
// Each registered op needs to provide all of a printer, parser and verifier.
let printer = [{ return infrt::dialect::print(p, *this); }];
let verifier = [{ return infrt::dialect::verify(*this); }];
let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }];
}
def CallOp : INFRT_Op<"call"> {
let summary = "call a host operation";
let description = [{
The "infrt.call" operation represents a direct call to a function. The operands and result types of the call must match the specified function type.
%2 = infrt.call @add(%0, %1) : (f32, f32) -> f32
}];
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<AnyType>);
let extraClassDeclaration = [{
StringRef getCallee() { return callee(); }
mlir::FunctionType getCalleeType();
}];
}
class ConstantOp<string suffix, Type baseType, Attr attr>
: INFRT_Op<"constant." # suffix, [NoSideEffect]> {
let summary = "constant value constructor in host";
let arguments = (ins attr:$value);
let results = (outs baseType);
}
def ConstantI32Op : ConstantOp<"i32", I32, I32Attr>;
def ConstantI64Op : ConstantOp<"i64", I64, I64Attr>;
def ConstantF32Op : ConstantOp<"f32", F32, F32Attr>;
def ConstantF64Op : ConstantOp<"f64", F64, F64Attr>;
def ReturnOp : INFRT_Op<"return", [Terminator]> {
let summary = "host executor return operation";
let description = [{
The "infrt.return" operation represents a return operation within a function.
func @foo() : (i32, f8) {
infrt.return %0, %1 : i32, f8
}
}];
let arguments = (ins Variadic<AnyType>:$operands);
let builders = [OpBuilder<
"OpBuilder &b, OperationState &result",
[{ build(b, result, llvm::None); }]>];
}
class AddOp<string suffix, Type type> : INFRT_Op<"add." # suffix, [NoSideEffect]> {
let summary = "infrt.add operation";
let description = [{
An operation that takes two inputs and returns their sum as result.
}];
let arguments = (ins type, type);
let results = (outs type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
def AddI32Op : AddOp<"i32", I32>;
def AddI64Op : AddOp<"i64", I64>;
def AddF32Op : AddOp<"f32", F32>;
def AddF64Op : AddOp<"f64", F64>;
class MulOp<string suffix, Type type> : INFRT_Op<"mul." # suffix, [NoSideEffect]> {
let summary = "infrt.mul operation";
let description = [{
An operation that takes two inputs and returns their mul as result.
}];
let arguments = (ins type, type);
let results = (outs type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
def MulI32Op : MulOp<"i32", I32>;
def MulI64Op : MulOp<"i64", I64>;
def MulF32Op : MulOp<"f32", F32>;
def MulF64Op : MulOp<"f64", F64>;
class PrintOp<string suffix, Type type> : INFRT_Op<"print." # suffix> {
let summary = "infrt.print operation";
let description = [{
An operation takes a number as input and prints to stdout.
}];
let arguments = (ins type);
let assemblyFormat = "operands attr-dict";
let verifier = ?;
}
//def PrintI32Op : PrintOp<"i32", I32>;
//def PrintI64Op : PrintOp<"i64", I64>;
def PrintF32Op : PrintOp<"f32", F32>;
//def PrintF64Op : PrintOp<"f64", F64>;
def GetStringOp : INFRT_Op<"get_string"> {
let summary = "infrt.get_string";
let description = [{
Get a !infrt.string value from the given string attribute.
}];
let arguments = (ins StrAttr:$value);
let results = (outs StringType);
let assemblyFormat = "`(` $value `)` attr-dict";
let verifier = ?;
}
def PrintStringOp : INFRT_Op<"print_string"> {
let summary = "infrt.print_string";
let description = [{
An operation that prints a string.
}];
let arguments = (ins StringType:$input);
let results = (outs);
let assemblyFormat = "`(` $input `)` attr-dict";
let verifier = ?;
}
#endif // basic kernels
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/dense_tensor.h"
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
#include <tuple>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt::dt {
void DTDialect::initialize() {
allowUnknownTypes();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/dense_tensor.cpp.inc"
>();
}
namespace detail {
struct TensorTypeStorage : public mlir::TypeStorage {
TensorTypeStorage(TargetType target,
LayoutType layout,
PrecisionType precision)
: target_(target), layout_(layout), precision_(precision) {}
using KeyTy = std::tuple<TargetType, LayoutType, PrecisionType>;
bool operator==(const KeyTy &key) const {
return key == KeyTy(target_, layout_, precision_);
}
static llvm::hash_code hashKey(const KeyTy &key) {
return llvm::hash_value(key);
}
static TensorTypeStorage *construct(
mlir::TypeStorageAllocator &allocator, // NOLINT
const KeyTy &key) {
return new (allocator.allocate<TensorTypeStorage>())
TensorTypeStorage(std::get<0>(key), std::get<1>(key), std::get<2>(key));
}
TargetType target_;
LayoutType layout_;
PrecisionType precision_;
};
} // namespace detail
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key) {
if (key.equals_lower("x86"))
return TargetType::X86;
else if (key.equals_lower("cuda"))
return TargetType::CUDA;
else
return llvm::None;
}
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key) {
if (key.equals_lower("nchw"))
return LayoutType::NCHW;
else if (key.equals_lower("nhwc"))
return LayoutType::NHWC;
else
return llvm::None;
}
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key) {
if (key.equals_lower("i32"))
return PrecisionType::I32;
else if (key.equals_lower("f32"))
return PrecisionType::F32;
else
return llvm::None;
}
TensorType TensorType::get(TargetType target,
LayoutType layout,
PrecisionType precision) {
return Base::get(
::infrt::Global::getMLIRContext(), target, layout, precision);
}
TargetType TensorType::target() { return getImpl()->target_; }
LayoutType TensorType::layout() { return getImpl()->layout_; }
PrecisionType TensorType::precision() { return getImpl()->precision_; }
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType) {
os << "TensorType<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return os;
}
TensorMapType TensorMapType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
TensorMapType TensorMapType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
StringType StringType::get() {
return Base::get(::infrt::Global::getMLIRContext());
}
StringType StringType::get(mlir::MLIRContext *context) {
return Base::get(context);
}
raw_ostream &operator<<(raw_ostream &os, TargetType type) {
switch (type) {
case (TargetType::X86):
os << "X86";
break;
case (TargetType::CUDA):
os << "CUDA";
break;
default:
os << "Unsupported";
}
return os;
}
raw_ostream &operator<<(raw_ostream &os, LayoutType type) {
switch (type) {
case (LayoutType::NCHW):
os << "NCHW";
break;
case (LayoutType::NHWC):
os << "NHWC";
break;
default:
os << "Unsupported";
}
return os;
}
raw_ostream &operator<<(raw_ostream &os, PrecisionType type) {
switch (type) {
case (PrecisionType::I32):
os << "I32";
break;
case (PrecisionType::F32):
os << "F32";
break;
default:
os << "Unsupported";
}
return os;
}
static Type getTensorType(mlir::MLIRContext *context) {
auto t_dialect = Identifier::get("t", context);
return OpaqueType::get(t_dialect, "tensor", context);
}
static ParseResult parseCreateUninitTensorOp(
OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
auto loc = parser.getCurrentLocation();
::mlir::Type outputRawTypes[1];
::llvm::ArrayRef<::mlir::Type> outputTypes(outputRawTypes);
mlir::ArrayAttr shapeAttr;
if (parser.parseAttribute(shapeAttr,
parser.getBuilder().getI64Type(),
"shape",
result.attributes))
return failure();
if (parser.parseOptionalAttrDict(result.attributes)) return failure();
if (parser.parseArrow()) return failure();
if (parser.parseType(outputRawTypes[0])) return failure();
if (!outputRawTypes[0].isa<TensorType>())
return parser.emitError(loc, "invalid kind of type specified");
result.addTypes(outputTypes);
return success();
}
template <typename CreateUninitTensorOp>
static void printCreateUninitTensorOp(OpAsmPrinter &p, // NOLINT
CreateUninitTensorOp op) {
p << CreateUninitTensorOp::getOperationName();
p << " ";
p.printAttributeWithoutType(op.shapeAttr());
p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
p << " -> ";
p << op.getOperation()->getResultTypes();
}
// TODO(shibo): can be removed?
// static ParseResult parseFillTensorWithConstantOp(OpAsmParser& parser,
// OperationState& result) {
// auto loc = parser.getCurrentLocation();
// ::mlir::OpAsmParser::OperandType inputRawOperands[1];
// ::llvm::ArrayRef<::mlir::OpAsmParser::OperandType>
// inputOperands(inputRawOperands);
// ::mlir::Type inputRawTypes[1];
// ::llvm::ArrayRef<::mlir::Type> inputTypes(inputRawTypes);
//
// if (parser.parseOperand(inputRawOperands[0])) return failure();
//
// if (parser.parseColon()) return failure();
// if (parser.parseType(inputRawTypes[0])) return failure();
// if (!inputRawTypes[0].isa<TensorType>())
// return parser.emitError(loc, "invalid kind of type specified");
//
// Attribute value_attr;
// if (parser.resolveOperands(inputOperands, inputTypes, loc, result.operands))
// return failure();
// if (parser.parseAttribute(value_attr, "value", result.attributes)) return
// failure();
// return success();
//}
// TODO(shibo): can be removed?
// template <typename FillTensorOp>
// static void printFillTensorWithConstantOp(OpAsmPrinter& p, FillTensorOp op) {
// p << FillTensorOp::getOperationName();
// p << " ";
// p.printOperand(op.getOperand());
// p << " : ";
// p << op.getOperation()->getOperandTypes();
// p << " ";
// p << op.getAttr("value");
//}
static ParseResult parseSetTensorOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
SmallVector<OpAsmParser::OperandType, 1> operands;
if (parser.parseOperandList(operands, 1)) return failure();
auto tensor_type = getTensorType(result.getContext());
Attribute value_attr;
return failure(
parser.resolveOperand(operands[0], tensor_type, result.operands) ||
parser.parseAttribute(value_attr, "values", result.attributes));
}
template <typename SetTensorOp>
static void printSetTensorOp(OpAsmPrinter &p, SetTensorOp op) { // NOLINT
p << SetTensorOp::getOperationName() << " ";
p.printOperand(op.getOperand());
p << " " << op.getAttr("values");
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.cpp.inc" // NOLINT
} // namespace infrt::dt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Dialect.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <string>
using namespace mlir; // NOLINT
namespace infrt::dt {
namespace detail {
struct TensorTypeStorage;
} // namespace detail
enum class TargetType : uint8_t { X86, CUDA };
enum class LayoutType : uint8_t { NCHW, NHWC };
enum class PrecisionType : uint8_t { I32, F32 };
llvm::Optional<TargetType> GetTargetType(mlir::StringRef key);
llvm::Optional<LayoutType> GetLayoutType(mlir::StringRef key);
llvm::Optional<PrecisionType> GetPrecisionType(mlir::StringRef key);
raw_ostream &operator<<(raw_ostream &os, TargetType type);
raw_ostream &operator<<(raw_ostream &os, LayoutType type);
raw_ostream &operator<<(raw_ostream &os, PrecisionType type);
class TensorType : public mlir::Type::TypeBase<TensorType,
mlir::Type,
detail::TensorTypeStorage> {
public:
using Base::Base;
static TensorType get(TargetType target,
LayoutType layout,
PrecisionType precision);
TargetType target();
LayoutType layout();
PrecisionType precision();
};
raw_ostream &operator<<(raw_ostream &os, TensorType tensorType);
class TensorMapType : public mlir::Type::TypeBase<TensorMapType,
mlir::Type,
mlir::TypeStorage> {
public:
using Base::Base;
static TensorMapType get();
static TensorMapType get(mlir::MLIRContext *context);
};
class StringType
: public mlir::Type::TypeBase<StringType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
static StringType get();
static StringType get(mlir::MLIRContext *context);
};
#include "paddle/infrt/dialect/dense_tensor_dialect.hpp.inc"
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/dense_tensor.hpp.inc"
} // namespace infrt::dt
#ifdef DT_OPS
#else
#define DT_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/tensor_shape_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def DT_Dialect : Dialect {
let name = "dt";
let description = [{
The DenseTensor dialect.
}];
let cppNamespace = "::infrt::dt";
}
class DT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<DT_Dialect, mnemonic, traits>;
class CreateUninitTensorOp<string dtype>
: DT_Op<"create_uninit_tensor." # dtype, [NoSideEffect]> {
let summary = "dt.create_uninit_tensor operation";
let description = [{
An operation that creates an uninitialized tensor.
}];
let arguments = (ins I64ArrayAttr:$shape);
let results = (outs TensorType:$output);
let parser = [{ return infrt::dt::parseCreateUninitTensorOp(parser, result); }];
let printer = [{ return infrt::dt::printCreateUninitTensorOp(p, *this); }];
}
def ShallowCopyTensorOp
: DT_Op<"shallow_copy_tensor", [NoSideEffect]> {
let summary = "dt.shallow_copy_tensor operation";
let description = [{
An operation that copy a tensor shallowly.
}];
let arguments = (ins TensorType:$input);
let results = (outs TensorType:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
}
class FillTensorWithConstantOp<string dtype> :
DT_Op<"fill_tensor_with_constant." # dtype> {
let summary = "dt.fill_tensor_with_constant operation";
let description = [{
An operation that fills an input tensor with a value.
}];
let arguments = (ins
TensorType:$input,
AnyAttr:$value
);
let results = (outs);
// TODO: can be removed?
//let parser = [{ return infrt::dt::parseFillTensorWithConstantOp(parser, result); }];
//let printer = [{ return infrt::dt::printFillTensorWithConstantOp(p, *this); }];
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
def PrintTensorOp : DT_Op<"print_tensor"> {
let summary = "dt.print_tensor operation";
let description = [{
An operation that prints a tensor.
}];
let arguments = (ins TensorType:$input);
let results = (outs);
let assemblyFormat = "`(` $input `:` type($input) `)` attr-dict";
}
class SetTensorOp<string dtype> :
DT_Op<"set_tensor_with_constant_values." # dtype> {
let summary = "dt.set_tensor_with_constant_values operation";
let description = [{
An operation that sets an input tensor with given values.
}];
let arguments = (ins TensorType);
let results = (outs);
let parser = [{ return infrt::dt::parseSetTensorOp(parser, result); }];
let printer = [{ return infrt::dt::printSetTensorOp(p, *this); }];
}
def LoadParamsOp : DT_Op<"load_params", [NoSideEffect]> {
let summary = "dt.load_params operation";
let description = [{
An operation that can load tensors to TensorMap.
}];
// input path of model params.
let arguments = (ins StringType:$path);
let results = (outs TensorMapType);
let assemblyFormat = "`(` operands `)` attr-dict";
let verifier = ?;
}
def GetParamOp : DT_Op<"get_param", [NoSideEffect]> {
let summary = "dt.get_param operation";
let description = [{
An operation that can get a tensor from TensorMap.
}];
// input path of model params.
let arguments = (ins
TensorMapType:$map,
StrAttr:$name
);
let results = (outs TensorType:$output);
let assemblyFormat = "`(` $map `,` $name `)` attr-dict `->` type($output)";
let verifier = ?;
}
def GetTensorShapeOp : DT_Op<"get_tensor_shape", [NoSideEffect]> {
let summary = "dt.get_tensor_shape operation";
let description = [{
An operation that returns the shape of the input tensor.
}];
let arguments = (ins TensorType:$input);
let results = (outs TS_Shape:$output);
let assemblyFormat = "$input attr-dict `:` type($input) `->` type($output)";
}
foreach dtype = ["ui8", "ui16", "ui32", "ui64", "i32", "f32", "f64", "i64"] in {
def DT_CreateUninitTensorOp_#dtype : CreateUninitTensorOp<dtype>;
def DT_FillTensorOp_#dtype : FillTensorWithConstantOp<dtype>;
def DT_SetTensorOp_#dtype : SetTensorOp<dtype>;
}
#endif // DT_OPS
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include <string>
namespace infrt::dialect {
struct MyScopedDiagnosicHandler::Impl {
Impl() : diag_stream_(diag_str_) {}
// String stream to assemble the final error message.
std::string diag_str_;
llvm::raw_string_ostream diag_stream_;
// A SourceMgr to use for the base handler class.
llvm::SourceMgr source_mgr_;
// Log detail information.
bool log_info_{};
};
MyScopedDiagnosicHandler::MyScopedDiagnosicHandler(mlir::MLIRContext *ctx,
bool propagate)
: mlir::SourceMgrDiagnosticHandler(
impl_->source_mgr_, ctx, impl_->diag_stream_),
impl_(new Impl) {
setHandler([this](mlir::Diagnostic &diag) { return this->handler(&diag); });
}
mlir::LogicalResult MyScopedDiagnosicHandler::handler(mlir::Diagnostic *diag) {
if (diag->getSeverity() != mlir::DiagnosticSeverity::Error &&
!impl_->log_info_)
return mlir::success();
emitDiagnostic(*diag);
impl_->diag_stream_.flush();
return mlir::failure(true);
}
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Diagnostics.h>
#include <memory>
namespace infrt::dialect {
/**
* A scoped diagnostic handler to help debug MLIR process.
*/
class MyScopedDiagnosicHandler : public mlir::SourceMgrDiagnosticHandler {
public:
MyScopedDiagnosicHandler(mlir::MLIRContext* ctx, bool propagate);
mlir::LogicalResult handler(mlir::Diagnostic* diag);
~MyScopedDiagnosicHandler();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::hlir::dialect {
class CinnDialect : public ::mlir::Dialect {
public:
explicit CinnDialect(::mlir::MLIRContext* ctx);
//! We should register this function in dialect
static llvm::StringRef getDialectNamespace() {
return "infrt::hlir::dialect";
}
};
} // namespace infrt::hlir::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/test_kernels.h"
namespace infrt::dialect {
// ----INFRTDialect definition begin----
void INFRTDialect::initialize() {
allowUnknownTypes();
allowUnknownOperations();
addTypes<infrt::dt::StringType>();
addTypes<infrt::dt::TensorType>();
addTypes<infrt::dt::TensorMapType>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/basic_kernels.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
>();
}
mlir::Type INFRTDialect::parseType(mlir::DialectAsmParser &parser) const {
llvm::StringRef keyword;
if (parser.parseKeyword(&keyword)) return mlir::Type();
// parse TensorType, for example: !infrt.tensor<X86, CUDA, F32>
if (keyword == "tensor") {
llvm::StringRef target;
llvm::StringRef layout;
llvm::StringRef precision;
// parse "<"
if (parser.parseLess()) return mlir::Type();
// parse target
if (parser.parseKeyword(&target)) return mlir::Type();
auto targetType = infrt::dt::GetTargetType(target);
if (!targetType) {
parser.emitError(parser.getCurrentLocation(), "unknown target type: ")
<< target;
return mlir::Type();
}
// parse ","
if (parser.parseComma()) return mlir::Type();
// parse layout
if (parser.parseKeyword(&layout)) return mlir::Type();
auto layoutType = infrt::dt::GetLayoutType(layout);
if (!layoutType) {
parser.emitError(parser.getCurrentLocation(), "unknown layout type: ")
<< layout;
return mlir::Type();
}
// parse ","
if (parser.parseComma()) return mlir::Type();
// parse precision
if (parser.parseKeyword(&precision)) return mlir::Type();
auto precisionType = infrt::dt::GetPrecisionType(precision);
if (!precisionType) {
parser.emitError(parser.getCurrentLocation(), "unknown precision type: ")
<< precision;
return mlir::Type();
}
// parse ">"
if (parser.parseGreater()) return mlir::Type();
return infrt::dt::TensorType::get(*targetType, *layoutType, *precisionType);
}
// parse TensorMapType, for example: !infrt.tensor_map
if (keyword == "tensor_map") {
return infrt::dt::TensorMapType::get();
}
// parse StringType, for example: !infrt.string
if (keyword == "string") {
return infrt::dt::StringType::get();
}
parser.emitError(parser.getCurrentLocation(), "unknown infrt type: ")
<< keyword;
return mlir::Type();
}
void INFRTDialect::printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const {
// print TensorType, for example: !infrt.tensor<X86, CUDA, F32>
if (type.isa<infrt::dt::TensorType>()) {
auto tensorType = type.cast<infrt::dt::TensorType>();
printer << "tensor<" << tensorType.target() << ", " << tensorType.layout()
<< ", " << tensorType.precision() << ">";
return;
}
// print TensorMapType, for example: !infrt.tensor_map
if (type.isa<infrt::dt::TensorMapType>()) {
printer << "tensor_map";
return;
}
// print StringType, for example: !infrt.string
if (type.isa<infrt::dt::StringType>()) {
printer << "string";
return;
}
llvm_unreachable("unknown infrt type.");
}
// ----INFRTDialect definition end----
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Builders.h>
#include <mlir/IR/Dialect.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/IR/Types.h>
#include "paddle/infrt/dialect/infrt_base.hpp.inc"
namespace infrt::dialect {
class INFRTDialect : public ::mlir::Dialect {
explicit INFRTDialect(::mlir::MLIRContext *context)
: ::mlir::Dialect(getDialectNamespace(),
context,
::mlir::TypeID::get<INFRTDialect>()) {
initialize();
}
// parse types registered to the dialect.
mlir::Type parseType(mlir::DialectAsmParser &parser) const override;
// print types registered to the dialect.
void printType(mlir::Type type,
mlir::DialectAsmPrinter &printer) const override;
void initialize();
friend class ::mlir::MLIRContext;
public:
static ::llvm::StringRef getDialectNamespace() { return "infrt"; }
};
} // namespace infrt::dialect
namespace mlir {
template <typename T>
static mlir::IntegerAttr createI32Attr(mlir::OpBuilder &b, // NOLINT
mlir::Location loc,
T constant) {
return b.getIntegerAttr(b.getI32Type(), constant);
}
static mlir::ValueRange cvtValueToValueRange(const mlir::Value &operand) {
return mlir::ValueRange(operand);
}
static mlir::ValueRange concatTwoValueRange(mlir::ValueRange operand_0,
mlir::ValueRange operand_1) {
mlir::SmallVector<::mlir::Value, 4> operands;
operands.append(operand_0.begin(), operand_0.end());
operands.append(operand_1.begin(), operand_1.end());
return operands;
}
} // namespace mlir
#ifndef INFRT_BASE
#define INFRT_BASE
include "mlir/IR/OpBase.td"
def INFRT_Dialect : Dialect {
let name = "infrt";
let description = [{
The INFRT host dialect.
}];
let cppNamespace = "::infrt::dialect";
}
// Type definitions
def StringType :
Type<CPred<"$_self.isa<::infrt::dt::StringType>()">, "!infrt.string type">,
BuildableType<"$_builder.getType<::infrt::dt::StringType>()">;
def TensorType :
Type<CPred<"$_self.isa<::infrt::dt::TensorType>()">, "!infrt.tensor type">;
def TensorMapType :
Type<CPred<"$_self.isa<::infrt::dt::TensorMapType>()">, "!infrt.tensor_map type">,
BuildableType<"$_builder.getType<::infrt::dt::TensorMapType>()">;
def BufferType : OpaqueType<"b", "buffer", "buffer">;
class INFRT_createI32Attr<string value> : NativeCodeCall<
"mlir::createI32Attr($_builder, $_loc, " # value # ")">;
def INFRT_cvtValueToValueRange : NativeCodeCall<
"mlir::cvtValueToValueRange($0)">;
def INFRT_concatTwoValueRange : NativeCodeCall<
"mlir::concatTwoValueRange($0, $1)">;
class IsBoolAttrEq<string value> : Constraint<
CPred<"($0.getValue() ==" # value # ")">,
"Bool attrbute value constraint">;
#endif // INFRT_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include <glog/logging.h>
#include "paddle/infrt/dialect/basic_kernels.h"
#include "paddle/infrt/dialect/dense_tensor.h"
#include "paddle/infrt/dialect/infrt_base.h"
#include "paddle/infrt/dialect/pd_ops.h"
#include "paddle/infrt/dialect/tensor_shape.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry) { // NOLINT
registry.insert<ts::TensorShapeDialect>();
registry.insert<dialect::INFRTDialect>();
registry.insert<dt::DTDialect>();
registry.insert<mlir::pd::PaddleDialect>();
}
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/Dialect.h"
namespace infrt {
void RegisterCinnDialects(mlir::DialectRegistry& registry); // NOLINT
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/mlir_loader.h"
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <unordered_map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/infrt/dialect/diagnostic_utils.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source) {
context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
return mlir::success();
LOG(INFO) << "diag: " << diag.str();
return mlir::failure(true);
});
auto res = mlir::parseSourceString(
llvm::StringRef(mlir_source.data(), mlir_source.length()), context);
CHECK(*res) << "failed to parse MLIR string";
return res;
}
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context) {
context->allowUnregisteredDialects();
RegisterCinnDialects(context->getDialectRegistry());
context->getDialectRegistry().insert<mlir::StandardOpsDialect>();
mlir::ScopedDiagnosticHandler scope_handler(
context, [](mlir::Diagnostic& diag) {
if (diag.getSeverity() != mlir::DiagnosticSeverity::Error)
return mlir::success();
LOG(INFO) << "diag: " << diag.str();
return mlir::failure(true);
});
return mlir::parseSourceFile(std::string(file_name), context);
}
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <mlir/IR/Module.h>
#include <string>
#include <memory>
namespace infrt::dialect {
mlir::OwningModuleRef LoadMlirSource(mlir::MLIRContext* context,
const std::string& mlir_source);
mlir::OwningModuleRef LoadMlirFile(const std::string& file_name,
mlir::MLIRContext* context);
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/mlir_loader.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <llvm/Support/SourceMgr.h>
#include <mlir/IR/Function.h>
#include <mlir/Parser.h>
#include <string>
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace infrt::dialect {
TEST(MlirLoader, basic) {
mlir::MLIRContext context;
auto source = R"ROC(
func @main() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
"infrt.print.f32"(%v0) : (f32) -> ()
infrt.return %value : f32
}
)ROC";
auto module = LoadMlirSource(&context, source);
module->verify();
LOG(INFO) << "module name: " << module->getOperationName().data();
for (auto func : module->getOps<mlir::FuncOp>()) {
LOG(INFO) << "get func " << func.getName().str();
int num_args = func.getNumArguments();
for (int i = 0; i < num_args; i++) {
LOG(INFO) << "arg: " << func.getArgument(i).getArgNumber();
}
}
}
} // namespace infrt::dialect
// CHECK-LABEL: @basic_f32
func @basic_f32() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%value = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
// CHECK-NEXT: 3
"infrt.print.f32"(%value) : (f32) -> ()
infrt.return %value : f32
}
/// ================================================================
/// @caller call the other function @callee
func @callee.add.f32(%x : f32, %y : f32, %y1 : f32) -> f32 {
%z = "infrt.add.f32"(%x, %y) : (f32, f32) -> f32
%z1 = "infrt.add.f32"(%z, %y1) : (f32, f32) -> f32
infrt.return %z1 : f32
}
// CHECK-LABEL: @caller.add.f32
func @caller.add.f32() -> f32 {
%x = infrt.constant.f32 1.0
%y = infrt.constant.f32 2.0
%y1 = infrt.constant.f32 3.0
%z = infrt.call @callee.add.f32(%x, %y, %y1) : (f32, f32, f32) -> f32
// CHECK-NEXT: 6
"infrt.print.f32"(%z) : (f32) -> ()
infrt.return %z : f32
}
/// <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// CHECK-LABEL: @string_test
func @string_test() {
%path = infrt.get_string("this is get_string op.")
// CHECK-LABEL: string = this is get_string op.
infrt.print_string(%path)
infrt.return
}
// CHECK-LABEL: @benchmark
func @benchmark() {
// CHECK-LABEL: BM:add.f32:Count: 3
// CHECK-LABEL: BM:add.f32:Duration(ns)
// CHECK-LABEL: BM:add.f32:Time Min(ns)
// CHECK-LABEL: BM:add.f32:Time 50%(ns)
// CHECK-LABEL: BM:add.f32:Time 95%(ns)
// CHECK-LABEL: BM:add.f32:Time 99%(ns)
// CHECK-LABEL: BM:add.f32:CPU Min(ns)
// CHECK-LABEL: BM:add.f32:CPU 50%(ns)
// CHECK-LABEL: BM:add.f32:CPU 95%(ns)
// CHECK-LABEL: BM:add.f32:CPU 99%(ns)
// CHECK-LABEL: BM:add.f32:CPU utilization(percent)
infrt.benchmark "add.f32"() duration_secs = 1, max_count = 3, num_warmup_runs = 3
{
%0 = infrt.constant.f32 1.0
%1 = infrt.constant.f32 2.0
%res = "infrt.add.f32"(%0, %1) : (f32, f32) -> f32
"infrt.print.f32"(%res) : (f32) -> ()
infrt.return %res : f32
}
infrt.return
}
func @dense_shape0() {
%shape = ts.build_shape [1:i64, 57:i64]
%a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor<X86, NCHW, F32>
infrt.return
}
func @predict(%a: !infrt.tensor<X86, NCHW, F32>, %b: !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) {
%a0 = dt.shallow_copy_tensor %a : !infrt.tensor<X86, NCHW, F32> -> !infrt.tensor<X86, NCHW, F32>
%b0 = dt.shallow_copy_tensor %b : !infrt.tensor<X86, NCHW, F32> -> !infrt.tensor<X86, NCHW, F32>
infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>
}
func @main() {
%shape = ts.build_shape [1:i64, 57:i64]
%a = dt.create_uninit_tensor.f32 [12:i64, 23:i64] -> !infrt.tensor<X86, NCHW, F32>
%b, %c = infrt.call @predict(%a, %a) : (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
func @ops() {
%a = pd.Feed() : tensor<?xf32>
%b = pd.Feed() : tensor<?xf32>
%c = "pd.Matmul"(%a, %b) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
infrt.return
}
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.Feed"() : () -> tensor<?xf32>
%b = "pd.Feed"() : () -> tensor<?xf32>
%bias = "pd.Feed"() : () -> tensor<?xf32>
%b1 = "pd.Feed"() : () -> tensor<?xf32>
%b2 = "pd.Feed"() : () -> tensor<?xf32>
%bias1 = "pd.Feed"() : () -> tensor<?xf32>
%bias2 = "pd.Feed"() : () -> tensor<?xf32>
%c = "pd.Matmul"(%a, %b) {transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d = "pd.ElementwiseAdd"(%c, %bias) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e = "pd.Relu6"(%d) {} : (tensor<?xf32>) -> tensor<?xf32>
%c1 = "pd.Matmul"(%e, %b1) {transpose_x=false, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d1 = "pd.ElementwiseAdd"(%c1, %bias1) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e1 = "pd.Relu"(%d1) {} : (tensor<?xf32>) -> tensor<?xf32>
%c2 = "pd.Matmul"(%e1, %b2) {transpose_x=true, transpose_y=false} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%d2 = "pd.ElementwiseAdd"(%c2, %bias2) {axis=1:i32} : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
%e2 = "pd.Relu"(%d2) {} : (tensor<?xf32>) -> tensor<?xf32>
infrt.return %e2 : tensor<?xf32>
}
\ No newline at end of file
// CHECK-LABEL: @main
func @main() -> tensor<?xf32> {
%a = "pd.Feed"() : () -> tensor<?x3x256x256xf32>
%filter = "pd.Constant"(){value = dense<1.000000e+00> : tensor<3x64x3x3xf32>} : () -> tensor<3x64x3x3xf32>
%bias = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%scale = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%bias2 = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%mean = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%var = "pd.Constant"(){value = dense<1.000000e+00> : tensor<64xf32>} : () -> tensor<64xf32>
%c = "pd.conv2d"(%a, %filter, %bias) {} : (tensor<?x3x256x256xf32>, tensor<3x64x3x3xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
%d = "pd.batch_norm"(%c, %scale, %bias2, %mean, %var) {} : (tensor<?x3x256x256xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) -> tensor<?x3x256x256xf32>
infrt.return %d : tensor<?x3x256x256xf32>
}
\ No newline at end of file
// CHECK-LABEL: @predict
func @predict(%input:!infrt.tensor<X86, NCHW, F32>, %map: !infrt.tensor_map) -> (!infrt.tensor<X86, NCHW, F32>) {
%w = dt.get_param(%map, "create_parameter_0.w_0") -> !infrt.tensor<X86, NCHW, F32>
%bias = dt.get_param(%map, "create_parameter_1.w_0") -> !infrt.tensor<X86, NCHW, F32>
%out = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor<X86, NCHW, F32>
// fc
"external.matmul"(%input, %w, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.sigmoid"(%out, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
//dt.print_tensor (%out : !infrt.tensor<X86, NCHW, F32>)
infrt.return %out : !infrt.tensor<X86, NCHW, F32>
}
// CHECK-LABEL: @main
func @main() {
%input = dt.create_uninit_tensor.f32 [3, 3] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
%path = infrt.get_string("/infrt/build/paddle/paddle_1.8_fc_model")
// CHECK-LABEL: loading params
%map = dt.load_params(%path)
%out = infrt.call @predict(%input, %map): (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor_map) -> (!infrt.tensor<X86, NCHW, F32>)
dt.print_tensor (%out : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
func @build_tensor1() {
%a = ts.build_shape [1:i64, 57:i64, 92:i64]
ts.print_shape %a
infrt.return
}
// CHECK-LABEL: test_tensor_type
func @test_tensor_type() {
%a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
// CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
dt.print_tensor (%a : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/infrt_base.td"
class INFRT_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, traits>;
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <llvm/Support/CommandLine.h>
#include <mlir/Dialect/Affine/IR/AffineOps.h>
#include <mlir/Dialect/LLVMIR/LLVMDialect.h>
#include <mlir/IR/AsmState.h>
#include <mlir/IR/Dialect.h>
#include <mlir/InitAllDialects.h>
#include <mlir/InitAllPasses.h>
#include <mlir/Pass/Pass.h>
#include <mlir/Pass/PassManager.h>
#include <mlir/Support/FileUtilities.h>
#include <mlir/Support/MlirOptMain.h>
#include <mlir/Transforms/Passes.h>
#include <iostream>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
#include "paddle/infrt/dialect/mlir_loader.h"
int main(int argc, char **argv) {
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
mlir::registerCanonicalizerPass();
return mlir::failed(
mlir::MlirOptMain(argc, argv, "INFRT mlir pass driver", registry));
}
// This file defines some basic elements of Paddle(alias pd) dialect.
// We learned much from TensorFlow mlir dialect https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td
#ifndef PD_OP_BASE
#define PD_OP_BASE
include "mlir/IR/OpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def PD_Dialect : Dialect {
let name = "pd";
let description = [{
The PaddlePaddle dialect.
This dialect contains the PaddlePaddle operators.
}];
let cppNamespace = "::mlir::pd";
}
class PD_Op<string mnemonic, list<OpTrait> traits = []> :
Op<PD_Dialect, mnemonic, traits>;
class PD_PaddleAttr <string name, string description> :
Attr<CPred<"$_self.isa<mlir::pd::" # name # "Attr>()">,
"PaddlePaddle " # description # " attribute">;
//===----------------------------------------------------------------------===//
// PaddlePaddle type definitions
//===----------------------------------------------------------------------===//
def PD_PDDialectType : Type<CPred<"$_self.isa<mlir::pd::PDType>()">, "PaddlePaddle type">;
class PD_PaddleType <string name, string description> :
Type<CPred<"$_self.isa<mlir::pd::" # name #"Type>()">,
"Paddle " # description # " type">,
BuildableType<"getType<mlir::pd::" # name # "Type>()">;
//===----------------------------------------------------------------------===//
// Integer types
def PD_Bool : AnyTypeOf<[I<1>], "bool">;
def PD_Int8 : AnyTypeOf<[I8], "8-bit integer">;
def PD_Int16 : AnyTypeOf<[I16], "16-bit integer">;
def PD_Int32 : AnyTypeOf<[I32], "32-bit integer">;
def PD_Int64 : AnyTypeOf<[I64], "64-bit integer">;
def PD_UInt8 : AnyTypeOf<[UI<8>], "8-bit unsigned integer">;
def PD_UInt16 : AnyTypeOf<[UI<16>], "16-bit unsigned integer">;
def PD_UInt32 : AnyTypeOf<[UI<32>], "32-bit unsigned integer">;
def PD_UInt64 : AnyTypeOf<[UI<64>], "64-bit unsigned integer">;
def PD_SInt : AnyTypeOf<[PD_Int8, PD_Int16, PD_Int32, PD_Int64], "signed integer">;
def PD_UInt : AnyTypeOf<[PD_UInt8, PD_UInt16, PD_UInt32, PD_UInt64], "unsigned integer">;
def PD_Int : AnyTypeOf<[PD_SInt, PD_UInt], "integer">;
// Float types
def PD_Float16 : AnyTypeOf<[F16], "16-bit float">;
def PD_Float32 : AnyTypeOf<[F32], "32-bit float">;
def PD_Float64 : AnyTypeOf<[F64], "64-bit float">;
def PD_Float : AnyTypeOf<[PD_Float16, PD_Float32, PD_Float64], "floating-point">;
// Tensor types
def PD_ElementType : Type<Or<[PD_Float.predicate,
PD_Bool.predicate,
PD_Int.predicate]>,
"pd.dtype">;
def PD_Tensor : TensorOf<[PD_ElementType]>;
#endif // PD_OP_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/pd_ops.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "paddle/infrt/dialect/infrt_base.h"
namespace mlir {
namespace pd {
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.hpp.inc"
#undef GET_OP_CLASSES
PaddleDialect::PaddleDialect(MLIRContext *context)
: Dialect("pd", context, TypeID::get<PaddleDialect>()) {
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
>();
#undef GET_OP_LIST
}
mlir::Operation *PaddleDialect::materializeConstant(mlir::OpBuilder &builder,
mlir::Attribute value,
mlir::Type type,
mlir::Location loc) {
return builder.create<ConstantOp>(loc, value);
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/pd_ops.cpp.inc" // NOLINT
#undef GET_OP_CLASSES
#include "paddle/infrt/dialect/rewrite.hpp.inc" // NOLINT
void ConstantOp::build(OpBuilder &builder,
OperationState &state,
Attribute value) {
if (auto elem_attr = value.dyn_cast<ElementsAttr>()) {
return ConstantOp::build(builder, state, elem_attr);
} else if (value.isa<BoolAttr, FloatAttr, IntegerAttr>()) {
ShapedType type = RankedTensorType::get(/*shape=*/{}, value.getType());
state.addAttribute("value", DenseElementsAttr::get(type, value));
state.addTypes(type);
return;
}
llvm_unreachable("unsupported attribute type for building pd.constant");
}
LogicalResult ConstantOp::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(attributes.get("value").getType());
return success();
}
::mlir::OpFoldResult ConstantOp::fold(
::llvm::ArrayRef<::mlir::Attribute> operands) {
return value();
}
LogicalResult ElementwiseAdd::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
void ElementwiseAdd::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseMulAdd>(context);
}
::mlir::OpFoldResult ElementwiseAdd::fold(
llvm::ArrayRef<mlir::Attribute> operands) {
if (getElementTypeOrSelf(getType()).isa<FloatType>()) {
if (!operands[0] || !operands[1]) return {};
DenseElementsAttr lhs = operands[0].dyn_cast<DenseElementsAttr>();
DenseElementsAttr rhs = operands[1].dyn_cast<DenseElementsAttr>();
if (!lhs || !rhs) return {};
ShapedType type = getType().template cast<ShapedType>();
if (!type.hasStaticShape()) return {};
Type etype = type.getElementType();
if (!etype.isa<FloatType>()) return {};
SmallVector<APFloat, 6> values;
values.reserve(lhs.getNumElements());
for (const auto zip :
llvm::zip(lhs.getValues<APFloat>(), rhs.getValues<APFloat>())) {
values.push_back(
std::plus<APFloat>()(std::get<0>(zip), std::get<1>(zip)));
}
return DenseElementsAttr::get(type, values);
}
return {};
}
LogicalResult ElementwiseDiv::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult ElementwiseMul::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult ElementwiseSub::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
LogicalResult MulOp::inferReturnTypes(
MLIRContext *context,
Optional<Location> location,
ValueRange operands,
DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(operands[0].getType());
return success();
}
void ReluOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseFCRelu>(context);
}
void FusedRepeatedFCRelu::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseRepeatedFCRelu2>(context);
}
void BatchNormOp::getCanonicalizationPatterns(
::mlir::OwningRewritePatternList &results, ::mlir::MLIRContext *context) {
results.insert<FuseBatchNormWithConvPattern>(context);
}
} // namespace pd
} // namespace mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/Interfaces/DerivedAttributeOpInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace pd {
class PaddleDialect : public Dialect {
public:
explicit PaddleDialect(MLIRContext* context);
static StringRef getDialectNamespace() { return "pd"; }
/// A hook used to materialize constant values with the given type.
Operation* materializeConstant(OpBuilder& builder,
Attribute value,
Type type,
Location loc) override;
Type parseType(DialectAsmParser& parser) const override {
return Dialect::parseType(parser);
}
void printType(Type type, DialectAsmPrinter& printer) const override {
Dialect::printType(type, printer);
}
};
} // namespace pd
} // namespace mlir
#ifndef PD_OPS
#define PD_OPS
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/IR/OpBase.td"
include "paddle/infrt/dialect/pd_op_base.td"
def PD_FeedOp : PD_Op<"Feed", [NoSideEffect]> {
let summary = "Feed Op";
let description = [{
Feed a tensor into the model.
}];
let arguments = (ins);
let results = (outs PD_Tensor:$out);
let assemblyFormat = [{
`(` `)` attr-dict `:` type($out)
}];
}
def PD_ConstantOp : PD_Op<"Constant", [NoSideEffect, ConstantLike, DeclareOpInterfaceMethods<InferTypeOpInterface>, AllTypesMatch<["value", "output"]>]> {
let summary = "constant Op";
let description = [{}];
let arguments = (ins ElementsAttr:$value);
let results = (outs PD_Tensor:$output);
let hasFolder = 1;
let builders = [
OpBuilder<"OpBuilder &builder, OperationState &state, Attribute value">,
];
}
def PD_AbsOp : PD_Op<"Abs", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the absolute value of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_SqrtOp : PD_Op<"sqrt", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the sqrt value of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_ReluOp : PD_Op<"Relu", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Relu of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
let hasCanonicalizer = 1;
}
def PD_Relu6Op : PD_Op<"Relu6", [NoSideEffect, SameOperandsAndResultType]> {
let summary = "Computes the Relu6 of a tensor";
let description = [{
}];
let arguments = (ins PD_Tensor:$x);
let results = (outs PD_Tensor:$y);
}
def PD_ElementwiseAdd : PD_Op<"ElementwiseAdd", [NoSideEffect, Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseAdd Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
let hasCanonicalizer = 1;
let hasFolder = 1;
}
def PD_ElementwiseSub : PD_Op<"ElementwiseSub", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseSub Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_ElementwiseMul : PD_Op<"ElementwiseMul", [NoSideEffect, Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseMul Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_ElementwiseDiv : PD_Op<"ElementwiseDiv", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "ElementwiseDiv Op";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y, DefaultValuedAttr<I32Attr, "-1">:$axis);
let results = (outs PD_Tensor:$out);
}
def PD_MatmulOp : PD_Op<"Matmul", [NoSideEffect]> {
let summary = "Computes the matrix mulplication result of two tensors";
let description = [{
}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y,
DefaultValuedAttr<BoolAttr, "false">:$transpose_x,
DefaultValuedAttr<BoolAttr, "false">:$transpose_y,
DefaultValuedAttr<F32Attr, "1.0">:$alpha);
let results = (outs PD_Tensor:$out);
//let hasCanonicalizer = 1;
}
def PD_MulOp : PD_Op<"mul", [NoSideEffect, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "paddle mul op";
let description = [{}];
let arguments = (ins PD_Tensor:$x, PD_Tensor:$y);
let results = (outs PD_Tensor:$out);
//let hasCanonicalizer = 1;
}
def PD_Conv2dOp : PD_Op<"conv2d", [NoSideEffect]> {
let summary = "paddle conv2d operation";
let description = [{
}];
let arguments = (ins PD_Tensor:$Input, PD_Tensor:$Filter, PD_Tensor:$Bias);
let results = (outs PD_Tensor:$Output);
//let hasCanonicalizer = 1;
}
def PD_BatchNormOp : PD_Op<"batch_norm", [NoSideEffect]> {
let summary = "paddle batch_norm operation";
let description = [{
}];
let arguments = (ins PD_Tensor:$X, PD_Tensor:$Scale, PD_Tensor:$Bias,
PD_Tensor:$Mean, PD_Tensor:$Variance,
DefaultValuedAttr<F32Attr, "1e-05">:$epsilon);
let results = (outs PD_Tensor:$Y);
let hasCanonicalizer = 1;
}
def PD_FusedFC : PD_Op<"FC", [NoSideEffect]> {
let summary = "Computes the Fully Connected result of two tensors";
let description = [{
}];
let arguments = (ins PD_Tensor:$input, PD_Tensor:$w, PD_Tensor:$bias, DefaultValuedAttr<I32Attr, "1">:$in_num_col_dims);
let results = (outs PD_Tensor:$out);
}
def PD_FusedRepeatedFCRelu : PD_Op<"RepeatedFCRelu", [SameVariadicOperandSize, NoSideEffect]> {
let summary = "";
let description = [{ }];
let arguments = (ins PD_Tensor:$input, Variadic<PD_Tensor>:$w, Variadic<PD_Tensor>:$bias);
let results = (outs PD_Tensor:$out);
let hasCanonicalizer = 1;
}
#endif // PD_OPS
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/pd_types.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file defines the types used in PaddlePaddle MLIR dialect.
// We borrowed much ideas from tensorflow mlir dialect (tf_types.h in
// tensorflow).
#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Types.h"
namespace mlir {
namespace PD {
class PaddleType : public Type {
public:
using Type::Type;
static bool classof(Type type);
};
namespace detail {
template <typename Derived>
class PaddleTypeImpl : public Type::TypeBase<Derived, PaddleType, TypeStorage> {
public:
using Base = typename Type::TypeBase<Derived, PaddleType, TypeStorage>;
using PDBase = PaddleTypeImpl<Derived>;
using Base::Base;
};
} // namespace detail
#define HANDLE_PD_TYPE(pdtype, enumerant, name) \
class pdtype##Type : public detail::PaddleTypeImpl<pdtype##Type> { \
public: \
using PDBase::PDBase; \
};
} // namespace PD
} // namespace mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "llvm/ADT/Optional.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/ScopedPrinter.h"
#include "llvm/Support/raw_os_ostream.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/Passes.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/init_infrt_dialects.h"
namespace cl = llvm::cl;
static cl::opt<std::string> inputFilename(cl::Positional,
cl::desc("<input toy file>"),
cl::init("-"),
cl::value_desc("filename"));
llvm::raw_ostream &printIndent(int indent = 0) {
for (int i = 0; i < indent; ++i) llvm::outs() << " ";
return llvm::outs();
}
void printOperation(mlir::Operation *op, int indent);
void printRegion(mlir::Region &region, int indent); // NOLINT
void printBlock(mlir::Block &block, int indent); // NOLINT
void printOperation(mlir::Operation *op, int indent) {
llvm::Optional<mlir::ModuleOp> module_op = llvm::None;
if (llvm::isa<mlir::ModuleOp>(op))
module_op = llvm::dyn_cast<mlir::ModuleOp>(op);
llvm::Optional<mlir::FuncOp> func_op = llvm::None;
if (llvm::isa<mlir::FuncOp>(op)) func_op = llvm::dyn_cast<mlir::FuncOp>(op);
printIndent(indent) << "op: '" << op->getName();
// This getName is inherited from Operation::getName
if (module_op) {
printIndent() << "@" << module_op->getName();
}
// This getName is inherited from SymbolOpInterfaceTrait::getName,
// which return value of "sym_name" in ModuleOp or FuncOp attributes.
if (func_op) {
printIndent() << "@" << func_op->getName();
}
printIndent() << "' with " << op->getNumOperands() << " operands"
<< ", " << op->getNumResults() << " results"
<< ", " << op->getAttrs().size() << " attributes"
<< ", " << op->getNumRegions() << " regions"
<< ", " << op->getNumSuccessors() << " successors\n";
if (!op->getAttrs().empty()) {
printIndent(indent) << op->getAttrs().size() << " attributes:\n";
for (mlir::NamedAttribute attr : op->getAttrs()) {
printIndent(indent + 1) << "- {" << attr.first << " : " << attr.second
<< "}\n";
}
}
if (op->getNumRegions() > 0) {
printIndent(indent) << op->getNumRegions() << " nested regions:\n";
for (mlir::Region &region : op->getRegions()) {
printRegion(region, indent + 1);
}
}
}
void printRegion(mlir::Region &region, int indent) { // NOLINT
printIndent(indent) << "Region with " << region.getBlocks().size()
<< " blocks:\n";
for (mlir::Block &block : region.getBlocks()) {
printBlock(block, indent + 1);
}
}
void printBlock(mlir::Block &block, int indent) { // NOLINT
printIndent(indent) << "Block with " << block.getNumArguments()
<< " arguments"
<< ", " << block.getNumSuccessors() << " successors"
<< ", " << block.getOperations().size()
<< " operations\n";
for (mlir::Operation &operation : block.getOperations()) {
printOperation(&operation, indent + 1);
}
}
int main(int argc, char **argv) {
mlir::registerAsmPrinterCLOptions();
mlir::registerMLIRContextCLOptions();
mlir::registerPassManagerCLOptions();
cl::ParseCommandLineOptions(argc, argv, "mlir demo");
mlir::MLIRContext *context = infrt::Global::getMLIRContext();
context->allowUnregisteredDialects();
auto &registry = context->getDialectRegistry();
infrt::RegisterCinnDialects(registry);
// mlir will verify module automatically after parsing.
// https://github.com/llvm/llvm-project/blob/38d18d93534d290d045bbbfa86337e70f1139dc2/mlir/lib/Parser/Parser.cpp#L2051
// mlir::OwningModuleRef module_ref = mlir::parseSourceString(mlir_source,
// context);
mlir::OwningModuleRef module_ref =
mlir::parseSourceFile(inputFilename, context);
std::cout << "----------print IR Structure begin----------" << std::endl;
printOperation(module_ref->getOperation(), 0);
std::cout << "----------print IR Structure end----------" << std::endl;
module_ref->dump();
return 0;
}
#ifndef INFRT_REWRITE
#define INFRT_REWRITE
include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "paddle/infrt/dialect/pd_ops.td"
//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'Matmul o ElementwiseAdd' into 'PD_FusedFC'.
//
// We have:
// (Matmul) z = x * y
// (Add) out = z + bias
//
// which corresponds to the following computation:
// (FusedFC) out = x * y + bias
//
// Todo:
// 1. Make the constrait more completely.
// 2. Consider the case of : out = bias + z
//===----------------------------------------------------------------------===//
def FuseMulAdd : Pat<(PD_ElementwiseAdd (PD_MatmulOp $x, $y, $transpose_x, $transpose_y, $alpha), $bias, $axis),
(PD_FusedFC $x, $y, $bias, (INFRT_createI32Attr<"1">)),
[(IsBoolAttrEq<"false"> $transpose_x),(IsBoolAttrEq<"false"> $transpose_y)]>;
//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'FusedFC o Relu' into 'FusedRepeatedFCRelu'.
//
// We have:
// (FusedFC) z = fc(x, y, bias)
// (Relu) out = relu(z)
//
// which corresponds to the following computation:
// (FusedRepeatedFCRelu) out = RepeatedFCRelu(x, [y], [bias])
//
//===----------------------------------------------------------------------===//
def FuseFCRelu : Pat<(PD_ReluOp (PD_FusedFC $x, $y, $bias, $_)),
(PD_FusedRepeatedFCRelu $x, (INFRT_cvtValueToValueRange $y), (INFRT_cvtValueToValueRange $bias))>;
//===----------------------------------------------------------------------===//
// This is to fold 'FusedRepeatedFCRelu' op.
//
// We have:
// (FusedRepeatedFCRelu) z = RepeatedFCRelu(x, [y, ...], [bias, ...])
// (FusedRepeatedFCRelu) out = RepeatedFCRelu(z, [y1, ...], [bias1, ...])
//
// which corresponds to the following computation:
// (FusedRepeatedFCRelu) out = RepeatedFCRelu(x, [y, ..., y1, ...], [bias, ..., bias1, ....])
//
//===----------------------------------------------------------------------===//
def FuseRepeatedFCRelu2 : Pat<(PD_FusedRepeatedFCRelu (PD_FusedRepeatedFCRelu $x, $y, $bias), $y_2, $bias_2),
(PD_FusedRepeatedFCRelu $x, (INFRT_concatTwoValueRange $y, $y_2), (INFRT_concatTwoValueRange $bias, $bias_2))>;
//===----------------------------------------------------------------------===//
// This is to fuse the composition: 'BatchNorm o Conv' into 'Conv'
// by deriving new 'w' and 'b' for 'Conv':
//
// We have:
// (Conv) z = w * x + b
// (BatchNorm) y = scale * (z - mean) / sqrt(var + eps) + bias
//
// which corresponds to the following computation:
// y = w_ * x + b_
// where
// w_ = scale * w / sqrt(var + eps)
// b_ = B + scale * (b - mean) / sqrt(var + eps)
//
//===----------------------------------------------------------------------===//
def FuseBatchNormWithConvPattern: Pat<
(PD_BatchNormOp
(PD_Conv2dOp $input, $filter, $bias),
$scale, $bias_2, $mean, $var, $epsilon),
(PD_Conv2dOp
$input,
(PD_MulOp $filter,
(PD_ElementwiseDiv:$coefficientW
$scale,
(PD_SqrtOp (PD_ElementwiseAdd $var, (PD_ConstantOp $epsilon), (INFRT_createI32Attr<"1">))),
(INFRT_createI32Attr<"1">))),
(PD_ElementwiseAdd
$bias,
(PD_MulOp
(PD_ElementwiseSub $bias, $mean, (INFRT_createI32Attr<"1">)),
$coefficientW),
(INFRT_createI32Attr<"1">)))
>;
#endif // INFRT_REWRITE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/tensor_shape.h"
#include <llvm/ADT/STLExtras.h>
#include <mlir/IR/Attributes.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/DialectImplementation.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/OpImplementation.h>
#include <mlir/IR/StandardTypes.h>
#include <mlir/IR/TypeUtilities.h>
#include <mlir/Support/LogicalResult.h>
namespace infrt::ts {
using namespace mlir; // NOLINT
void TensorShapeDialect::initialize() {
allowUnknownTypes();
addTypes<ShapeType, PartialShapeType>();
addOperations<
#define GET_OP_LIST
#include "paddle/infrt/dialect/tensor_shape.cpp.inc"
>();
}
Type TensorShapeDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword)) return Type();
if (keyword == "shape") return ShapeType::get(getContext());
if (keyword == "partial_shape") return PartialShapeType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
return Type();
}
void TensorShapeDialect::printType(::mlir::Type type,
::mlir::DialectAsmPrinter &os) const {
if (type.isa<ShapeType>()) {
os << "shape";
return;
}
if (type.isa<PartialShapeType>()) {
os << "partial_shape";
return;
}
llvm_unreachable("unexpected 'shape' type kind");
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.cpp.inc" // NOLINT
} // namespace infrt::ts
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Dialect.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/Interfaces/SideEffectInterfaces.h>
namespace infrt::ts {
class ShapeType
: public mlir::Type::TypeBase<ShapeType, mlir::Type, mlir::TypeStorage> {
public:
using Base::Base;
};
class PartialShapeType : public mlir::Type::TypeBase<PartialShapeType,
mlir::Type,
mlir::TypeStorage> {
public:
using Base::Base;
};
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/tensor_shape.hpp.inc"
#include "paddle/infrt/dialect/tensor_shape_dialect.hpp.inc"
} // namespace infrt::ts
#ifdef INFRT_OPS
#else
#define INFRT_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "paddle/infrt/dialect/tensor_shape_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
// Base class for the operation in the TensorShape dialect
class TS_Op<string mnemonic, list<OpTrait> traits = []> :
Op<TensorShapeDialect, mnemonic, traits> {
let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }];
let printer = " return infrt::dialect::printOpWithOperands(p, *this)" ";";
}
def TS_BuildShapeOp : TS_Op<"build_shape", [NoSideEffect]> {
let summary = "Build tensor shape operation";
let description = [{
An operation that builds a tensor shape of given ranks and extents.
}];
let arguments = (ins I64ArrayAttr:$value);
let results = (outs TS_Shape:$output);
let assemblyFormat = "$value attr-dict";
}
def TS_GetNumElementsOp : TS_Op<"get_num_elements"> {
let summary = "Returns the number of elements in the shape";
let description = [{
An operation that returns the number of elements in the given shape.
}];
let arguments = (ins TS_Shape);
let results = (outs I64);
let assemblyFormat = "operands attr-dict";
}
def TS_PrintShapeOp : TS_Op<"print_shape"> {
let summary = "Print tensor shape operation";
let description = [{
An operation that prints a tensor shape.
}];
let arguments = (ins TS_Shape:$shape);
let assemblyFormat = "operands attr-dict";
}
#endif
#ifdef TS_OPS_BASE
#else
#define TS_OPS_BASE
// Tensor shape dialect.
def TensorShapeDialect : Dialect {
let name = "ts";
let description = [{
The Tensor Shape dialect.
This dialect contains operations for working with tensor shapes.
}];
let cppNamespace = "::infrt::ts";
}
// Type definition.
def TS_Shape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::ShapeType>()">, "!ts.shape type">,
BuildableType<"$_builder.getType<::infrt::ts::ShapeType>()"> {
let typeDescription = [{
`!ts.shape type` represents a static tensor shape.
}];
}
def TS_PartialShape : DialectType<TensorShapeDialect,
CPred<"$_self.isa<::infrt::ts::PartialShapeType>()">, "!ts.partial_shape type">,
BuildableType<"$_builder.getType<::infrt::ts::PartialShapeType>()"> {
let typeDescription = [{
`!ts.partial_shape type` represents either a static tensor shape, unranked
tensor shape or a ranked tensor shape with unknown dimension sizes.
}];
}
#endif // TS_OPS_BASE
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/test_kernels.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/IR/TypeUtilities.h"
namespace infrt::dialect {
//===----------------------------------------------------------------------===//
// BenchmarkOp
//===----------------------------------------------------------------------===//
// Parse the BenchmarkOp in the following format
// infrt.benchmark "add.i32"(%c : i32, %d : f32)
// max_count = 100, duration_secs = 1 {
// ...
// }
static ParseResult parseBenchmarkOp(OpAsmParser &parser, // NOLINT
OperationState &result) { // NOLINT
StringAttr nameAttr;
if (parser.parseAttribute(nameAttr, "name", result.attributes))
return failure();
// Parse the operands, e.g. (%c : i32, %d : f32)
if (parser.parseLParen()) return failure();
SmallVector<OpAsmParser::OperandType, 4> operands;
SmallVector<Type, 4> types;
llvm::SMLoc type_loc = parser.getCurrentLocation();
if (parser.parseOptionalRParen()) {
// Parse non-empty operands
do {
// Parse %c : i32,
OpAsmParser::OperandType operand;
Type type;
if (parser.parseOperand(operand) || parser.parseColonType(type))
return failure();
operands.push_back(operand);
types.push_back(type);
} while (succeeded(parser.parseOptionalComma()));
if (parser.parseRParen()) return failure();
}
if (parser.resolveOperands(operands, types, type_loc, result.operands))
return failure();
// Parse the keyword attribute, e.g. max_count = 100, duration_secs = 1
do {
StringRef attr;
Attribute resultAttr;
if (parser.parseKeyword(&attr) || parser.parseEqual() ||
parser.parseAttribute(resultAttr,
parser.getBuilder().getIntegerType(32),
attr,
result.attributes))
return failure();
} while (succeeded(parser.parseOptionalComma()));
// Set the default attribute num_warmup_runs to 1 if unset
auto setDefaultAttrIfUnset = [&](const char *attr_name, int value) {
bool found = llvm::any_of(result.attributes,
[attr_name](const NamedAttribute &attr) {
return attr.first == attr_name;
});
if (!found) {
IntegerAttr default_val = parser.getBuilder().getI32IntegerAttr(value);
result.addAttribute(attr_name, default_val);
}
};
setDefaultAttrIfUnset("num_warmup_runs", 1);
Region *target = result.addRegion();
return parser.parseRegion(*target,
operands,
types,
/*enableNameShadowing=*/true);
}
// Print the BenchmarkOp in the following format
// infrt.benchmark "add.i32"(%c : i32, %d : f32)
// max_count = 100, duration_secs = 1 {
// ...
// }
static void print(OpAsmPrinter &p, BenchmarkOp op) { // NOLINT
p << "infrt.benchmark ";
// Print the name attribute, e.g "add.i32"
auto name_attr = op.getAttr("name");
p << name_attr;
// Print the operands and types, e.g. (%c : i32, %d : f32)
p << '(';
llvm::interleaveComma(llvm::zip(op.getOperands(), op.getOperandTypes()),
p,
[&](const auto &it) {
p << std::get<0>(it) << " : " << std::get<1>(it);
});
p << ") ";
bool need_comma = false;
// Print the attributes, e.g. max_count = 100, duration_secs = 1
for (auto &name_attr : op.getAttrs()) {
auto id = name_attr.first;
if (id == "name") continue;
if (need_comma) p << ", ";
auto attr = name_attr.second;
p << id << " = ";
if (auto int_attr = attr.dyn_cast<IntegerAttr>()) {
int_attr.getValue().print(p.getStream(), /*isSigned=*/false);
} else {
op.emitOpError("Unexpected attribute");
}
need_comma = true;
}
p << ' ';
// Print the region
// Reuse the argument names provided to the op for the bbarg names within
// the region.
p.shadowRegionArgs(op.region(), op.getOperands());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
static LogicalResult verify(BenchmarkOp op) {
// Verify that the target benchmark region has exactly one return value.
auto &region = op.region();
auto &last_op = region.front().back();
if (last_op.getName().getStringRef() != "infrt.return") {
return op.emitOpError("missing return statement");
}
if (last_op.getNumOperands() != 1) {
return op.emitOpError(
"incorrect number of return values. One return value is expected");
}
return success();
}
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.cpp.inc"
} // namespace infrt::dialect
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace infrt::dialect {
using namespace mlir; // NOLINT
#define GET_OP_CLASSES
#include "paddle/infrt/dialect/test_kernels.hpp.inc"
} // namespace infrt::dialect
// Operation definitions for testing.
#ifdef TEST_OPS
#else
#define TEST_OPS
include "paddle/infrt/dialect/infrt_base.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
// Base class for Test dialect ops.
class Test_Op<string mnemonic, list<OpTrait> traits = []> :
Op<INFRT_Dialect, mnemonic, !listconcat(traits, [IsolatedFromAbove])> {
// Each registered op in the Test namespace needs to provide all of a printer,
// parser and verifier.
let printer = [{ return infrt::dialect::print(p, *this); }];
let verifier = [{ return infrt::dialect::verify(*this); }];
let parser = [{ return infrt::dialect::parse$cppClass(parser, result); }];
}
def BenchmarkOp : Test_Op<"benchmark"> {
let summary = "benchmark operation";
let description = [{
The "infrt.benchmark" operation benchmarks the performance of an MLIR
region by executing the given MLIR region repeatedly up to the
`duratino_secs` seconds or `max_count` times. `num_warmup_runs` specifies
the number of warm up runs to run the given MLIR region before the
benchmark starts.
The target MLIR region can take an arbitrary number of arguments and
should return exactly one value. The arguments for the MLIR region are
provided as the operands of the infrt.benchmark op.
Example:
infrt.benchmark "add.i32"(%c : i32, %d : f32) max_count = 100, duration_secs = 1 {
// code for benchmarking
...
}
infrt.benchmark "add.i32"(%c : i32)
duration_secs = 1,
max_count = 100,
num_warmup_runs = 10 {
// The MLIR code to be benchmarked goes here.
// The following code benchmarks the infrt.add.i32 kernel.
%x = infrt.add.i32 %c, %c
// The benchmarked function needs to return exactly one value.
infrt.return %x : i32
}
}];
let regions = (region SizedRegion<1>:$region);
let arguments = (ins
Variadic<AnyType>,
I32Attr:$duration_secs,
I32Attr:$max_count,
StrAttr:$name,
DefaultValuedAttr<I32Attr, "1">:$num_warmup_runs
);
let results = (outs);
}
#endif // TEST_OPS
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/dialect/types.h"
namespace infrt::hlir::mlir {} // namespace infrt::hlir::mlir
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/StandardTypes.h>
set(external_kernels_src "basic_kernels.cc")
cc_library(external_kernels SHARED SRCS ${external_kernels_src})
set_target_properties(external_kernels PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
set(basic_mlir "${CMAKE_CURRENT_SOURCE_DIR}/basic.mlir")
set(external_kernels_lib "${CMAKE_CURRENT_BINARY_DIR}/libexternal_kernels.so")
message(STATUS "basic_mlir: ${basic_mlir}")
message(STATUS "external_kernels_lib: ${external_kernels_lib}")
add_test(
NAME run_and_check_external_kernels
COMMAND sh -c "${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i ${basic_mlir} --shared_libs=${external_kernels_lib} | ${LLVM_PATH}/bin/FileCheck ${basic_mlir}"
)
// CHECK: basic
func @basic() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%v2 = "external.add.f32"(%v0, %v1) : (f32, f32) -> f32
// CHECK: 1
"external.print.f32"(%v0) : (f32) -> ()
// CHECK: 2
"external.print.f32"(%v1) : (f32) -> ()
// CHECK: 3
"external.print.f32"(%v2) : (f32) -> ()
%v3 = "external.mul.f32"(%v2, %v1) : (f32, f32) -> f32
// CHECK: 6
"external.print.f32"(%v3) : (f32) -> ()
infrt.return %v3 : f32
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
template <typename T>
T add(T a, T b) {
return a + b;
}
template <typename T>
T sub(T a, T b) {
return a - b;
}
template <typename T>
T mul(T a, T b) {
return a * b;
}
template <typename T>
T div(T a, T b) {
return a / b;
}
template <typename T>
void print(T a) {
std::cout << a << std::endl;
}
void RegisterKernels(infrt::host_context::KernelRegistry *registry) {
// int32
registry->AddKernel("external.add.i32", INFRT_KERNEL(add<int32_t>));
registry->AddKernel("external.sub.i32", INFRT_KERNEL(sub<int32_t>));
registry->AddKernel("external.mul.i32", INFRT_KERNEL(mul<int32_t>));
registry->AddKernel("external.div.i32", INFRT_KERNEL(div<int32_t>));
registry->AddKernel("external.print.i32", INFRT_KERNEL(print<int32_t>));
// float
registry->AddKernel("external.add.f32", INFRT_KERNEL(add<float>));
registry->AddKernel("external.sub.f32", INFRT_KERNEL(sub<float>));
registry->AddKernel("external.mul.f32", INFRT_KERNEL(mul<float>));
registry->AddKernel("external.div.f32", INFRT_KERNEL(div<float>));
registry->AddKernel("external.print.f32", INFRT_KERNEL(print<float>));
}
// CHECK-LABEL: @fc
func @fc(%input : !infrt.tensor<X86, NCHW, F32>,
%w : !infrt.tensor<X86, NCHW, F32>,
%bias : !infrt.tensor<X86, NCHW, F32>) -> !infrt.tensor<X86, NCHW, F32>
{
%out = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor<X86, NCHW, F32>
// dt.fill_tensor_with_constant.f32 (%out : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
// fc1
"external.matmul"(%input, %w, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.sigmoid"(%out, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
// fc2
"external.matmul"(%out, %w, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.elementwise_add"(%out, %bias, %out) {axis = -1}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
"external.sigmoid"(%out, %out) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
infrt.return %out : !infrt.tensor<X86, NCHW, F32>
}
// CHECK-LABEL: @benchmark
func @benchmark() {
%input = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
%w = dt.create_uninit_tensor.f32 [50, 50] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%w : !infrt.tensor<X86, NCHW, F32>) {value=2.0:f32}
%bias = dt.create_uninit_tensor.f32 [30, 50] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%bias : !infrt.tensor<X86, NCHW, F32>) {value=3.0:f32}
infrt.benchmark "add.f32"(
%input:!infrt.tensor<X86, NCHW, F32>,
%w:!infrt.tensor<X86, NCHW, F32>,
%bias:!infrt.tensor<X86, NCHW, F32>)
duration_secs = 100, max_count = 300000, num_warmup_runs = 3
{
%res = infrt.call @fc(%input, %w, %bias) : (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>)
infrt.return %res : !infrt.tensor<X86, NCHW, F32>
}
infrt.return
}
// CHECK: paddle_func
func @paddle_func() -> () {
%input = dt.create_uninit_tensor.f32 [3, 5] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%input : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
%w = dt.create_uninit_tensor.f32 [5, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%w : !infrt.tensor<X86, NCHW, F32>) {value=2.0:f32}
%bias = dt.create_uninit_tensor.f32 [4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%bias : !infrt.tensor<X86, NCHW, F32>) {value=3.0:f32}
%out = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%out : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
"external.fc2"(%input, %w, %bias, %out) {in_num_col_dims=3:i32, test_attr=5:i32}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
// CHECK-LABEL: tensor: shape=shape[3,5], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
dt.print_tensor (%input : !infrt.tensor<X86, NCHW, F32>)
// CHECK-LABEL: tensor: shape=shape[5,4], values=[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
dt.print_tensor (%w : !infrt.tensor<X86, NCHW, F32>)
dt.print_tensor (%bias : !infrt.tensor<X86, NCHW, F32>)
dt.print_tensor (%out : !infrt.tensor<X86, NCHW, F32>)
// test external.matmul
%out1 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%out1 : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
"external.matmul"(%input, %w, %out1) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
dt.print_tensor (%out1 : !infrt.tensor<X86, NCHW, F32>)
// test external.elementwise_add
%out2 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%out2 : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
%bias1 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%bias1 : !infrt.tensor<X86, NCHW, F32>) {value=3.0:f32}
"external.elementwise_add"(%out1, %bias1, %out2) {axis=-1}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
dt.print_tensor (%out2 : !infrt.tensor<X86, NCHW, F32>)
// test external.relu
%out3 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%out3 : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
"external.relu"(%out1, %out3) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
dt.print_tensor (%out3 : !infrt.tensor<X86, NCHW, F32>)
// test external.sigmoid
%out4 = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%out4 : !infrt.tensor<X86, NCHW, F32>) {value=0.0:f32}
"external.sigmoid"(%out1, %out4) {}: (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) -> ()
dt.print_tensor (%out4 : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv);
gflags::ParseCommandLineFlags(&argc, &argv, false);
return RUN_ALL_TESTS();
}
core_gather_headers()
gather_srcs(infrt_src SRCS
kernel_frame.cc
kernel_registry.cc
value.cc
kernel_utils.cc
symbol_table.cc
op_executable.cc
core_runtime.cc
mlir_to_runtime_translate.cc
function.cc
mlir_function_executable.cc
mlir_program_executor.cc
)
cc_test_tiny(test_infrt_host_context_value SRCS value_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_utils SRCS kernel_utils_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_kernel_registry SRCS kernel_registry_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_op_executable SRCS op_executable_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_core_runtime SRCS core_runtime_test.cc DEPS infrt ${MLIR_IR_LIBS})
cc_test_tiny(test_infrt_mlir_to_runtime_translate SRCS mlir_to_runtime_translate_test.cc DEPS infrt ${MLIR_IR_LIBS})
infrt_exec_check(test_infrt_mlir_exec_on_basic mlir_tests/basic.mlir)
infrt_exec_check(test_infrt_mlir_exec_on_shape mlir_tests/shape.mlir)
infrt_exec_check(test_infrt_mlir_exec_on_dense_tensor mlir_tests/dense_tensor.mlir)
add_executable(infrt-exec mlir_exec.cc)
target_link_libraries(infrt-exec infrt ${MLIR_IR_LIBS})
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/core_runtime.h"
#include <unordered_map>
#include <string>
#include <vector>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
struct CoreRuntime::Impl {
KernelRegistry* kernel_registry{};
SymbolTable symbol_table;
std::vector<OpExecutableBuilder> op_executables;
mutable std::vector<ValueRef> results;
};
SymbolTable* CoreRuntime::symbol_table() { return &impl_->symbol_table; }
CoreRuntime::CoreRuntime(CoreRuntime::Impl* impl) : impl_(impl) { CHECK(impl); }
void CoreRuntime::Execute() {
// std::cout << "CoreRuntime::Execute" << std::endl;
int op_offset = 0;
for (auto& op : impl_->op_executables) {
VLOG(3) << "running op " << op_offset++ << " " << op.name();
op.Execute();
}
}
KernelRegistry* CoreRuntime::kernel_registry() const {
return impl_->kernel_registry;
}
size_t CoreRuntime::num_ops() const { return impl_->op_executables.size(); }
CoreRuntimeBuilder::CoreRuntimeBuilder(KernelRegistry* kernel_registry)
: CoreRuntime(new Impl) {
impl_->kernel_registry =
kernel_registry ? kernel_registry : GetCpuKernelRegistry();
}
OpExecutableBuilder* CoreRuntimeBuilder::NewOpExecutable(
const std::string& op_name) {
CHECK(impl_.get());
impl_->op_executables.emplace_back(
op_name, symbol_table(), impl_->kernel_registry);
return &impl_->op_executables.back();
}
void CoreRuntimeBuilder::FeedInArgs(
llvm::ArrayRef<std::pair<std::string, ValueRef>> args) {
for (auto& item : args) {
symbol_table()->Register(item.first, item.second);
}
}
void CoreRuntimeBuilder::SetKernelRegistry(KernelRegistry* x) {
CHECK(x);
impl_->kernel_registry = x;
}
llvm::SmallVector<ValueRef, 4> CoreRuntime::GetResults(
llvm::ArrayRef<std::string> arg_names) {
llvm::SmallVector<ValueRef, 4> results;
for (auto& name : arg_names) {
results.push_back(ValueRef(symbol_table()->GetValue(name)));
}
return results;
}
CoreRuntime::~CoreRuntime() {}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <memory>
#include <string>
#include <utility>
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
class KernelRegistry;
class OpExecutable;
class OpExecutableBuilder;
class SymbolTable;
/**
* CoreRuntime encapsulate the execution for a sequence of ops.
* Each function call will bind to a CoreRuntime instance, push the argument
* Values in to the argument-list, and get the
* result Values from the return-list.
*/
class CoreRuntime : public std::enable_shared_from_this<CoreRuntime> {
public:
//! Execute a program.
void Execute();
//! Return the number of ops.
size_t num_ops() const;
//! Get the results of the execution.
llvm::SmallVector<ValueRef, 4> //
GetResults(llvm::ArrayRef<std::string> arg_names);
std::shared_ptr<CoreRuntime> getptr() {
return std::shared_ptr<CoreRuntime>(this);
}
KernelRegistry* kernel_registry() const;
~CoreRuntime();
protected:
//! Get the symbol table.
SymbolTable* symbol_table();
class Impl;
explicit CoreRuntime(Impl* impl);
std::unique_ptr<Impl> impl_;
};
/**
* The builder for CoreRuntime, help to construct a function.
*/
class CoreRuntimeBuilder : public CoreRuntime {
public:
explicit CoreRuntimeBuilder(KernelRegistry* kernel_registry);
using CoreRuntime::symbol_table;
void SetKernelRegistry(KernelRegistry* x);
//! Feed the input arguments, each item is a pair of arg-name and arg-value.
void FeedInArgs(llvm::ArrayRef<std::pair<std::string, ValueRef>> args);
llvm::ArrayRef<const std::string&> attr_names() const;
OpExecutableBuilder* NewOpExecutable(const std::string& op_name);
};
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/core_runtime.h"
#include <gtest/gtest.h>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt {
namespace host_context {
int add(int a, int b) { return a + b; }
int sub(int a, int b) { return a - b; }
TEST(CoreRuntime, basic) {
KernelRegistry registry;
registry.AddKernel("infrt.test.addi32", INFRT_KERNEL(add));
registry.AddKernel("infrt.test.subi32", INFRT_KERNEL(sub));
CoreRuntimeBuilder builder(&registry);
auto* table = builder.symbol_table();
table->Register("a", 1);
table->Register("b", 2);
table->Register("d", 4);
// c = a + b
auto* op0 = builder.NewOpExecutable("infrt.test.addi32");
op0->AppendArgument("a");
op0->AppendArgument("b");
op0->SetResults({"c"});
// e = c - d
auto* op1 = builder.NewOpExecutable("infrt.test.subi32");
op1->AppendArgument("c");
op1->AppendArgument("d");
op1->SetResults({"e"});
builder.Execute();
ASSERT_EQ(table->GetValue("d")->get<int>(), 4);
ASSERT_EQ(table->GetValue("c")->get<int>(), 3);
ASSERT_EQ(table->GetValue("e")->get<int>(), -1);
}
TEST(CoreRuntime, function) {
// The function:
// func(int a, int b) {
// int c = a + b
// return c
// }
KernelRegistry registry;
registry.AddKernel("infrt.test.addi32", INFRT_KERNEL(add));
registry.AddKernel("infrt.test.subi32", INFRT_KERNEL(sub));
CoreRuntimeBuilder builder(&registry);
auto* table = builder.symbol_table();
std::vector<std::pair<std::string, ValueRef>> feeds{
{std::make_pair("a", ValueRef(new Value(1))), //
std::make_pair("b", ValueRef(new Value(2)))}};
builder.FeedInArgs(llvm::ArrayRef<std::pair<std::string, ValueRef>>(
feeds.data(), feeds.size()));
ASSERT_EQ(table->Get<int>("a"), 1);
ASSERT_EQ(table->Get<int>("b"), 2);
ASSERT_EQ(table->size(), 2UL);
auto* op = builder.NewOpExecutable("infrt.test.addi32");
op->AppendArgument("a");
op->AppendArgument("b");
op->SetResults({"c"});
builder.Execute();
auto res = builder.GetResults({"c"});
ASSERT_EQ(res.size(), 1UL);
ASSERT_EQ(res[0].get<int>(), 3);
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/function.h"
namespace infrt {
namespace host_context {} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <string>
namespace infrt {
namespace host_context {
struct Value;
struct ValueRef;
/**
* Base class of all executable Function.
*
* This is used by `infrt.call` op, to execute a function.
*/
class Function {
public:
Function(Function&& other)
: name_(other.name_),
num_arguments_(other.num_arguments_),
num_results_(other.num_results_) {}
Function() = delete;
std::string name() const { return name_; }
size_t num_arguments() const { return num_arguments_; }
size_t num_results() const { return num_results_; }
virtual void Execute(llvm::ArrayRef<Value*> arguments,
llvm::MutableArrayRef<ValueRef> results,
bool is_region = false) const {}
virtual ~Function() = default;
protected:
Function(std::string name, size_t num_arguments, size_t num_results)
: name_(name), num_arguments_(num_arguments), num_results_(num_results) {}
private:
std::string name_;
size_t num_arguments_{};
size_t num_results_{};
};
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/kernel_frame.h"
#include <memory>
namespace infrt {
namespace host_context {
std::ostream& operator<<(std::ostream& os, const KernelFrame& frame) {
os << "KernelFrame: " << frame.GetNumArgs() << " args, "
<< frame.GetNumResults() << " res, " << frame.GetNumResults() << " attrs";
return os;
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <llvm/ADT/ArrayRef.h>
#include <utility>
#include "llvm/ADT/SmallVector.h"
#include "paddle/infrt/host_context/value.h"
namespace infrt::host_context {
/**
* KernelFrame captures the states(input arguments, attributes, results)
* associated with a kernel invocation.
*/
class KernelFrame {
public:
int GetNumArgs() const { return num_arguments_; }
int GetNumResults() const { return num_results_; }
int GetNumAttributes() const {
return value_or_attrs_.size() - num_arguments_ -
(num_results_ == -1 ? 0 : num_results_);
}
template <typename T>
T& GetArgAt(int index) {
CHECK_LT(index, GetNumArgs());
return value_or_attrs_[index]->get<T>();
}
template <typename T>
const T& GetArgAt(int index) const {
CHECK_LT(index, GetNumArgs());
return value_or_attrs_[index]->get<T>();
}
Value* GetArgAt(int index) {
CHECK_LT(index, GetNumArgs());
return value_or_attrs_[index];
}
// Get all arguments.
llvm::ArrayRef<Value*> GetArguments() const {
return GetValues(0, num_arguments_);
}
Value* GetAttributeAt(int idx) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before GetAttributeAt";
CHECK_LT(idx,
static_cast<int>(value_or_attrs_.size() - num_arguments_ -
num_results_));
return value_or_attrs_[num_arguments_ + num_results_ + idx];
}
void AddAttribute(Value* v) {
CHECK_NE(num_results_, -1)
<< "Must call SetNumResults before calling AddAttribute";
value_or_attrs_.emplace_back(v);
}
template <typename T, typename... Args>
void EmplaceResult(Args&&... args) {
EmplaceResult<T>(0, std::forward<Args>(args)...);
}
template <typename T, typename... Args>
void EmplaceResult(int index, Args&&... args) {
SetResultAt(index, T(std::forward<Args>(args)...));
}
template <typename T>
void SetResultAt(int index, T&& value) {
CHECK_LT(index, num_results_) << "Invalid result index";
CHECK(value_or_attrs_[num_arguments_ + index]);
value_or_attrs_[num_arguments_ + index]->set(std::move(value));
}
llvm::ArrayRef<Value*> GetResults() const {
return GetValues(num_arguments_, num_results_);
}
llvm::MutableArrayRef<Value*> GetResults() {
return GetMutableValues(num_arguments_, num_results_);
}
llvm::ArrayRef<Value*> GetValues(size_t from, size_t length) const {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
if (length == 0) return {};
return llvm::makeArrayRef(&value_or_attrs_[from], length);
}
llvm::MutableArrayRef<Value*> GetMutableValues(size_t from, size_t length) {
CHECK_LE(static_cast<int>(from + length), num_arguments_ + num_results_);
if (length == 0) return {};
return llvm::makeMutableArrayRef(&value_or_attrs_[from], length);
}
protected:
int num_arguments_{};
int num_results_{-1};
llvm::SmallVector<Value*, 8> value_or_attrs_;
};
std::ostream& operator<<(std::ostream& os, const KernelFrame& frame);
class KernelFrameBuilder : public KernelFrame {
public:
void AddArgument(Value* value) {
CHECK(value);
CHECK_EQ(num_results_, -1)
<< "Should call AddArgument before calling SetNumResults";
value_or_attrs_.push_back(value);
++num_arguments_;
}
void SetResults(llvm::ArrayRef<Value*> values) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
for (Value* x : values) {
value_or_attrs_.push_back(x);
}
num_results_ = values.size();
}
void SetNumResults(size_t n) {
CHECK_EQ(num_arguments_, static_cast<int>(value_or_attrs_.size()));
CHECK_EQ(num_results_, -1);
num_results_ = n;
for (size_t i = 0; i < n; i++) {
value_or_attrs_.emplace_back(new Value);
}
}
void SetResultAt(int result_id, Value* value) {
CHECK_EQ(static_cast<int>(value_or_attrs_.size()),
num_arguments_ + num_results_)
<< "Call SetNumResults first";
CHECK_LT(result_id + num_arguments_,
static_cast<int>(value_or_attrs_.size()));
CHECK(value);
value_or_attrs_[num_arguments_ + result_id]->set(value);
}
void Reset() {
value_or_attrs_.clear();
num_arguments_ = 0;
num_results_ = -1;
}
};
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/kernel_registry.h"
#include <unordered_map>
#include "glog/logging.h"
#include "llvm/ADT/SmallVector.h"
namespace infrt {
namespace host_context {
struct KernelRegistry::Impl {
std::unordered_map<std::string, KernelImplementation> data;
std::unordered_map<std::string, llvm::SmallVector<std::string, 4>> attr_names;
};
KernelRegistry::KernelRegistry() : impl_(std::make_unique<Impl>()) {}
void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation fn) {
CHECK(!impl_->data.count(key)) << "kernel [" << key
<< "] is registered twice";
impl_->data.emplace(key, fn);
}
void KernelRegistry::AddKernelAttrNameList(
const std::string &key, const std::vector<std::string> &names) {
CHECK(!impl_->attr_names.count(key))
<< "kernel [" << key << "] is registered twice in attribute names";
impl_->attr_names.emplace(
key, llvm::SmallVector<std::string, 4>(names.begin(), names.end()));
}
KernelImplementation KernelRegistry::GetKernel(const std::string &key) const {
auto it = impl_->data.find(key);
return it != impl_->data.end() ? it->second : KernelImplementation{};
}
std::vector<std::string> KernelRegistry::GetKernelList() const {
std::vector<std::string> res(impl_->data.size());
for (auto i : impl_->data) {
res.push_back(i.first);
}
return res;
}
KernelRegistry::~KernelRegistry() {}
size_t KernelRegistry::size() const { return impl_->data.size(); }
KernelRegistry *GetCpuKernelRegistry() {
static auto registry = std::make_unique<KernelRegistry>();
return registry.get();
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
namespace infrt {
namespace host_context {
class KernelFrame;
using KernelImplementation = void (*)(KernelFrame *frame);
/**
* Hold the kernels registered in the system.
*/
class KernelRegistry {
public:
KernelRegistry();
void AddKernel(const std::string &key, KernelImplementation fn);
void AddKernelAttrNameList(const std::string &key,
const std::vector<std::string> &names);
KernelImplementation GetKernel(const std::string &key) const;
std::vector<std::string> GetKernelList() const;
size_t size() const;
~KernelRegistry();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
//! The global CPU kernel registry.
KernelRegistry *GetCpuKernelRegistry();
} // namespace host_context
} // namespace infrt
/**
* compile function RegisterKernels in C way to avoid C++ name mangling.
*/
#ifdef __cplusplus
extern "C" {
#endif
void RegisterKernels(infrt::host_context::KernelRegistry *registry);
#ifdef __cplusplus
}
#endif
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/kernel_registry.h"
#include <gtest/gtest.h>
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt::host_context {
int add_i32(int a, int b) { return a + b; }
TEST(KernelRegistry, basic) {
KernelRegistry registry;
std::string key = "infrt.test.add.i32";
registry.AddKernel(key, INFRT_KERNEL(add_i32));
auto* kernel_impl = registry.GetKernel(key);
ASSERT_TRUE(kernel_impl);
ValueRef a(1);
ValueRef b(2);
KernelFrameBuilder fbuilder;
fbuilder.AddArgument(a.get());
fbuilder.AddArgument(b.get());
fbuilder.SetNumResults(1);
kernel_impl(&fbuilder);
auto results = fbuilder.GetResults();
ASSERT_EQ(results.size(), 1UL);
ASSERT_EQ(results[0]->get<int>(), 3);
}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt {
namespace host_context {} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <llvm/ADT/ArrayRef.h>
#include <utility>
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/value.h"
namespace infrt {
namespace host_context {
template <typename T>
class Argument {
public:
explicit Argument(ValueRef value) : value_(value) {}
ValueRef& value() { return value_; }
const ValueRef& value() const { return value_; }
T& get() const { return value_.get<T>(); }
private:
ValueRef value_;
};
/**
* RemainingArguments collects all remaining arguments in an ArrayRef.
*/
class RemainingArguments {
public:
explicit RemainingArguments(llvm::ArrayRef<Value*> remaining_arguments)
: remaining_arguments_(remaining_arguments) {}
llvm::ArrayRef<Value*> values() const { return remaining_arguments_; }
size_t size() const { return remaining_arguments_.size(); }
const Value* operator[](size_t i) const { return remaining_arguments_[i]; }
private:
llvm::ArrayRef<Value*> remaining_arguments_;
};
/**
* RemainingResults collects all remaining results in a MutableArrayRef.
*/
class RemainingResults {
public:
explicit RemainingResults(llvm::MutableArrayRef<ValueRef> remaining_results)
: remaining_results_(remaining_results) {}
llvm::MutableArrayRef<ValueRef> values() { return remaining_results_; }
size_t size() const { return remaining_results_.size(); }
template <typename T>
const ValueRef& AllocateAt(int index) {
// eagerly create a ValueRef
if (remaining_results_[index].get()) return remaining_results_[index];
remaining_results_[index] = ValueRef(new Value);
return remaining_results_[index];
}
ValueRef& operator[](size_t i) const { return remaining_results_[i]; }
private:
llvm::MutableArrayRef<ValueRef> remaining_results_;
};
template <typename T>
class Result {
public:
explicit Result(ValueRef* result) : result_(result) {}
template <typename... Args>
void Emplace(Args&&... args) {
ValueRef v;
Set(T(std::forward<Args>(args)...));
}
void Set(Argument<T> argument) {
CHECK(!result_->IsValid());
*result_ = argument.value();
}
private:
ValueRef* result_{};
};
template <typename T>
class Attribute {
public:
explicit Attribute(const Value* value) : value_(value) {}
const T& get() const { return value_->get<T>(); }
private:
const Value* value_;
};
template <typename ViewT>
class ArgumentView {
using UnderlyingT = typename ViewT::UnderlyingT;
public:
explicit ArgumentView(Value* value)
: value_(value), arg_(&value->template get<UnderlyingT>()) {}
Value* value() const { return value_; }
ViewT& get() const { return arg_; }
ViewT* operator->() const { return &get(); }
ViewT& operator*() const { return get(); }
private:
Value* value_{};
mutable ViewT arg_;
};
template <typename F, F f>
struct KernelImpl;
template <typename T>
struct TypeTag {};
#define INFRT_KERNEL(...) \
::infrt::host_context::KernelImpl<decltype(&__VA_ARGS__), \
&__VA_ARGS__>::Invoke
template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct KernelImpl<Return (*)(Args...), impl_fn> {
static void Invoke(KernelFrame* frame) {
KernelCallHelper<Args..., TypeTag<int>>::template Invoke<0, 0, 0>(frame);
}
// Helper that introspects the arguments to derive the signature and cast
// parts of the KernelFrame to their type before passing them to impl_fn.
template <typename... RemainingArgs>
struct KernelCallHelper;
// Casts the return value of the kernel, if non-void.
// bool _ is an unnecessary parameter to make compiler allow templace specific
// in non-namespace scope.
template <typename T, bool _>
struct KernelReturnHelper {
static void Invoke(KernelFrame* frame, const Args&... args) {
HandleReturn(frame, impl_fn(args...));
}
};
template <bool _>
struct KernelReturnHelper<void, _> {
static void Invoke(KernelFrame* frame, const Args&... args) {
impl_fn(args...);
}
};
// Specialization to cast a single input argument(Head).
template <typename Head, typename... Tail>
struct KernelCallHelper<Argument<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
Argument<Head> arg(frame->GetArgAt(in_idx));
KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
pargs...,
arg);
}
};
template <typename Head, typename... Tail>
struct KernelCallHelper<ArgumentView<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
ArgumentView<Head> arg(frame->GetArgAt(in_idx));
KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
pargs...,
arg);
}
};
// Specialization to cast a single result argument (Head).
template <typename Head, typename... Tail>
struct KernelCallHelper<Result<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(out_idx != -1,
"Do not place Results after RemainingResults");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes");
Result<Head> arg(&frame->GetResults()[out_idx]);
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx + 1, const_idx>(frame,
pargs...,
arg);
}
};
// Specialization to cast a single attribute.
template <typename Head, typename... Tail>
struct KernelCallHelper<Attribute<Head>, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(const_idx != -1,
"Do not place Attributes after RemainingAttributes");
Attribute<Head> arg(frame->GetAttributeAt(const_idx));
KernelCallHelper<
Tail...>::template Invoke<in_idx, out_idx, const_idx + 1>(frame,
pargs...,
arg);
}
};
// Treat other pointer as an Argument.
template <typename Head, typename... Tail>
struct KernelCallHelper<Head*, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
auto* arg = &frame->GetArgAt<Head>(in_idx);
KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
pargs...,
arg);
}
};
// Treat any other type as an Argument.
template <typename Head, typename... Tail>
struct KernelCallHelper<Head, Tail...> {
using ArgT = std::decay_t<Head>;
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(in_idx != -1,
"Do not place Arguments after RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes.");
auto* value = frame->GetArgAt(in_idx);
auto&& arg = value->get<ArgT>();
KernelCallHelper<
Tail...>::template Invoke<in_idx + 1, out_idx, const_idx>(frame,
pargs...,
arg);
}
};
// RemainingArguments provides an ArrayRef<AsyncValue*> containing all
// remaining arguments. Useful for variadic
// kernels.
template <typename... Tail>
struct KernelCallHelper<RemainingArguments, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(in_idx != -1,
"Do not use more than one RemainingArguments");
static_assert(out_idx == 0, "Arguments should appear before results.");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes");
RemainingArguments remaining_arguments(
frame->GetArguments().drop_front(in_idx));
KernelCallHelper<Tail...>::template Invoke<-1, out_idx, const_idx>(
frame, pargs..., remaining_arguments);
}
};
// RemainingResults provides an MutableArrayRef<AsyncValue*> containing all
// remaining results.
template <typename... Tail>
struct KernelCallHelper<RemainingResults, Tail...> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
static_assert(out_idx != -1, "Do not use more than one RemainingResults");
static_assert(const_idx == 0,
"Arguments and results should appear before attributes");
llvm::MutableArrayRef<Value*> returned_results =
frame->GetResults().drop_front(out_idx);
llvm::SmallVector<ValueRef, 4> result_values;
for (size_t i = 0; i < returned_results.size(); i++)
result_values.emplace_back(returned_results[i]);
RemainingResults remaining_results(result_values);
KernelCallHelper<Tail...>::template Invoke<in_idx, -1, const_idx>(
frame, pargs..., remaining_results);
}
};
// No arguments left.
template <typename T>
struct KernelCallHelper<TypeTag<T>> {
template <int in_idx, int out_idx, int const_idx, typename... PreviousArgs>
static void Invoke(KernelFrame* frame, const PreviousArgs&... pargs) {
KernelReturnHelper<Return, false>::Invoke(frame, pargs...);
}
};
// Handle pair result
template <typename T0, typename T1>
static void HandleReturn(KernelFrame* frame, std::pair<T0, T1>&& t) {
CHECK_EQ(frame->GetNumResults(), 2);
StoreResultAt(frame, 0, std::move(t.first));
StoreResultAt(frame, 1, std::move(t.second));
}
// Store the function result back to the output Value in KernelFrame.
template <typename T>
static void HandleReturn(KernelFrame* frame, T&& t) {
assert(frame->GetNumResults() == 1 && "Extra results passed to kernel.");
StoreResultAt(frame, 0, std::forward<T>(t));
}
// Store result as an Value output in KernelFrame.
template <typename T>
static void StoreResultAt(KernelFrame* frame, int index, T&& t) {
frame->EmplaceResult<std::decay_t<T>>(index, std::forward<T>(t));
}
};
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/kernel_utils.h"
#include <gtest/gtest.h>
namespace infrt::host_context {
int add_i32(int a, int b) { return a + b; }
float add_f32(float a, float b) { return a + b; }
std::pair<int, float> add_pair(int a, float b) { return {a, b}; }
TEST(KernelImpl, i32) {
KernelFrameBuilder fbuilder;
ValueRef a(new Value(1));
ValueRef b(new Value(2));
fbuilder.AddArgument(a.get());
fbuilder.AddArgument(b.get());
fbuilder.SetNumResults(1);
INFRT_KERNEL(add_i32)(&fbuilder);
auto results = fbuilder.GetResults();
ASSERT_EQ(results.size(), 1UL);
ASSERT_EQ(results.front()->get<int>(), 3);
}
TEST(KernelImpl, f32) {
KernelFrameBuilder fbuilder;
ValueRef a(new Value(1.f));
ValueRef b(new Value(2.f));
fbuilder.AddArgument(a.get());
fbuilder.AddArgument(b.get());
fbuilder.SetNumResults(1);
INFRT_KERNEL(add_f32)(&fbuilder);
auto results = fbuilder.GetResults();
ASSERT_EQ(results.size(), 1UL);
ASSERT_EQ(results.front()->get<float>(), 3.f);
}
TEST(KernelImpl, pair) {
KernelFrameBuilder fbuilder;
ValueRef a(new Value(1));
ValueRef b(new Value(3.f));
fbuilder.AddArgument(a.get());
fbuilder.AddArgument(b.get());
fbuilder.SetNumResults(2);
INFRT_KERNEL(add_pair)(&fbuilder);
auto results = fbuilder.GetResults();
ASSERT_EQ(results.size(), 2UL);
ASSERT_EQ(results[0]->get<int>(), 1);
ASSERT_EQ(results[1]->get<float>(), 3.f);
}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <llvm/Support/CommandLine.h>
#include <iostream>
#include <string>
#include "llvm/Support/DynamicLibrary.h"
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
static llvm::cl::list<std::string> cl_shared_libs( // NOLINT
"shared_libs",
llvm::cl::desc("Specify shared library with kernels."),
llvm::cl::ZeroOrMore,
llvm::cl::MiscFlags::CommaSeparated);
int main(int argc, char** argv) {
using namespace llvm; // NOLINT
using namespace infrt; // NOLINT
cl::opt<std::string> input_file("i",
cl::desc("Specify input filename"),
cl::value_desc("input file name"));
cl::ParseCommandLineOptions(argc, argv);
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto module = dialect::LoadMlirFile(input_file.c_str(), context);
host_context::KernelRegistry registry;
kernel::RegisterBasicKernels(&registry);
kernel::RegisterTestKernels(&registry);
kernel::RegisterTensorShapeKernels(&registry);
kernel::RegisterTensorKernels(&registry);
kernel::RegisterControlFlowKernels(&registry);
// load extra shared library
for (const auto& lib_path : cl_shared_libs) {
std::string err;
llvm::sys::DynamicLibrary dynLib =
llvm::sys::DynamicLibrary::getPermanentLibrary(lib_path.c_str(), &err);
if (!dynLib.isValid()) {
llvm::errs() << "Load shared library failed. Error: " << err << "\n";
return 1;
}
if (auto reg_sym = dynLib.SearchForAddressOfSymbol("RegisterKernels")) {
auto reg_func =
reinterpret_cast<void (*)(host_context::KernelRegistry*)>(reg_sym);
reg_func(&registry);
} else {
llvm::outs() << "Symbol \"RegisterKernels\" not found in \"" << lib_path
<< "\". Skip.\n";
}
}
host_context::TestMlir(module.get(), &registry);
std::cout << std::endl;
return 0;
}
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include <glog/logging.h>
#include <string> // NOLINT
#include "paddle/infrt/common/common.h"
#include "paddle/infrt/host_context/core_runtime.h"
namespace infrt {
namespace host_context {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
std::string buffer;
llvm::raw_string_ostream os(buffer);
op.print(os);
os.flush();
return buffer;
}
MlirFunctionExecutable::MlirFunctionExecutable(
mlir::FuncOp func_op,
KernelRegistry* kernel_registry,
MlirToRuntimeTranslator::function_defs_t& function_table)
: Function(func_op.getName().str(),
func_op.getNumArguments(),
func_op.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_),
region_(&func_op.getRegion()),
core_runtime_builder_(kernel_registry),
function_table_(function_table) {}
MlirFunctionExecutable::MlirFunctionExecutable(
mlir::Region* region,
mlir::FunctionType func_type,
KernelRegistry* kernel_registry,
MlirToRuntimeTranslator::function_defs_t& function_table)
: Function("", func_type.getNumInputs(), func_type.getNumResults()),
MlirToRuntimeTranslator(&core_runtime_builder_),
region_(region),
core_runtime_builder_(kernel_registry),
function_table_(function_table) {}
void MlirFunctionExecutable::BuildExecutables(
llvm::ArrayRef<Value*> arguments,
llvm::MutableArrayRef<ValueRef> results,
bool is_region) {
CHECK_EQ(arguments.size(), num_arguments());
// We use the function call's arguments as op_executable's operands to avoid
// copy.
for (size_t i = 0; i < num_arguments(); i++) {
AddValue(region_->getArgument(i), arguments[i]);
}
// build the program
auto& blocks = region_->getBlocks();
CHECK_EQ(blocks.size(), 1UL)
<< "function with more than one block is not supported yet";
llvm::SmallVector<Value*, 3> runtime_results;
for (auto& op : blocks.front()) {
if (EmitConstantOp(&op)) continue;
if (EmitBuildShapeOp(&op)) continue;
llvm::SmallVector<mlir::Value, 3> mlir_results;
if (EmitReturnOp(&op, &mlir_results)) {
if (!is_region) {
for (auto v : mlir_results) {
runtime_results.push_back(GetValue(v));
}
}
continue;
}
if (EmitCallOp(&op, &function_table_)) continue;
if (EmitGeneralOp(&op)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op);
}
// after the block is built, we can get the result values of the whole
// function call in the runtime_results.
mlir::SmallVector<Value*, 3> results_copied;
if (!is_region) {
for (ValueRef& x : results) {
results_copied.push_back(x.get());
}
}
// set a lambda function to help copy the results from the runtime results in
// the local function to outer program.
CHECK_EQ(results_copied.size(), runtime_results.size());
this->copy_res_fn_ = [results_copied, runtime_results] {
VLOG(4) << "copy results to result";
for (size_t i = 0; i < results_copied.size(); i++) {
VLOG(4) << ".. copy " << runtime_results[i] << " to "
<< results_copied[i];
CopyTo(*runtime_results[i], results_copied[i]);
}
};
}
void MlirFunctionExecutable::Execute(llvm::ArrayRef<Value*> arguments,
llvm::MutableArrayRef<ValueRef> results,
bool is_region) const {
CHECK_EQ(arguments.size(), num_arguments());
CHECK_EQ(results.size(), num_results());
if (core_runtime_builder_.num_ops() == 0) {
Reference(this).BuildExecutables(arguments, results, is_region);
}
Reference(&core_runtime_builder_).Execute();
copy_res_fn_();
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/IR/Function.h>
#include <string>
#include <unordered_map>
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
namespace infrt {
namespace host_context {
struct KernelRegistry;
/**
* Executable function for a given MLIR function definition, mainly used in two
* scenerios:
* 1. infrt.call op
* 2. main function call
*
* A MlirFunctionExecutable might have one or more arguments and results.
*/
class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator {
public:
using function_defs_t = std::unordered_map<std::string, mlir::FuncOp>;
MlirFunctionExecutable(mlir::FuncOp func_op,
KernelRegistry* kernel_registry,
function_defs_t& function_table); // NOLINT
MlirFunctionExecutable(
mlir::Region* region,
mlir::FunctionType func_type,
KernelRegistry* kernel_registry,
MlirToRuntimeTranslator::function_defs_t& function_table); // NOLINT
/**
* Execute the function with the given arguments and results.
* NOTE the \param arguments and \param results should not be altered.
*/
void Execute(llvm::ArrayRef<Value*> arguments,
llvm::MutableArrayRef<ValueRef> results,
bool is_region = false) const;
private:
/**
* Build the runtime executables once the function call arguments and results
* are passed in.
* This will trigger in the first execution.
*/
void BuildExecutables(llvm::ArrayRef<Value*> arguments,
llvm::MutableArrayRef<ValueRef> results,
bool is_region);
private:
mlir::Region* region_{};
CoreRuntimeBuilder core_runtime_builder_;
MlirToRuntimeTranslator::function_defs_t& function_table_;
std::function<void()> copy_res_fn_;
};
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/mlir_program_executor.h"
namespace infrt {
namespace host_context {} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/Module.h>
#include <mlir/IR/OperationSupport.h>
#include <unordered_map>
#include <memory>
#include <string>
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include "paddle/infrt/host_context/op_executable.h"
namespace infrt {
namespace host_context {
/**
* This get a MLIR program as input, it compiles it into runtime program, and
* one can retrieve the function and execute
* it by passing the input arguments.
*/
class MlirProgramExecutor : public MlirToRuntimeTranslator {
public:
CoreRuntimeBuilder runtime_builder;
mlir::ModuleOp module;
function_defs_t function_defs;
MlirProgramExecutor(mlir::ModuleOp module, KernelRegistry* registry)
: MlirToRuntimeTranslator(module, &runtime_builder),
runtime_builder(registry),
module(module) {}
// Build functions and generate executables.
void BuildFunctions() { EmitFunctions(); }
void EmitFunction(mlir::FuncOp op) override {
LOG(INFO) << "Emit function: " << op.getName().str();
function_defs[op.getName().str()] = op;
func_executables_.emplace(
op.getName().str(),
new MlirFunctionExecutable(
op, runtime_builder.kernel_registry(), function_defs));
}
MlirFunctionExecutable* LookupFunc(const std::string& name) {
auto it = func_executables_.find(name);
if (it != func_executables_.end()) {
return it->second.get();
}
return nullptr;
}
private:
std::unordered_map<std::string, std::unique_ptr<MlirFunctionExecutable>>
func_executables_;
};
} // namespace host_context
} // namespace infrt
// CHECK-LABEL: basic
func @basic() -> f32 {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
// CHECK: 1
"infrt.print.f32"(%v0) : (f32) -> ()
// CHECK: 2
"infrt.print.f32"(%v1) : (f32) -> ()
// CHECK: 3
"infrt.print.f32"(%v2) : (f32) -> ()
%v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32
// CHECK: 6
"infrt.print.f32"(%v3) : (f32) -> ()
infrt.return %v3 : f32
}
// CHECK-LABEL: basic1
// Check the mlir executor can work with more than one function in a file.
func @basic1() -> () {
%v0 = infrt.constant.f32 1.0
"infrt.print.f32"(%v0) : (f32) -> ()
// CHECK: 1
infrt.return
}
\ No newline at end of file
// CHECK-LABEL: build_tensor1
func @build_tensor1() {
%a = dt.create_uninit_tensor.f32 [3, 4] -> !infrt.tensor<X86, NCHW, F32>
dt.fill_tensor_with_constant.f32 (%a : !infrt.tensor<X86, NCHW, F32>) {value=1.0:f32}
// CHECK: tensor: shape=shape[3,4], values=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
dt.print_tensor (%a : !infrt.tensor<X86, NCHW, F32>)
infrt.return
}
// CHECK-LABEL: build_tensor1
func @build_tensor1() {
%a = ts.build_shape [1:i64, 57:i64, 92:i64]
// CHECK: shape[1,57,92]
ts.print_shape %a
infrt.return
}
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include <llvm/Support/SourceMgr.h>
#include <mlir/Dialect/StandardOps/IR/Ops.h>
#include <mlir/IR/Diagnostics.h>
#include <mlir/IR/Function.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/Parser.h>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "boost/optional.hpp"
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/dialect/tensor_shape.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/op_executable.h"
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::host_context {
template <typename T>
std::string DumpToString(T& op) { // NOLINT
std::string buffer;
llvm::raw_string_ostream os(buffer);
op.print(os);
os.flush();
return buffer;
}
struct MlirToRuntimeTranslator::Impl {
mlir::ModuleOp module;
// The runtime for a function call.
CoreRuntimeBuilder* runtime{};
// The current working op, the translator process the ops one by one, each
// time it updates `cur_op` here to current op
// working on.
OpExecutableBuilder* cur_op{};
// record the current function name.
std::string cur_func_name;
// Name to function definitions.
std::unordered_map<std::string, mlir::FuncOp> func_defs;
// Map from an operation to its results.
std::unordered_map<const mlir::Operation*, std::vector<ValueRef>> op_results;
llvm::DenseMap<mlir::Value, ValueRef> value_map;
};
bool MlirToRuntimeTranslator::EmitConstantOp(mlir::Operation* op) {
if (!infrt::Startswith(op->getName().getStringRef().str(), "infrt.constant"))
return false;
VLOG(3) << "Emitting constant op [" << op->getName().getStringRef().str()
<< "]";
auto attr = op->getAttr("value");
if (attr.isa<mlir::FloatAttr>()) {
if (attr.getType().isF32()) {
impl_->op_results[op] = {ValueRef(
static_cast<float>(attr.cast<mlir::FloatAttr>().getValueAsDouble()))};
} else if (attr.getType().isF64()) {
impl_->op_results[op] = {ValueRef(static_cast<double>(
attr.cast<mlir::FloatAttr>().getValueAsDouble()))};
} else {
LOG(FATAL) << "Not supported attribute type";
}
return true;
}
if (attr.isa<mlir::IntegerAttr>()) {
if (attr.getType().isInteger(32)) {
impl_->op_results[op] = {ValueRef(
static_cast<int32_t>(attr.cast<mlir::IntegerAttr>().getSInt()))};
} else if (attr.getType().isInteger(64)) {
impl_->op_results[op] = {ValueRef(
static_cast<int64_t>(attr.cast<mlir::IntegerAttr>().getSInt()))};
} else if (attr.getType().isInteger(1)) {
impl_->op_results[op] = {
ValueRef(static_cast<bool>(attr.cast<mlir::IntegerAttr>().getInt()))};
} else {
LOG(FATAL) << "Not supported attribute type";
}
return true;
}
LOG(FATAL) << "Not supported constant attribute type";
return true;
}
template <>
boost::optional<int32_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
if (val.getType().isInteger(32)) {
return val.getInt();
}
}
return boost::none;
}
template <>
boost::optional<int64_t> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::IntegerAttr>()) return boost::none;
if (attr->isa<mlir::IntegerAttr>()) {
auto val = attr->cast<mlir::IntegerAttr>();
if (val.getType().isInteger(64)) {
return val.getInt();
}
}
return boost::none;
}
// TODO(Superjomn) Make double and float parsing share some thing.
template <>
boost::optional<float> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
if (val.getType().isF32()) return val.getValueAsDouble();
}
return boost::none;
}
template <>
boost::optional<double> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::FloatAttr>()) return boost::none;
if (attr->isa<mlir::FloatAttr>()) {
auto val = attr->cast<mlir::FloatAttr>();
if (val.getType().isF64()) return val.getValueAsDouble();
}
return boost::none;
}
template <>
boost::optional<std::string> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::StringAttr>()) return boost::none;
return attr->cast<mlir::StringAttr>().getValue().str();
}
#define PROCESS_ARRAY_INT(type__, bits__) \
template <> \
boost::optional<std::vector<type__>> MlirToRuntimeTranslator::EmitAttribute( \
const mlir::Attribute* attr) { \
if (!attr->isa<mlir::ArrayAttr>()) return boost::none; \
auto array = attr->cast<mlir::ArrayAttr>(); \
CHECK(!array.empty()); \
\
if (!array[0].getType().isInteger(bits__)) { \
return boost::none; \
} \
\
std::vector<type__> res; \
for (auto& v : array) { \
res.push_back(v.cast<mlir::IntegerAttr>().getInt()); \
} \
return res; \
}
PROCESS_ARRAY_INT(int16_t, 16);
PROCESS_ARRAY_INT(int32_t, 32);
PROCESS_ARRAY_INT(int64_t, 64);
template <>
boost::optional<std::vector<float>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF32()) return boost::none;
std::vector<float> res;
for (auto& v : array) {
res.push_back(v.cast<mlir::FloatAttr>().getValueAsDouble());
}
return res;
}
template <>
boost::optional<std::vector<double>> MlirToRuntimeTranslator::EmitAttribute(
const mlir::Attribute* attr) {
if (!attr->isa<mlir::ArrayAttr>()) return boost::none;
auto array = attr->cast<mlir::ArrayAttr>();
CHECK(!array.empty());
if (!array[0].getType().isF64()) return boost::none;
std::vector<double> res;
for (auto& v : array) {
res.push_back(v.cast<mlir::FloatAttr>().getValueAsDouble());
}
return res;
}
static bool IsReturn(mlir::Operation* op) {
return op->getName().getStringRef() == "infrt.return";
}
bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
CHECK(impl_->runtime);
impl_->cur_op =
impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());
VLOG(3) << "processing general op : " << op->getName().getStringRef().str();
// process operands
for (int i = 0, e = op->getNumOperands(); i < e; i++) {
// function argument as value
auto operand = op->getOperand(i);
if (operand.getKind() == mlir::Value::Kind::BlockArgument) {
mlir::BlockArgument arg = operand.dyn_cast<mlir::BlockArgument>();
Value* arg_value = GetValue(arg);
impl_->cur_op->AppendArgument(arg_value);
VLOG(3) << "* op mlir operand: " << DumpToString(arg) << " "
<< GetValue(arg);
continue;
}
// normal value
Value* arg_value = GetValue(operand);
if (!arg_value) {
auto upstream_op = operand.getDefiningOp();
arg_value = GetOpResult(upstream_op);
}
CHECK(arg_value) << "No-exist argument value found: "
<< DumpToString(operand);
impl_->cur_op->AppendArgument(arg_value);
VLOG(3) << "* op mlir operand: " << DumpToString(operand) << " "
<< GetValue(operand) << " vs " << arg_value;
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
VLOG(3) << "* op mlir res: " << DumpToString(res) << " " << GetValue(res);
}
impl_->cur_op->SetResults(res_values);
#ifdef INFRT_DEBUG
{
VLOG(3) << "check result";
for (int i = 0; i < impl_->cur_op->frame().GetNumResults(); i++) {
VLOG(3) << "+ res value: " << impl_->cur_op->frame().GetResults()[i];
}
}
#endif
// process attributes
auto attrs = op->getAttrs();
for (size_t i = 0; i < attrs.size(); i++) {
auto& attr = attrs[i];
if (auto v = EmitAttribute<int32_t>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<int64_t>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<float>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<double>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(*v));
} else if (auto v = EmitAttribute<std::string>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int16_t>>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int32_t>>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<int64_t>>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<float>>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else if (auto v = EmitAttribute<std::vector<double>>(&attr.second)) {
impl_->cur_op->AppendAttribute(new Value(std::move(*v)));
} else {
LOG(FATAL) << "Not supported attribute type";
}
}
// process regions, we treat regions as attribute.
auto num_regions = op->getNumRegions();
if (num_regions > 0) {
CHECK_EQ(num_regions, 1UL)
<< "op with more than one region is not supported yet.";
auto& region = op->getRegions().front();
auto num_blocks = region.getBlocks().size();
CHECK_EQ(num_blocks, 1UL)
<< "region with more than one block is not supported yet.";
// process arguments
llvm::SmallVector<mlir::Type, 4> inputs;
auto& block = region.getBlocks().front();
for (auto arg : block.getArguments()) inputs.push_back(arg.getType());
// process results
// NOTE: if an op contains a region, we simply ignore the region's return
// values,
// or its return values will conflict with op's return values.
llvm::SmallVector<mlir::Type, 0> results;
auto func_type =
mlir::FunctionType::get(inputs, results, region.getContext());
auto* function = impl_->cur_op->CreateFunctionExecutable(
&region, func_type, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function));
}
return true;
}
bool MlirToRuntimeTranslator::EmitReturnOp(
mlir::Operation* op, llvm::SmallVectorImpl<mlir::Value>* results) {
CHECK(results);
if (op->getName().getStringRef() == "infrt.return") {
for (size_t i = 0; i < op->getNumOperands(); i++) {
results->push_back(op->getOperand(i));
}
return true;
}
return false;
}
bool MlirToRuntimeTranslator::EmitFunctions() {
for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
EmitFunction(func_op);
}
return true;
}
void MlirToRuntimeTranslator::EmitFunction(mlir::FuncOp op) {
impl_->func_defs[op.getName().str()] = op;
}
Value* MlirToRuntimeTranslator::GetOpResult(mlir::Operation* op) {
auto it = impl_->op_results.find(op);
return it == impl_->op_results.end() ? nullptr : it->second.front().get();
}
Value* MlirToRuntimeTranslator::GetValue(mlir::Value value) {
auto it = impl_->value_map.find(value);
return it == impl_->value_map.end() ? nullptr : it->second.get();
}
Value* MlirToRuntimeTranslator::AddValue(mlir::Value value) {
auto res = impl_->value_map.try_emplace(value, ValueRef(new Value));
CHECK(res.second) << "Duplicate add mlir value [" << DumpToString(value)
<< "]";
return res.first->second.get();
}
MlirToRuntimeTranslator::~MlirToRuntimeTranslator() {}
void MlirToRuntimeTranslator::UpdateCurFuncName(const std::string& name) {
impl_->cur_func_name = std::string(name);
}
MlirToRuntimeTranslator::MlirToRuntimeTranslator(mlir::ModuleOp module,
CoreRuntimeBuilder* runtime)
: impl_(new Impl) {
CHECK(runtime);
impl_->module = module;
impl_->runtime = runtime;
}
bool MlirToRuntimeTranslator::EmitBuildShapeOp(mlir::Operation* op) {
if (op->getName().getStringRef() != "ts.build_shape") return false;
auto value = op->getAttr("value");
CHECK(value.isa<mlir::ArrayAttr>());
auto values = value.cast<mlir::ArrayAttr>().getValue();
std::vector<int64_t> dims;
for (auto& attr_v : values) {
dims.push_back(attr_v.cast<mlir::IntegerAttr>().getInt());
}
impl_->op_results[op] = {
ValueRef(new Value(tensor::TensorShape(llvm::ArrayRef<int64_t>(dims))))};
return true;
}
bool MlirToRuntimeTranslator::EmitCallOp(mlir::Operation* op,
function_defs_t* function_table) {
CHECK(op);
CHECK(function_table);
if (op->getName().getStringRef() != "infrt.call") return false;
impl_->cur_op =
impl_->runtime->NewOpExecutable(op->getName().getStringRef().str());
auto callee = op->getAttr("callee");
auto callee_name = callee.dyn_cast<mlir::FlatSymbolRefAttr>();
// process arguments
for (size_t i = 0; i < op->getNumOperands(); i++) {
auto operand = op->getOperand(i);
auto* arg_value = GetValue(operand);
if (!arg_value) {
auto upstream_op = operand.getDefiningOp();
arg_value = GetOpResult(upstream_op);
}
CHECK(arg_value) << "No-exist argument value found: "
<< DumpToString(operand);
impl_->cur_op->AppendArgument(arg_value);
}
// process results
llvm::SmallVector<Value*, 4> res_values;
for (int i = 0, e = op->getNumResults(); i < e; i++) {
auto res = op->getResult(i);
res_values.push_back(AddValue(res));
}
impl_->cur_op->SetResults(res_values);
// process attribute
auto& table = function_table ? *function_table : impl_->func_defs;
{
// lookup the callee function
auto it = table.find(callee_name.getValue().str());
CHECK(it != table.end()) << "can't find function ["
<< callee_name.getValue().str() << "]";
auto* function =
impl_->cur_op->CreateFunctionExecutable(it->second, &impl_->func_defs);
impl_->cur_op->AppendAttribute(new Value(function));
}
VLOG(3) << "Emit call " << callee_name.getValue().str() << " "
<< impl_->cur_op->frame();
return true;
}
MlirToRuntimeTranslator::MlirToRuntimeTranslator(CoreRuntimeBuilder* runtime)
: impl_(new Impl) {
CHECK(runtime);
impl_->runtime = runtime;
}
Value* MlirToRuntimeTranslator::AddValue(mlir::Value mlir_value, Value* value) {
auto it = impl_->value_map.try_emplace(mlir_value, ValueRef(value));
CHECK(it.second) << "duplicate add value " << DumpToString(mlir_value);
return value;
}
void MlirToRuntimeTranslate(mlir::ModuleOp module,
CoreRuntimeBuilder* runtime) {
MlirToRuntimeTranslator(module, runtime).Run();
}
/**
* Execute the mlir program in test mode -- print some debug information to
* stdout.
*/
class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
public:
CoreRuntimeBuilder core_runtime;
MlirProgramTestExecutor(mlir::ModuleOp module, KernelRegistry* registry)
: MlirToRuntimeTranslator(module, &core_runtime),
core_runtime(registry),
registry(registry) {
CHECK(registry);
}
void Run() {
EmitFunctions();
CHECK(registry);
for (auto func_op : impl_->module.getOps<mlir::FuncOp>()) {
VLOG(3) << "Running function " << func_op.getName().str();
EmitAndRunFuncWithoutArguments(func_op);
}
}
protected:
std::unordered_map<std::string, mlir::FuncOp> func_def_table;
void EmitFunction(mlir::FuncOp op) override {
CHECK(!impl_->func_defs.count(op.getName().str()))
<< "Duplicate function defition found for function ["
<< op.getName().str();
impl_->func_defs.emplace(op.getName().str(), op);
}
private:
void EmitAndRunFuncWithoutArguments(mlir::FuncOp func) {
// print the function name for llvm FileChecker macro, CHECK-LABEL
std::cout << '@' << func.getName().str() << std::endl;
if (func.getNumArguments() ==
0) { // an entry function, execute it immediately
VLOG(3) << "executing function " << func.getName().str();
// Emit and execute each function
CoreRuntimeBuilder runtime(registry);
impl_->runtime = &runtime;
auto& blocks = func.getBlocks();
CHECK_EQ(blocks.size(), 1UL)
<< "function with more than one block is not supported yet";
for (auto& op : blocks.front()) {
if (EmitConstantOp(&op)) continue;
if (EmitBuildShapeOp(&op)) continue;
llvm::SmallVector<mlir::Value, 3> results;
if (EmitReturnOp(&op, &results)) continue;
if (EmitCallOp(&op, &impl_->func_defs)) continue;
if (EmitGeneralOp(&op)) continue;
LOG(FATAL) << "Not supported op: " << DumpToString(op);
}
runtime.Execute();
} else {
VLOG(2) << "get an callable function: " << func.getName().str();
}
}
private:
KernelRegistry* registry{};
};
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry) {
MlirProgramTestExecutor execute(module, registry);
execute.Run();
}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/SmallVector.h>
#include <boost/optional.hpp>
#include <memory> // NOLINT
#include <string> //NOLINT
#include <unordered_map> // NOLINT
namespace mlir {
class FuncOp;
class ModuleOp;
class Operation;
class Attribute;
class Value;
} // namespace mlir
namespace infrt::host_context {
class CoreRuntimeBuilder;
class Value;
class ValueRef;
class KernelRegistry;
/**
* MlirToRuntimeTranslator helps to translate a MLIR program to a CoreRuntime.
* This is the base class of all the modules those parse a MLIR program and
* finally generate a CoreRuntime.
*/
class MlirToRuntimeTranslator {
public:
//! Holds all the function definitions.
using function_defs_t = std::unordered_map<std::string, mlir::FuncOp>;
explicit MlirToRuntimeTranslator(CoreRuntimeBuilder* runtime);
MlirToRuntimeTranslator(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
void Run() { EmitFunctions(); }
virtual ~MlirToRuntimeTranslator();
protected:
//! Emit a "infrt.constant.*" operation, return true if succeed.
bool EmitConstantOp(mlir::Operation* op);
//! Emit a "infrt.return" operation.
bool EmitReturnOp(mlir::Operation* op,
llvm::SmallVectorImpl<mlir::Value>* results);
//! Emit a "ts.build_shape" operation.
bool EmitBuildShapeOp(mlir::Operation* op);
//! Emit an operation other than the special cases above.
bool EmitGeneralOp(mlir::Operation* op);
//! Emit all the functions.
bool EmitFunctions();
//! Emit a single function, this is an API that should be implemented by
//! inherients.
virtual void EmitFunction(mlir::FuncOp op);
bool EmitCallOp(mlir::Operation* op, function_defs_t* function_table);
template <typename T>
boost::optional<T> EmitAttribute(const mlir::Attribute* attr);
Value* GetOpResult(mlir::Operation* op);
Value* GetValue(mlir::Value value);
Value* AddValue(mlir::Value value);
Value* AddValue(mlir::Value mlir_value, Value* value);
void UpdateCurFuncName(const std::string& name);
protected:
struct Impl;
std::unique_ptr<Impl> impl_;
};
/**
* Build a CoreRuntime from a MLIR module.
*/
void MlirToRuntimeTranslate(mlir::ModuleOp module, CoreRuntimeBuilder* runtime);
/**
* Execute a MLIR program, that is execute all the functions without input
* arguments.
* This is mainly used by testcase.
* @param module a MLIR module.
* @param registry the kernel registry containing all the valid kernels.
*/
void TestMlir(mlir::ModuleOp module, KernelRegistry* registry);
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/mlir_to_runtime_translate.h"
#include <gtest/gtest.h>
#include <llvm/Support/FormatVariadic.h>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/dialect/mlir_loader.h"
#include "paddle/infrt/host_context/core_runtime.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/host_context/mlir_program_executor.h"
#include "paddle/infrt/kernel/basic_kernels.h"
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include "paddle/infrt/kernel/tensor_kernels.h"
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include "paddle/infrt/kernel/test_kernels.h"
namespace infrt::host_context {
TEST(MlirToRuntimeTranslate, basic) {
mlir::MLIRContext context;
auto source = R"ROC(
func @main() -> () {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
%v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32
"infrt.print.f32"(%v1) : (f32) -> ()
infrt.return
}
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
kernel::RegisterIntBasicKernels(&registry);
TestMlir(module.get(), &registry);
}
TEST(TestMlir, basic) {
mlir::MLIRContext context;
auto source = R"ROC(
func @main() -> () {
%v0 = infrt.constant.f32 1.0
%v1 = infrt.constant.f32 2.0
%v2 = "infrt.add.f32"(%v0, %v1) : (f32, f32) -> f32
%v3 = "infrt.mul.f32"(%v2, %v1) : (f32, f32) -> f32
"infrt.print.f32"(%v1) : (f32) -> ()
infrt.return
}
)ROC";
auto module = dialect::LoadMlirSource(&context, source);
module->verify();
KernelRegistry registry;
kernel::RegisterFloatBasicKernels(&registry);
kernel::RegisterIntBasicKernels(&registry);
TestMlir(module.get(), &registry);
}
TEST(TestMlir, shadow_copy_tensor_profile) {
mlir::MLIRContext* context = infrt::Global::getMLIRContext();
auto head = R"ROC(
func @predict(%a: !infrt.tensor<X86, NCHW, F32>, %b: !infrt.tensor<X86, NCHW, F32>) -> (!infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>) {
)ROC";
auto tpl0 =
"%a{0} = dt.shallow_copy_tensor %a : !infrt.tensor<X86, NCHW, F32> -> "
"!infrt.tensor<X86, NCHW, F32>";
auto tpl1 =
"%b{0} = dt.shallow_copy_tensor %b : !infrt.tensor<X86, NCHW, F32> -> "
"!infrt.tensor<X86, NCHW, F32>";
auto end = R"ROC(
infrt.return %a0, %b0: !infrt.tensor<X86, NCHW, F32>, !infrt.tensor<X86, NCHW, F32>
}
)ROC";
std::stringstream ss;
ss << head;
for (int i = 0; i < 2000; i++) {
ss << llvm::formatv(tpl0, i).str() << "\n";
ss << llvm::formatv(tpl1, i).str() << "\n";
}
ss << end;
auto content = ss.str();
// LOG(INFO) << "content: " << content << std::endl;
auto module = dialect::LoadMlirSource(context, content);
module->verify();
host_context::KernelRegistry registry;
kernel::RegisterBasicKernels(&registry);
kernel::RegisterTestKernels(&registry);
kernel::RegisterTensorShapeKernels(&registry);
kernel::RegisterTensorKernels(&registry);
kernel::RegisterControlFlowKernels(&registry);
MlirProgramExecutor executor(*module, &registry);
executor.BuildFunctions();
auto* func = executor.LookupFunc("predict");
ASSERT_TRUE(func);
std::vector<Value*> in_args;
std::vector<ValueRef> out_args(
{ValueRef(new Value(tensor::DenseHostTensor())),
ValueRef(new Value(tensor::DenseHostTensor()))});
auto create_tensor = [] {
tensor::DenseHostTensor a(tensor::TensorShape{{200, 3000}},
DType(DType::Kind::F32));
auto* data = reinterpret_cast<float*>(a.raw_data());
for (int i = 0; i < a.shape().GetNumElements(); i++) {
data[i] = i;
}
return a;
};
std::vector<ValueRef> inputs({ValueRef(new Value(create_tensor())),
ValueRef(new Value(create_tensor()))});
in_args.assign({inputs[0].get(), inputs[1].get()});
for (int i = 0; i < 500; i++) {
func->Execute(
llvm::ArrayRef<Value*>(in_args.data(), in_args.size()),
llvm::MutableArrayRef<ValueRef>(out_args.data(), out_args.size()));
}
}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/op_executable.h"
#include <string>
#include "paddle/infrt/host_context/kernel_frame.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt::host_context {
struct OpExecutable::Impl {
Impl(const std::string& op_name,
SymbolTable* symbol_table,
KernelRegistry* kernel_registry)
: name(op_name),
symbol_table(symbol_table),
kernel_registry(kernel_registry ? kernel_registry
: GetCpuKernelRegistry()) {
CHECK(kernel_registry);
}
inline bool to_execute() const {
return !run_once || run_once && !has_executed;
}
inline void MarkRun() { has_executed = true; }
std::string name;
SymbolTable* symbol_table{};
KernelFrameBuilder frame;
KernelRegistry* kernel_registry{};
std::unique_ptr<MlirFunctionExecutable> mlir_function_executable;
KernelImplementation kernel_impl{};
//! Tell whether this Op should be executed only once.
bool run_once{};
//! Tell whether this op has been executed.
bool has_executed{};
};
OpExecutable::OpExecutable(OpExecutable::Impl* impl) : impl_(impl) {}
const std::string& OpExecutable::name() const { return impl_->name; }
OpExecutableBuilder::OpExecutableBuilder(const std::string& op_name,
SymbolTable* symbol_table,
KernelRegistry* kernel_registry)
: OpExecutable(new Impl(op_name, symbol_table, kernel_registry)) {
CHECK(impl_);
// CPU kernel registry is the default KernelRegistry.
impl_->kernel_impl = impl_->kernel_registry->GetKernel(
std::string(op_name.data(), op_name.size()));
// TODO(Superjomn) support other device other than CPU.
CHECK(impl_->kernel_impl) << "No CPU kernel called " << op_name;
if (op_name == "dt.get_param") {
impl_->run_once = true;
}
}
void OpExecutableBuilder::AppendArgument(const std::string& name) {
if (!impl_->symbol_table->GetValue(name)) {
impl_->symbol_table->Register(name);
}
impl_->frame.AddArgument(impl_->symbol_table->GetValue(name));
}
void OpExecutableBuilder::AppendArgument(Value* value) {
impl_->frame.AddArgument(value);
}
KernelFrame& OpExecutable::frame() { return impl_->frame; }
const KernelFrame& OpExecutable::frame() const { return impl_->frame; }
void OpExecutableBuilder::SetResults(llvm::ArrayRef<std::string> result_names) {
llvm::SmallVector<Value*, 3> results;
for (size_t result_id = 0; result_id < result_names.size(); result_id++) {
Value* value = impl_->symbol_table->Register(result_names[result_id]);
results.push_back(value);
}
impl_->frame.SetResults(results);
}
void OpExecutableBuilder::SetResults(llvm::ArrayRef<Value*> results) {
impl_->frame.SetResults(results);
}
void OpExecutableBuilder::AppendAttribute(Value* value) {
impl_->frame.AddAttribute(value);
}
OpExecutableBuilder::OpExecutableBuilder(OpExecutableBuilder&& other)
: OpExecutable(other.impl_.release()) {}
MlirFunctionExecutable* OpExecutableBuilder::CreateFunctionExecutable(
mlir::FuncOp op, MlirToRuntimeTranslator::function_defs_t* function_defs) {
CHECK(!impl_->mlir_function_executable);
impl_->mlir_function_executable.reset(
new MlirFunctionExecutable(op, impl_->kernel_registry, *function_defs));
return impl_->mlir_function_executable.get();
}
MlirFunctionExecutable* OpExecutableBuilder::CreateFunctionExecutable(
mlir::Region* region,
mlir::FunctionType func_type,
function_defs_t* function_defs) {
CHECK(!impl_->mlir_function_executable);
impl_->mlir_function_executable.reset(new MlirFunctionExecutable(
region, func_type, impl_->kernel_registry, *function_defs));
return impl_->mlir_function_executable.get();
}
void OpExecutable::Execute() {
#ifndef NDEBUG
VLOG(3) << "execute " << name()
<< " --- frame args: " << impl_->frame.GetNumArgs() << " results "
<< impl_->frame.GetNumResults() << " attributes "
<< impl_->frame.GetNumAttributes();
for (int i = 0; i < impl_->frame.GetNumArgs(); i++) {
VLOG(3) << "function arg: " << impl_->frame.GetArgAt(i);
}
for (int i = 0; i < impl_->frame.GetNumResults(); i++) {
VLOG(3) << "function result: " << impl_->frame.GetResults()[i];
}
#endif
if (impl_->to_execute()) {
impl_->kernel_impl(&impl_->frame);
impl_->MarkRun();
}
}
OpExecutable::~OpExecutable() {}
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/ArrayRef.h>
#include <memory>
#include <string>
#include <unordered_map>
#include "mlir/IR/Function.h"
#include "mlir/IR/Region.h"
namespace mlir {
class FuncOp;
} // namespace mlir
namespace infrt::host_context {
class SymbolTable;
class KernelRegistry;
class KernelFrame;
class Value;
class CoreRuntimeBuilder;
class MlirFunctionExecutable;
/**
* OpExecutable is a runtime executable instance for an operation. It captures
* all the information(Tensors, attributes
* and so on) needed for execution.
* With the SymbolTable and op definition, it create and hold a KernelFrame once
* and execute any times.
*/
class OpExecutable {
public:
KernelFrame& frame();
const KernelFrame& frame() const;
void Execute();
const std::string& name() const;
~OpExecutable();
protected:
class Impl;
explicit OpExecutable(Impl* impl);
std::unique_ptr<Impl> impl_;
};
/**
* Builder to help contruct an OpExecutable.
*/
class OpExecutableBuilder : public OpExecutable {
public:
using function_defs_t = std::unordered_map<std::string, mlir::FuncOp>;
OpExecutableBuilder(const std::string& op_name,
SymbolTable* symbol_table,
KernelRegistry* kernel_registry = nullptr);
OpExecutableBuilder(OpExecutableBuilder&& other);
void AppendArgument(const std::string& name);
void AppendArgument(Value* value);
void SetResults(llvm::ArrayRef<std::string> result_names);
void SetResults(llvm::ArrayRef<Value*> results);
void AppendAttribute(Value* value);
MlirFunctionExecutable* CreateFunctionExecutable(
mlir::FuncOp op, function_defs_t* function_defs);
MlirFunctionExecutable* CreateFunctionExecutable(
mlir::Region* region,
mlir::FunctionType func_type,
function_defs_t* function_defs);
};
} // namespace infrt::host_context
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/op_executable.h"
#include <gtest/gtest.h>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/host_context/symbol_table.h"
namespace infrt {
namespace host_context {
int add(int a, int b) { return a + b; }
TEST(OpExecutable, basic) {
// register kernel
KernelRegistry registry;
registry.AddKernel("infrt.test.add.i32", INFRT_KERNEL(add));
SymbolTable table;
table.Register("a", 1);
table.Register("b", 2);
OpExecutableBuilder executable("infrt.test.add.i32", &table, &registry);
executable.AppendArgument("a");
executable.AppendArgument("b");
executable.SetResults({"c"});
executable.Execute();
// check the kernel frame has the result.
auto results = executable.frame().GetResults();
ASSERT_EQ(results.size(), 1UL);
ASSERT_EQ(results.front()->get<int32_t>(), 3);
// check symbol table contains the same result instance.
LOG(INFO) << "type: " << table.GetValue("c")->type_info();
int c = table.GetValue("c")->get<int32_t>();
ASSERT_EQ(c, 3);
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/symbol_table.h"
#include <string>
namespace infrt {
namespace host_context {
struct SymbolTable::Impl {
std::unordered_map<std::string, ValueRef> data;
};
SymbolTable::SymbolTable() : impl_(new Impl) {}
Value* SymbolTable::Register(const std::string& key) {
CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]";
auto newitem = ValueRef(new Value);
impl_->data.emplace(key, newitem);
return newitem.get();
}
Value* SymbolTable::Register(const std::string& key, ValueRef value) {
CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]";
impl_->data.emplace(key, value);
return value.get();
}
Value* SymbolTable::GetValue(const std::string& key) const {
auto it = impl_->data.find(std::string(key));
return it != impl_->data.end() ? it->second.get() : nullptr;
}
// @{
#define REGISTER_TYPE__(T) \
template <> \
T SymbolTable::Get<T>(const std::string& key) { \
auto it = impl_->data.find(std::string(key)); \
CHECK(it != impl_->data.end()) << "No value called " << key; \
return it->second->get<T>(); \
}
REGISTER_TYPE__(int32_t);
REGISTER_TYPE__(float);
REGISTER_TYPE__(double);
REGISTER_TYPE__(int64_t);
#undef REGISTER_TYPE__
// @}
SymbolTable::~SymbolTable() {}
size_t SymbolTable::size() const { return impl_->data.size(); }
// @{
#define REGISTER_TYPE__(T) \
template <> \
Value* SymbolTable::Register(const std::string& key, T&& v) { \
CHECK(!impl_->data.count(key)) << "Duplicate register [" << key << "]"; \
auto newitem = ValueRef(v); \
impl_->data.emplace(key, newitem); \
return newitem.get(); \
}
REGISTER_TYPE__(int)
REGISTER_TYPE__(float)
REGISTER_TYPE__(double)
REGISTER_TYPE__(bool)
#undef REGISTER_TYPE__
// @}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include <memory>
#include "paddle/infrt/host_context/value.h"
namespace infrt {
namespace host_context {
/**
* SymbolTable holds all the states of the kernel graph in the runtime.
*/
class SymbolTable {
public:
SymbolTable();
/**
* Register a state called \p key.
*/
Value* Register(const std::string& key);
Value* Register(const std::string& key, ValueRef value);
/**
* Register a state and set value.
*/
template <typename T>
Value* Register(const std::string& key, T&& v);
size_t size() const;
/**
* Get a state called \p key.
*/
Value* GetValue(const std::string& key) const;
template <typename T>
T Get(const std::string& key);
~SymbolTable();
private:
class Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/value.h"
#include "paddle/infrt/tensor/dense_tensor_view.h"
namespace infrt {
namespace host_context {
ValueRef::ValueRef(int32_t val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(int64_t val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(float val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(double val) : Shared<Value>(new Value(val)) {}
ValueRef::ValueRef(bool val) : Shared<Value>(new Value(val)) {}
const char* Value::type_info() const { return __type_info__; }
void CopyTo(const Value& from, Value* to) {
CHECK(from.valid()) << "from value is not valid, can't be copied";
CHECK(to) << "to is not valid";
visit(
[&](auto&& arg) {
using T = std::decay_t<decltype(arg)>;
if (std::is_same<T, int16_t>::value)
to->data = arg;
else if (std::is_same<T, int32_t>::value)
to->data = arg;
else if (std::is_same<T, float>::value)
to->data = arg;
else if (std::is_same<T, double>::value)
to->data = arg;
else if (std::is_same<T, uint32_t>::value)
to->data = arg;
else if (std::is_same<T, uint64_t>::value)
to->data = arg;
else if (std::is_same<T, bool>::value)
to->data = arg;
else if (std::is_same<T, tensor::TensorShape>::value)
to->data = arg;
else if (std::is_same<T, MlirFunctionExecutable*>::value)
to->data = arg;
else if (std::is_same<T, tensor::DenseHostTensor>::value)
to->data = arg;
else if (std::is_same<T, std::vector<int16_t>>::value)
to->data = arg;
else if (std::is_same<T, std::vector<int64_t>>::value)
to->data = arg;
else if (std::is_same<T, tensor::TensorMap>::value)
to->data = arg;
else
LOG(FATAL) << "Not supported Value copy: " << typeid(T).name();
},
from.data);
}
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <llvm/ADT/SmallVector.h>
#include <string>
#include <utility>
#include <vector>
#include "paddle/infrt/common/object.h"
#include "paddle/infrt/common/shared.h"
#include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/support/variant.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/dense_tensor_view.h"
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt {
namespace host_context {
struct MlirFunctionExecutable;
using ValueVariantType = Variant<int16_t,
int32_t,
int64_t,
float,
double,
bool,
std::string,
tensor::TensorShape,
tensor::DenseHostTensor,
MlirFunctionExecutable*,
tensor::TensorMap,
std::vector<int16_t>,
std::vector<int32_t>,
std::vector<int64_t>,
std::vector<float>,
std::vector<double>>;
//! Copy content from \param from to \param to.
void CopyTo(const Value& from, Value* to);
/**
* Represents any data type for value in host context.
*/
class Value : public common::Object {
public:
using variant_type = ValueVariantType;
explicit Value() {} // NOLINT
explicit Value(int32_t x) : data(x) {}
explicit Value(int64_t x) : data(x) {}
explicit Value(float x) : data(x) {}
explicit Value(double x) : data(x) {}
explicit Value(bool x) : data(x) {}
explicit Value(std::string x) : data(x) {}
explicit Value(tensor::TensorMap&& x) : data(x) {}
explicit Value(std::vector<int16_t>&& x) : data(x) {}
explicit Value(std::vector<int32_t>&& x) : data(x) {}
explicit Value(std::vector<int64_t>&& x) : data(x) {}
explicit Value(std::vector<float>&& x) : data(x) {}
explicit Value(std::vector<double>&& x) : data(x) {}
explicit Value(tensor::TensorShape&& x) : data(std::move(x)) {}
explicit Value(tensor::DenseHostTensor&& x) : data(std::move(x)) {}
explicit Value(MlirFunctionExecutable* x) : data(x) {}
template <typename T>
const T& get() const {
return data.get<T>();
}
template <typename T>
T& get() {
return data.get<T>();
}
template <typename T>
void set(T&& v) {
data = std::move(v);
}
void set(Value* v) { data = std::move(v->data); }
bool valid() const { return true; }
const char* type_info() const override;
friend void CopyTo(const Value& from, Value* to);
private:
ValueVariantType data;
static constexpr const char* __type_info__ = "host_context_value";
};
/**
* Represents a counted reference of a Value.
*/
class ValueRef : common::Shared<Value> {
public:
ValueRef() = default;
explicit ValueRef(Value* n) : common::Shared<Value>(n) {}
explicit ValueRef(int32_t val);
explicit ValueRef(int64_t val);
explicit ValueRef(float val);
explicit ValueRef(double val);
explicit ValueRef(bool val);
using common::Shared<Value>::get;
using common::Shared<Value>::Reset;
using common::Shared<Value>::operator->;
using common::Shared<Value>::operator*;
//! Get a readonly data.
template <typename T>
const T& get() const {
CHECK(p_);
return p_->get<T>();
}
template <typename T>
T& get() {
CHECK(p_);
return p_->get<T>();
}
//! Assign a data.
template <typename T>
void Assign(const T& x) {
if (!p_) {
p_ = common::make_shared<Value>();
}
*p_ = x;
}
template <typename T, typename... Args>
void Assign(Args... args) {
p_ = common::make_shared<T>(std::forward<Args>(args)...);
}
inline bool IsValid() { return p_; }
};
} // namespace host_context
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/host_context/value.h"
#include <gtest/gtest.h>
namespace infrt {
namespace host_context {
TEST(ValueRef, test) {
ValueRef x(12);
ASSERT_EQ(x.get<int>(), 12);
ValueRef y(1.2f);
ASSERT_EQ(y.get<float>(), 1.2f);
ValueRef z(true);
ASSERT_EQ(z.get<bool>(), true);
}
} // namespace host_context
} // namespace infrt
core_gather_headers()
gather_srcs(infrt_src SRCS
basic_kernels.cc
test_kernels.cc
tensor_shape_kernels.cc
tensor_kernels.cc
control_flow_kernels.cc
)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/basic_kernels.h"
#include <iostream>
#include <string>
#include "llvm/Support/raw_ostream.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
using infrt::host_context::Attribute;
namespace infrt::kernel {
template <typename T>
T add(T a, T b) {
return a + b;
}
template <typename T>
T sub(T a, T b) {
return a - b;
}
template <typename T>
T mul(T a, T b) {
return a * b;
}
template <typename T>
T div(T a, T b) {
return a / b;
}
template <typename T>
void print(T a) {
std::cout << a << std::endl;
}
static std::string GetString(Attribute<std::string> value) {
return value.get();
}
static void PrintString(const std::string &str) {
llvm::outs() << "string = " << str << '\n';
llvm::outs().flush();
}
void RegisterBasicKernels(host_context::KernelRegistry *registry) {
RegisterIntBasicKernels(registry);
RegisterFloatBasicKernels(registry);
registry->AddKernel("infrt.get_string", INFRT_KERNEL(GetString));
registry->AddKernel("infrt.print_string", INFRT_KERNEL(PrintString));
}
void RegisterIntBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.add.i32", INFRT_KERNEL(add<int32_t>));
registry->AddKernel("infrt.sub.i32", INFRT_KERNEL(sub<int32_t>));
registry->AddKernel("infrt.mul.i32", INFRT_KERNEL(mul<int32_t>));
registry->AddKernel("infrt.div.i32", INFRT_KERNEL(div<int32_t>));
registry->AddKernel("infrt.print.i32", INFRT_KERNEL(print<int32_t>));
}
void RegisterFloatBasicKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.add.f32", INFRT_KERNEL(add<float>));
registry->AddKernel("infrt.sub.f32", INFRT_KERNEL(sub<float>));
registry->AddKernel("infrt.mul.f32", INFRT_KERNEL(mul<float>));
registry->AddKernel("infrt.div.f32", INFRT_KERNEL(div<float>));
registry->AddKernel("infrt.print.f32", INFRT_KERNEL(print<float>));
}
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace infrt::host_context {
struct KernelRegistry;
} // namespace infrt::host_context
namespace infrt::kernel {
/**
* Register all the basic kernels to \p registry.
*/
void RegisterBasicKernels(host_context::KernelRegistry* registry);
void RegisterIntBasicKernels(host_context::KernelRegistry* registry);
void RegisterFloatBasicKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/control_flow_kernels.h"
#include <glog/logging.h>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
namespace infrt {
namespace kernel {
static void INFRTCall(
host_context::RemainingArguments args,
host_context::RemainingResults results,
host_context::Attribute<host_context::MlirFunctionExecutable*> fn) {
VLOG(3) << "running call kernel ...";
CHECK_EQ(fn.get()->num_arguments(), args.size());
CHECK_EQ(fn.get()->num_results(), results.size());
for (auto& v : results.values()) {
CHECK(v.get());
}
fn.get()->Execute(args.values(), results.values());
}
void RegisterControlFlowKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("infrt.call", INFRT_KERNEL(INFRTCall));
}
} // namespace kernel
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/infrt/host_context/function.h"
#include "paddle/infrt/host_context/kernel_utils.h"
namespace infrt {
namespace host_context {
struct KernelRegistry;
} // namespace host_context
namespace kernel {
void RegisterControlFlowKernels(host_context::KernelRegistry* registry);
} // namespace kernel
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/tensor_kernels.h"
#include <iostream>
#include <vector>
#include "paddle/infrt/common/global.h"
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include "paddle/infrt/tensor/dense_tensor_view.h"
#include "paddle/infrt/tensor/tensor_map.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
using namespace host_context; // NOLINT
using namespace tensor; // NOLINT
/// ===== Kernel begin ====
template <typename T>
DenseHostTensor CreateUninitTensor(Attribute<std::vector<int64_t>> shape) {
const auto &shape_data = shape.get();
auto array = llvm::ArrayRef<int64_t>(shape_data.data(), shape_data.size());
auto type = GetDType<T>();
return DenseHostTensor(TensorShape(array), type);
}
void PrintTensor(const DenseHostTensor &tensor) {
std::cout << tensor << std::endl;
}
template <typename T>
void FillTensorWithConstant(DenseHostTensor *tensor, Attribute<T> v) {
MutableDTArrayView<T>(tensor).Fill(v.get());
}
TensorMap LoadParams(const std::string &path) {
return *(infrt::tensor::LoadParams(path));
}
DenseHostTensor GetParam(TensorMap map, Attribute<std::string> nameAttr) {
auto &name = nameAttr.get();
return *(map[name]);
}
DenseHostTensor ShallowCopyTensor(DenseHostTensor v) { return v; }
/// ===== Kernel end ====
void RegisterTensorKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("dt.create_uninit_tensor.f32",
INFRT_KERNEL(CreateUninitTensor<float>));
registry->AddKernelAttrNameList("dt.create_uninit_tensor.f32", {"shape"});
registry->AddKernel("dt.print_tensor", INFRT_KERNEL(PrintTensor));
registry->AddKernel("dt.fill_tensor_with_constant.f32",
INFRT_KERNEL(FillTensorWithConstant<float>));
registry->AddKernel("dt.fill_tensor_with_constant.f64",
INFRT_KERNEL(FillTensorWithConstant<double>));
registry->AddKernel("dt.load_params", INFRT_KERNEL(LoadParams));
registry->AddKernel("dt.get_param", INFRT_KERNEL(GetParam));
registry->AddKernel("dt.shallow_copy_tensor",
INFRT_KERNEL(ShallowCopyTensor));
}
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace infrt::host_context {
struct KernelRegistry;
} // namespace infrt::host_context
namespace infrt::kernel {
void RegisterTensorKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/tensor_shape_kernels.h"
#include <llvm/ADT/ArrayRef.h>
#include <llvm/ADT/SmallVector.h>
#include <llvm/Support/raw_os_ostream.h>
#include <iostream>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt::kernel {
void PrintShape(const tensor::TensorShape& shape) {
llvm::raw_os_ostream oos(std::cout);
oos << shape << '\n';
}
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry) {
registry->AddKernel("ts.print_shape", INFRT_KERNEL(PrintShape));
}
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace infrt::host_context {
class KernelRegistry;
} // namespace infrt::host_context
namespace infrt::kernel {
void RegisterTensorShapeKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/kernel/test_kernels.h"
#include <llvm/ADT/FunctionExtras.h>
#include <llvm/Support/raw_ostream.h>
#include <cassert>
#include <chrono>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <string>
#include "paddle/infrt/host_context/kernel_registry.h"
#include "paddle/infrt/host_context/kernel_utils.h"
#include "paddle/infrt/host_context/mlir_function_executable.h"
#include "paddle/infrt/tensor/dense_host_tensor.h"
using infrt::host_context::Attribute;
using infrt::host_context::MlirFunctionExecutable;
using infrt::host_context::RemainingArguments;
namespace infrt::kernel {
namespace {
class BenchmarkStats {
public:
BenchmarkStats(std::string name,
int num_warmup_runs,
int max_count,
std::chrono::microseconds benchmark_duration)
: name_{name},
num_warmup_runs_{num_warmup_runs},
max_count_{max_count},
benchmark_duration_{benchmark_duration} {}
void StartRun() {
++cur_count_;
// Start recording CPU time.
cur_start_walltime_ = std::chrono::steady_clock::now();
cur_start_cpu_ = std::clock();
}
void StopRun() {
// Do not collect the runtime statistics if we are still in the warm up
// period.
if (cur_count_ <= num_warmup_runs_) return;
// Stop the CPU timer.
std::clock_t cur_stop_cpu_ = std::clock();
// Stop the wall clock timer.
auto cur_stop_walltime_ = std::chrono::steady_clock::now();
// Collect the wall clock duration.
auto duration_walltime_ = cur_stop_walltime_ - cur_start_walltime_;
run_times_walltime_.push_back(duration_walltime_);
// Collect the CPU duration in microseconds.
// First cast to integer that represents microseconds with truncation, as
// does std::chrono::duration_cast. Then cast to std::chrono::microseconds.
std::clock_t duration_cpu_raw = cur_stop_cpu_ - cur_start_cpu_;
auto duration_cpu_ = static_cast<std::chrono::nanoseconds>(
static_cast<int64_t>(1e9 * duration_cpu_raw / CLOCKS_PER_SEC));
run_times_cpu_.push_back(duration_cpu_);
total_duration_walltime_ += duration_walltime_;
total_duration_cpu_ += duration_cpu_;
}
// Return if we should we run more rounds.
bool MoreRun() const {
return cur_count_ < max_count_ + num_warmup_runs_ &&
total_duration_walltime_ < benchmark_duration_;
}
// Summarize the benchmark results.
void Summarize() {
std::sort(run_times_walltime_.begin(), run_times_walltime_.end());
std::sort(run_times_cpu_.begin(), run_times_cpu_.end());
auto percentile = [](
double p, const std::vector<std::chrono::nanoseconds> &run_times) {
assert(p >= 0.0 && p <= 1.0);
return run_times[run_times.size() * p];
};
// BM: prefix is added to make grepping results from lit output easier.
std::string prefix;
llvm::raw_string_ostream(prefix) << "BM:" << name_ << ':';
auto cpu_utilization =
total_duration_cpu_.count() * 100.0 / total_duration_walltime_.count();
llvm::outs() << prefix << "Count: " << run_times_walltime_.size() << '\n';
llvm::outs() << prefix
<< "Duration(ns): " << total_duration_walltime_.count()
<< '\n';
llvm::outs() << prefix
<< "Time Min(ns): " << run_times_walltime_.front().count()
<< '\n';
llvm::outs() << prefix
<< "Time Max(ns): " << run_times_walltime_.back().count()
<< '\n';
llvm::outs() << prefix << "Time 50%(ns): "
<< percentile(0.5, run_times_walltime_).count() << '\n';
llvm::outs() << prefix << "Time 95%(ns): "
<< percentile(0.95, run_times_walltime_).count() << '\n';
llvm::outs() << prefix << "Time 99%(ns): "
<< percentile(0.99, run_times_walltime_).count() << '\n';
// Log CPU time statistics.
llvm::outs() << prefix
<< "CPU Duration(ns): " << total_duration_cpu_.count() << '\n';
llvm::outs() << prefix << "CPU Min(ns): " << run_times_cpu_.front().count()
<< '\n';
llvm::outs() << prefix << "CPU Max(ns): " << run_times_cpu_.back().count()
<< '\n';
llvm::outs() << prefix
<< "CPU 50%(ns): " << percentile(0.5, run_times_cpu_).count()
<< '\n';
llvm::outs() << prefix
<< "CPU 95%(ns): " << percentile(0.95, run_times_cpu_).count()
<< '\n';
llvm::outs() << prefix
<< "CPU 99%(ns): " << percentile(0.99, run_times_cpu_).count()
<< '\n';
llvm::outs() << prefix << "CPU utilization(percent): " << cpu_utilization
<< "\n";
llvm::outs().flush();
}
private:
const std::string name_;
const int num_warmup_runs_;
const int max_count_;
int cur_count_ = 0;
const std::chrono::nanoseconds benchmark_duration_;
std::chrono::nanoseconds total_duration_walltime_{};
std::chrono::nanoseconds total_duration_cpu_{};
std::chrono::time_point<std::chrono::steady_clock> cur_start_walltime_{};
std::clock_t cur_start_cpu_;
std::vector<std::chrono::nanoseconds> run_times_walltime_;
// CPU run times in microseconds.
std::vector<std::chrono::nanoseconds> run_times_cpu_;
};
} // anonymous namespace
// This op benchmarks the input function by running the function in a loop
// up to a max count or max time as specified in the function's attributes.
//
// Attributes:
// duration_secs: Benchmark duration in seconds.
// max_count: Max run count of input function.
// name: The name used to tag the benchmark results.
// num_warmup_runs: Number of warm up runs before benchmarking starts.
// fn: The input function to be benchmarked.
static void benchmark(RemainingArguments args,
host_context::RemainingResults results,
Attribute<int32_t> duration_secs,
Attribute<int32_t> max_count,
Attribute<std::string> name,
Attribute<int32_t> num_warmup_runs,
Attribute<MlirFunctionExecutable *> fn) {
BenchmarkStats bm_stats{name.get(),
num_warmup_runs.get(),
max_count.get(),
std::chrono::seconds(duration_secs.get())};
while (bm_stats.MoreRun()) {
bm_stats.StartRun();
fn.get()->Execute(args.values(), results.values(), true);
bm_stats.StopRun();
}
bm_stats.Summarize();
}
// Just copy the input to the result.
tensor::DenseHostTensor ShadowCopyTensor(tensor::DenseHostTensor src) {
return src;
}
void RegisterTestKernels(host_context::KernelRegistry *registry) {
registry->AddKernel("infrt.benchmark", INFRT_KERNEL(benchmark));
registry->AddKernel("infrt.test.shadow_copy_tensor",
INFRT_KERNEL(ShadowCopyTensor));
}
} // namespace infrt::kernel
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace infrt::host_context {
struct KernelRegistry;
} // namespace infrt::host_context
namespace infrt::kernel {
/**
* Register all the test kernels to registry.
*/
void RegisterTestKernels(host_context::KernelRegistry* registry);
} // namespace infrt::kernel
proto_library(paddle_framework_proto SRCS framework.proto)
add_subdirectory(cpp)
add_subdirectory(pb)
core_gather_headers()
gather_srcs(infrt_src SRCS
model_parser.cc
scope.cc
tensor.cc
)
foreach(cpp ${SRCS})
set(infrt_src
"${infrt_src};infrt/paddle/${cpp}"
CACHE INTERNAL "")
endforeach()
file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h)
foreach(header ${includes})
set(core_includes "${core_includes};${header}" CACHE INTERNAL "")
endforeach()
core_gather_headers()
gather_srcs(infrt_src SRCS
)
foreach(cpp ${SRCS})
set(infrt_src
"${infrt_src};infrt/paddle/cpp/${cpp}"
CACHE INTERNAL "")
endforeach()
file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h)
foreach(header ${includes})
set(core_includes "${core_includes};${header}" CACHE INTERNAL "")
endforeach()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <sstream>
#include <string>
#include <vector>
namespace infrt::paddle::cpp {
/*
* Compatible interfaces for all the different kinds of XXXDesc. All the XXXDesc
* classes should implement this.
*/
class VarDescAPI {
public:
enum class Type {
// Pod Types
BOOL = 0,
INT16,
INT32,
INT64,
FP16,
FP32,
FP64,
// Tensor<size_t> is used in C++.
SIZE_T,
UINT8,
INT8,
// Other types that may need additional descriptions
LOD_TENSOR,
SELECTED_ROWS,
FEED_MINIBATCH,
FETCH_LIST,
STEP_SCOPES,
LOD_RANK_TABLE,
LOD_TENSOR_ARRAY,
PLACE_LIST,
READER,
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW,
TUPLE
};
using VarDataType = Type;
virtual ~VarDescAPI() = default;
// Get var's name
virtual std::string Name() const = 0;
// Set var's name
virtual void SetName(std::string name) = 0;
// Get var's type
virtual Type GetType() const = 0;
// Set var's type
virtual void SetType(Type type) = 0;
// Tell whether var is persistable or not
virtual bool Persistable() const = 0;
// Set var to be persistable or not
virtual void SetPersistable(bool persistable) = 0;
// Get var's shape
virtual std::vector<int64_t> GetShape() const = 0;
// Set var's shape
virtual void SetShape(const std::vector<int64_t>& dims) = 0;
};
/*
* NOTE Some interfaces are weried, we remain them unchanged to keep compatible
* with framework::OpDesc in Fluid framework.
*/
class OpDescAPI {
public:
// The AttrType is used to make the proto::AttrType portable.
enum class AttrType {
INT = 0,
FLOAT = 1,
STRING = 2,
INTS = 3,
FLOATS = 4,
STRINGS = 5,
BOOLEAN = 6,
BOOLEANS = 7,
BLOCK = 8,
LONG = 9,
BLOCKS = 10,
LONGS = 11,
UNK,
};
virtual ~OpDescAPI() = default;
/// Get operator's type.
virtual std::string Type() const = 0;
/// Set operator's type.
virtual void SetType(const std::string& type) = 0;
/// Get arguments given the parameter.
virtual std::vector<std::string> Input(const std::string& param) const = 0;
/// Get parameters.
virtual std::vector<std::string> InputArgumentNames() const = 0;
/// Get arguments given the parameter.
virtual std::vector<std::string> Output(const std::string& param) const = 0;
/// Get parameters.
virtual std::vector<std::string> OutputArgumentNames() const = 0;
/// Set a input given the parameter and arguments.
virtual void SetInput(const std::string& param,
const std::vector<std::string>& args) = 0;
virtual void SetOutput(const std::string& param,
const std::vector<std::string>& args) = 0;
/// Tell whether this desc has an attribute.
virtual bool HasAttr(const std::string& name) const = 0;
/// Get the type of an attribute.
virtual AttrType GetAttrType(const std::string& name) const = 0;
virtual std::vector<std::string> AttrNames() const = 0;
/// Set an attribute.
template <typename T>
void SetAttr(const std::string& name, const T& v);
/// Get an attribute.
template <typename T>
T GetAttr(const std::string& name) const;
std::string Repr() const {
std::stringstream ss;
ss << Type();
ss << "(";
for (auto& arg : InputArgumentNames()) {
ss << arg << ":";
for (auto val : Input(arg)) {
ss << val << " ";
}
}
ss << ") -> (";
for (auto& arg : OutputArgumentNames()) {
ss << arg << ":";
for (auto val : Output(arg)) {
ss << val << " ";
}
}
ss << ")";
return ss.str();
}
};
class BlockDescAPI {
public:
virtual ~BlockDescAPI() = default;
virtual int32_t Idx() const = 0;
virtual void SetIdx(int32_t idx) = 0;
virtual int32_t ParentIdx() const = 0;
virtual void SetParentIdx(int32_t idx) = 0;
virtual size_t VarsSize() const = 0;
virtual void ClearVars() = 0;
// NOTE: This ugly method is used to compatible interfaces between cpp and
// pb/nb backends
// TODO(sangoly): refine this
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T* AddVar();
virtual size_t OpsSize() const = 0;
virtual void ClearOps() = 0;
// NOTE: This ugly method is used to compatible interfaces between cpp and
// pb/nb backends
// TODO(sangoly): refine this
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T* AddOp();
virtual int32_t ForwardBlockIdx() const = 0;
virtual void SetForwardBlockIdx(int32_t idx) = 0;
};
class ProgramDescAPI {
public:
virtual ~ProgramDescAPI() = default;
virtual size_t BlocksSize() const = 0;
virtual void ClearBlocks() = 0;
// NOTE: This ugly method is used to compatible interfaces between cpp and
// pb/nb backends
// TODO(sangoly): refine this
template <typename T>
T* GetBlock(int32_t idx);
template <typename T>
T* AddBlock();
virtual bool HasVersion() const = 0;
virtual int64_t Version() const = 0;
virtual void SetVersion(int64_t version) = 0;
};
} // namespace infrt::paddle::cpp
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle.framework.proto;
// Any incompatible changes to ProgramDesc and its dependencies should
// raise the version defined version.h.
//
// Serailization and Deserialization codes should be modified in a way
// that supports old versions following the version and compatibility policy.
message Version { optional int64 version = 1 [ default = 0 ]; }
enum AttrType {
INT = 0;
FLOAT = 1;
STRING = 2;
INTS = 3;
FLOATS = 4;
STRINGS = 5;
BOOLEAN = 6;
BOOLEANS = 7;
BLOCK = 8;
LONG = 9;
BLOCKS = 10;
LONGS = 11;
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// derived class type.
message OpDesc {
message Attr {
required string name = 1;
required AttrType type = 2;
optional int32 i = 3;
optional float f = 4;
optional string s = 5;
repeated int32 ints = 6;
repeated float floats = 7;
repeated string strings = 8;
optional bool b = 10;
repeated bool bools = 11;
optional int32 block_idx = 12;
optional int64 l = 13;
repeated int32 blocks_idx = 14;
repeated int64 longs = 15;
};
message Var {
required string parameter = 1;
repeated string arguments = 2;
};
required string type = 3;
repeated Var inputs = 1;
repeated Var outputs = 2;
repeated Attr attrs = 4;
optional bool is_target = 5 [ default = false ];
};
// OpProto describes a C++ framework::OperatorBase derived class.
message OpProto {
// VarProto describes the C++ type framework::Variable.
message Var {
required string name = 1;
required string comment = 2;
optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool dispensable = 5 [ default = false ];
}
// AttrProto describes the C++ type Attribute.
message Attr {
required string name = 1;
required AttrType type = 2;
required string comment = 3;
// If that attribute is generated, it means the Paddle third
// language binding has responsibility to fill that
// attribute. End-User should not set that attribute.
optional bool generated = 4 [ default = false ];
}
required string type = 1;
repeated Var inputs = 2;
repeated Var outputs = 3;
repeated Attr attrs = 4;
required string comment = 5;
}
message VarType {
enum Type {
// Pod Types
BOOL = 0;
INT16 = 1;
INT32 = 2;
INT64 = 3;
FP16 = 4;
FP32 = 5;
FP64 = 6;
// Tensor<size_t> is used in C++.
SIZE_T = 19;
UINT8 = 20;
INT8 = 21;
// Other types that may need additional descriptions
LOD_TENSOR = 7;
SELECTED_ROWS = 8;
FEED_MINIBATCH = 9;
FETCH_LIST = 10;
STEP_SCOPES = 11;
LOD_RANK_TABLE = 12;
LOD_TENSOR_ARRAY = 13;
PLACE_LIST = 14;
READER = 15;
// Any runtime decided variable type is raw
// raw variables should manage their own allocations
// in operators like nccl_op
RAW = 17;
TUPLE = 18;
}
required Type type = 1;
message TensorDesc {
// Should only be PODType. Is enforced in C++
required Type data_type = 1;
repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480]
}
optional TensorDesc selected_rows = 2;
message LoDTensorDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorDesc lod_tensor = 3;
message LoDTensorArrayDesc {
required TensorDesc tensor = 1;
optional int32 lod_level = 2 [ default = 0 ];
}
optional LoDTensorArrayDesc tensor_array = 4;
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }
optional ReaderDesc reader = 5;
message Tuple { repeated Type element_type = 1; }
optional Tuple tuple = 7;
}
message VarDesc {
required string name = 1;
required VarType type = 2;
optional bool persistable = 3 [ default = false ];
// True if the variable is an input data and
// have to check the feed data shape and dtype
optional bool need_check_feed = 4 [ default = false ];
}
message BlockDesc {
required int32 idx = 1;
required int32 parent_idx = 2;
repeated VarDesc vars = 3;
repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [ default = -1 ];
}
// CompatibleInfo is used to determine if a feature is compatible and
// provides the information.
message CompatibleInfo {
enum Type {
COMPATIBLE = 0;
DEFINITELY_NOT = 1;
POSSIBLE = 2;
BUG_FIX = 3;
PRECISION_CHANGE = 4;
}
required string version = 1;
required Type type = 2;
}
// In some cases, Paddle Fluid may perform operator definition iterations,
// and the operator uses OpCompatibleMap for compatibility testing.
message OpCompatibleMap {
message OpCompatiblePair {
required string op_name = 1;
required CompatibleInfo compatible_info = 2;
}
repeated OpCompatiblePair pair = 1;
optional string default_required_version = 2;
}
// Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/program.md
// for more details.
// TODO(panyx0718): A model can have multiple programs. Need a
// way to distinguish them. Maybe ID or name?
message ProgramDesc {
reserved 2; // For backward compatibility.
repeated BlockDesc blocks = 1;
optional Version version = 4;
optional OpCompatibleMap op_compatible_map = 3;
}
\ No newline at end of file
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/model_parser.h"
#include <fstream>
#include <vector>
#include "paddle/infrt/common/common.h"
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/common/target.h"
#include "paddle/infrt/common/type.h"
namespace infrt::paddle {
int SizeOfType(framework_proto::VarType::Type type) {
using Type = framework_proto::VarType::Type;
switch (static_cast<int>(type)) {
#define DO(desc, type) \
case Type::VarType_Type_##desc: \
return sizeof(type);
DO(BOOL, bool);
DO(FP16, float);
DO(FP32, float);
DO(INT8, int8_t);
DO(INT16, int16_t);
DO(INT32, int);
DO(INT64, int64_t);
#undef DO
default:
LOG(FATAL) << "unknown data type " << type;
}
return -1;
}
void TensorFromStream(std::istream &is,
_Tensor_ *tensor,
const common::Target &target) {
using Type = framework_proto::VarType::Type;
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
CHECK_EQ(version, 0U) << "Only version 0 is supported";
// read tensor desc
framework_proto::VarType::TensorDesc desc;
{
// int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
CHECK(desc.ParseFromArray(buf.get(), size)) << "Cannot parse tensor desc";
}
// read tensor
std::vector<int32_t> dims_vec;
std::copy(
desc.dims().begin(), desc.dims().end(), std::back_inserter(dims_vec));
Shape dims(dims_vec);
tensor->Resize(dims);
void *buf;
size_t size = tensor->shape().numel() * SizeOfType(desc.data_type());
// alllocate memory
if (target.arch == Target::Arch::X86) {
switch (static_cast<int>(desc.data_type())) {
#define SET_TENSOR(desc, type, precision) \
case Type::VarType_Type_##desc: \
buf = tensor->mutable_data<type>(target); \
tensor->set_type(precision); \
break
SET_TENSOR(FP32, float, Float(32));
SET_TENSOR(INT8, int8_t, Int(8));
SET_TENSOR(INT16, int16_t, Int(16));
SET_TENSOR(INT32, int32_t, Int(32));
SET_TENSOR(INT64, int64_t, Int(64));
#undef SET_TENSOR
default:
LOG(FATAL) << "unknown type " << desc.data_type();
}
// tensor->set_persistable(true);
is.read(static_cast<char *>(buf), size);
} else if (target.arch == Target::Arch::NVGPU) {
#ifdef INFRT_WITH_CUDA
if (desc.data_type() != Type::VarType_Type_FP32)
LOG(FATAL) << "[CUDA] The type is not fp32!!";
auto *data = tensor->mutable_data<float>(target);
tensor->set_type(infrt::common::Float(32));
std::vector<float> temp(tensor->shape().numel());
// LOG(INFO) <<"[CUDA] The tensor's size is "<< tensor->shape().numel();
is.read(reinterpret_cast<char *>(temp.data()), size);
CUDA_CALL(cudaMemcpy(reinterpret_cast<void *>(data),
temp.data(),
tensor->shape().numel() * sizeof(float),
cudaMemcpyHostToDevice));
#else
LOG(FATAL) << "To use CUDA backends, you need to set WITH_CUDA ON!";
#endif
} else {
INFRT_NOT_IMPLEMENTED
}
}
void LoadLoDTensor(std::istream &is, _Variable *var, const Target &target) {
auto &tensor = var->get<Tensor>();
uint32_t version{};
is.read(reinterpret_cast<char *>(&version), sizeof(version));
VLOG(3) << "model version " << version;
// Load LoD information
uint64_t lod_level{};
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
for (uint64_t i = 0; i < lod_level; ++i) {
uint64_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::vector<uint64_t> tmp(size / sizeof(uint64_t));
is.read(reinterpret_cast<char *>(tmp.data()),
static_cast<std::streamsize>(size));
// lod[i] = tmp;
}
TensorFromStream(is, tensor.operator->(), target);
}
void ReadBinaryFile(const std::string &filename, std::string *contents) {
std::ifstream fin(filename, std::ios::in | std::ios::binary);
CHECK(fin.is_open()) << "Cannot open file: " << filename;
fin.seekg(0, std::ios::end);
auto size = fin.tellg();
contents->clear();
contents->resize(size);
fin.seekg(0, std::ios::beg);
fin.read(&(contents->at(0)), contents->size());
fin.close();
}
std::unique_ptr<framework_proto::ProgramDesc> LoadProgram(
const std::string &path, bool program_from_memory) {
std::unique_ptr<framework_proto::ProgramDesc> main_program(
new framework_proto::ProgramDesc);
if (!program_from_memory) {
std::string desc_str;
ReadBinaryFile(path, &desc_str);
main_program->ParseFromString(desc_str);
} else {
main_program->ParseFromString(path);
}
return main_program;
}
void LoadParams(const std::string &path) {}
// Load directly to CPU, and latter transfer to other devices.
void LoadParam(const std::string &path, _Variable *out, const Target &target) {
std::ifstream fin(path, std::ios::binary);
CHECK(fin.is_open()) << "failed to open file " << path;
LoadLoDTensor(fin, out, target);
}
} // namespace infrt::paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/paddle/pb/block_desc.h"
#include "paddle/infrt/paddle/pb/op_desc.h"
#include "paddle/infrt/paddle/pb/program_desc.h"
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/paddle/tensor.h"
namespace infrt::paddle {
namespace framework_proto = ::paddle::framework::proto;
// Read a __model__ file.
std::unique_ptr<framework_proto::ProgramDesc> LoadProgram(
const std::string& path, bool program_from_memory = false);
void LoadLoDTensor(std::istream& is,
_Variable* var,
const common::Target& target);
// Read a single file containing all the parameters.
void LoadParams(const std::string& path);
// Load a single parameter to an output tensor.
void LoadParam(const std::string& path,
_Variable* out,
const common::Target& target);
// LoDTensor to ostream
void TensorToStream(std::ostream& os, const _Tensor_& tensor);
void TensorFromStream(
std::istream& is,
_Tensor_* tensor,
const common::Target& target = common::DefaultHostTarget());
void ReadBinaryFile(const std::string& filename, std::string* contents);
} // namespace infrt::paddle
core_gather_headers()
gather_srcs(infrt_src SRCS
var_desc.cc
op_desc.cc
block_desc.cc
program_desc.cc
)
foreach(cpp ${SRCS})
set(infrt_src
"${infrt_src};infrt/paddle/pb/${cpp}"
CACHE INTERNAL "")
endforeach()
file(GLOB includes LIST_DIRECTORIES false RELATIVE ${CMAKE_SOURCE_DIR} *.h)
foreach(header ${includes})
set(core_includes "${core_includes};${header}" CACHE INTERNAL "")
endforeach()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/pb/block_desc.h"
namespace infrt::paddle::pb {
template <>
framework_proto::VarDesc* BlockDesc::GetVar<framework_proto::VarDesc>(
int32_t idx) {
CHECK_LT(idx, static_cast<int>(VarsSize())) << "idx >= vars.size()";
return desc_->mutable_vars(idx);
}
template <>
framework_proto::VarDesc* BlockDesc::AddVar<framework_proto::VarDesc>() {
return desc_->add_vars();
}
template <>
framework_proto::OpDesc* BlockDesc::GetOp<framework_proto::OpDesc>(
int32_t idx) {
CHECK_LT(idx, static_cast<int>(OpsSize())) << "idx >= ops.size()";
return desc_->mutable_ops(idx);
}
template <>
framework_proto::OpDesc* BlockDesc::AddOp<framework_proto::OpDesc>() {
return desc_->add_ops();
}
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace framework_proto = ::paddle::framework::proto;
class BlockDesc : public cpp::BlockDescAPI {
public:
BlockDesc() = delete;
explicit BlockDesc(framework_proto::BlockDesc* desc) : desc_(desc) {
CHECK(desc_);
}
framework_proto::BlockDesc* Proto() { return desc_; }
const framework_proto::BlockDesc& ReadonlyProto() const { return *desc_; }
int32_t Idx() const override { return desc_->idx(); }
void SetIdx(int32_t idx) override { desc_->set_idx(idx); }
int32_t ParentIdx() const override { return desc_->parent_idx(); }
void SetParentIdx(int32_t idx) override { desc_->set_parent_idx(idx); }
size_t VarsSize() const override { return desc_->vars_size(); }
void ClearVars() override { desc_->clear_vars(); }
template <typename T>
T* GetVar(int32_t idx);
template <typename T>
T* AddVar();
size_t OpsSize() const override { return desc_->ops_size(); }
void ClearOps() override { desc_->clear_ops(); }
template <typename T>
T* GetOp(int32_t idx);
template <typename T>
T* AddOp();
int32_t ForwardBlockIdx() const override {
return desc_->forward_block_idx();
}
void SetForwardBlockIdx(int32_t idx) override {
desc_->set_forward_block_idx(idx);
}
private:
framework_proto::BlockDesc* desc_; // not_own
};
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/pb/op_desc.h"
namespace infrt::paddle::pb {
google::protobuf::internal::RepeatedPtrIterator<framework_proto::OpDesc_Attr>
FindAttr(framework_proto::OpDesc *desc, const std::string &name) {
auto &xs = *desc->mutable_attrs();
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) {
return x.name() == name;
});
if (it == xs.end()) {
auto *attr = xs.Add();
attr->set_name(name);
it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) {
return x.name() == name;
});
}
return it;
}
#define SET_IMPL_ONE(T, ty__, pb_f__) \
template <> \
void OpDesc::SetAttr<T>(const std::string &name, const T &v) { \
auto it = FindAttr(desc_, name); \
it->set_type(framework_proto::ty__); \
it->set_##pb_f__(v); \
}
SET_IMPL_ONE(int, INT, i);
SET_IMPL_ONE(float, FLOAT, f);
SET_IMPL_ONE(bool, BOOLEAN, b);
SET_IMPL_ONE(int64_t, LONG, l);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework_proto::INTS);
it->clear_ints();
for (auto &i : v) {
it->add_ints(i);
}
}
template <>
void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework_proto::STRING);
it->set_s(v.c_str());
}
template <>
void OpDesc::SetAttr<std::vector<float>>(const std::string &name,
const std::vector<float> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework_proto::FLOATS);
it->clear_floats();
for (auto &i : v) {
it->add_floats(i);
}
}
template <>
void OpDesc::SetAttr<std::vector<std::string>>(
const std::string &name, const std::vector<std::string> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework_proto::STRINGS);
it->clear_strings();
for (auto &i : v) {
it->add_strings(i);
}
}
template <>
void OpDesc::SetAttr<std::vector<int64_t>>(const std::string &name,
const std::vector<int64_t> &v) {
auto it = FindAttr(desc_, name);
it->set_type(framework_proto::LONGS);
it->clear_longs();
for (auto &i : v) {
it->add_longs(i);
}
}
google::protobuf::internal::RepeatedPtrIterator<
const framework_proto::OpDesc_Attr>
GetFindAttr(const framework_proto::OpDesc &desc, const std::string &name) {
auto &xs = desc.attrs();
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) {
return x.name() == name;
});
return it;
}
#define GET_ATTR_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(*desc_, name); \
return it->pb_f__(); \
}
#define GET_ATTRS_IMPL(T, pb_f__) \
template <> \
T OpDesc::GetAttr<T>(const std::string &name) const { \
auto it = GetFindAttr(*desc_, name); \
T res; \
for (const auto &v : it->pb_f__()) { \
res.push_back(v); \
} \
return res; \
}
GET_ATTR_IMPL(int32_t, i);
GET_ATTR_IMPL(int16_t, block_idx);
GET_ATTR_IMPL(float, f);
GET_ATTR_IMPL(bool, b);
GET_ATTR_IMPL(int64_t, l);
GET_ATTRS_IMPL(std::vector<int>, ints);
GET_ATTRS_IMPL(std::vector<float>, floats);
GET_ATTRS_IMPL(std::vector<std::string>, strings);
GET_ATTR_IMPL(std::string, s);
GET_ATTRS_IMPL(std::vector<int64_t>, longs);
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
#include "paddle/infrt/support/variant.h"
namespace infrt::paddle::pb {
namespace framework_proto = ::paddle::framework::proto;
using Attribute =
Variant<int, float, bool, std::vector<std::string>, std::vector<int>>;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
/*
* The lite::OpDesc, an light-weight implementation of wrapper of proto::OpDesc.
* Unlike the original one in framework::OpDesc, we remove the local members
* except the desc_, to avoid the inconsistent state, which is normal in the
* original interface and results in bugs.
*/
class OpDesc : public cpp::OpDescAPI {
public:
OpDesc() = delete;
explicit OpDesc(framework_proto::OpDesc *desc) : desc_(desc) { CHECK(desc_); }
framework_proto::OpDesc *Proto() { return desc_; }
const framework_proto::OpDesc &ReadonlyProto() const { return *desc_; }
std::string Type() const override { return desc_->type(); }
void SetType(const std::string &type) override { desc_->set_type(type); }
// Get the arguments of parameter called `param`
std::vector<std::string> Input(const std::string &param) const override {
return GetArguments(desc_->inputs(), param);
}
std::vector<std::string> InputArgumentNames() const override {
return GetArgumentNames(desc_->inputs());
}
void SetInput(const std::string &param,
const std::vector<std::string> &args) override {
SetArgument(desc_->mutable_inputs(), param, args);
}
std::vector<std::string> Output(const std::string &param) const override {
return GetArguments(desc_->outputs(), param);
}
std::vector<std::string> OutputArgumentNames() const override {
return GetArgumentNames(desc_->outputs());
}
void SetOutput(const std::string &param,
const std::vector<std::string> &args) override {
SetArgument(desc_->mutable_outputs(), param, args);
}
bool HasAttr(const std::string &name) const override {
const auto &xs = desc_->attrs();
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) {
return x.name() == name;
});
return it != xs.end();
}
AttrType GetAttrType(const std::string &name) const override {
const auto &xs = desc_->attrs();
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Attr &x) {
return x.name() == name;
});
CHECK(it != xs.end());
#define DEF_ONE(type__) \
case framework_proto::AttrType::type__: \
return AttrType::type__;
switch (it->type()) {
DEF_ONE(INT);
DEF_ONE(FLOAT);
DEF_ONE(STRING);
DEF_ONE(INTS);
DEF_ONE(FLOATS);
DEF_ONE(STRINGS);
DEF_ONE(BOOLEAN);
DEF_ONE(BOOLEANS);
DEF_ONE(BLOCK);
DEF_ONE(LONG);
DEF_ONE(BLOCKS);
DEF_ONE(LONGS);
default:
LOG(FATAL) << "Unknown attribute type";
return static_cast<AttrType>(-1);
}
#undef DEF_ONE
}
std::vector<std::string> AttrNames() const override {
std::vector<std::string> res;
const auto &xs = desc_->attrs();
std::transform(
xs.begin(),
xs.end(),
std::back_inserter(res),
[](const framework_proto::OpDesc_Attr &x) { return x.name(); });
return res;
}
template <typename T>
void SetAttr(const std::string &name, const T &v);
template <typename T>
T GetAttr(const std::string &name) const;
private:
std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework_proto::OpDesc_Var> &xs,
const std::string &param) const {
std::vector<std::string> res;
auto it = std::find_if(
xs.begin(), xs.end(), [&](const framework_proto::OpDesc_Var &it) {
return it.parameter() == param;
});
CHECK(it != xs.end());
const auto &ys = it->arguments();
std::transform(ys.begin(),
ys.end(),
std::back_inserter(res),
[](const std::string &x) { return x; });
return res;
}
void SetArgument(
google::protobuf::RepeatedPtrField<framework_proto::OpDesc_Var> *xs,
const std::string &param,
const std::vector<std::string> &args) {
auto it = std::find_if(
xs->begin(), xs->end(), [&](const framework_proto::OpDesc_Var &it) {
return it.parameter() == param;
});
if (it == xs->end()) {
auto *new_arg = xs->Add();
new_arg->set_parameter(param);
for (const auto &arg : args) {
*new_arg->mutable_arguments()->Add() = arg;
}
} else {
it->mutable_arguments()->Clear();
for (const auto &arg : args) {
*it->mutable_arguments()->Add() = arg;
}
}
}
std::vector<std::string> GetArgumentNames(
const google::protobuf::RepeatedPtrField<framework_proto::OpDesc_Var> &xs)
const {
std::vector<std::string> res;
std::transform(
xs.begin(),
xs.end(),
std::back_inserter(res),
[](const framework_proto::OpDesc_Var &x) { return x.parameter(); });
return res;
}
private:
framework_proto::OpDesc *desc_;
};
template <>
void OpDesc::SetAttr<std::string>(const std::string &name,
const std::string &v);
template <>
void OpDesc::SetAttr<std::vector<int>>(const std::string &name,
const std::vector<int> &v);
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/pb/program_desc.h"
#include <algorithm>
#include <limits>
namespace infrt::paddle::pb {
template <>
framework_proto::BlockDesc* ProgramDesc::GetBlock<framework_proto::BlockDesc>(
int32_t idx) {
CHECK_LT(idx, static_cast<int>(BlocksSize())) << "idx >= blocks.size()";
return desc_->mutable_blocks(idx);
}
template <>
framework_proto::BlockDesc*
ProgramDesc::AddBlock<framework_proto::BlockDesc>() {
return desc_->add_blocks();
}
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <string>
#include <vector>
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace framework_proto = ::paddle::framework::proto;
class ProgramDesc : public cpp::ProgramDescAPI {
public:
ProgramDesc() = delete;
explicit ProgramDesc(framework_proto::ProgramDesc *desc) : desc_(desc) {
CHECK(desc_);
}
framework_proto::ProgramDesc *Proto() { return desc_; }
const framework_proto::ProgramDesc &ReadonlyProto() const { return *desc_; }
size_t BlocksSize() const override { return desc_->blocks_size(); }
void ClearBlocks() override { desc_->clear_blocks(); }
template <typename T>
T *GetBlock(int32_t idx);
template <typename T>
T *AddBlock();
bool HasVersion() const override { return desc_->has_version(); }
int64_t Version() const override { return desc_->version().version(); }
void SetVersion(int64_t version) override {
desc_->mutable_version()->set_version(version);
}
private:
framework_proto::ProgramDesc *desc_; // not_own
};
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/pb/var_desc.h"
#include <google/protobuf/map.h>
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
cpp::VarDescAPI::Type VarDesc::GetType() const {
auto type = desc_->type().type();
#define GET_TYPE_CASE_ITEM(type__) \
case framework_proto::VarType::type__: \
return cpp::VarDescAPI::Type::type__;
switch (type) {
GET_TYPE_CASE_ITEM(LOD_TENSOR);
GET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY);
GET_TYPE_CASE_ITEM(LOD_RANK_TABLE);
GET_TYPE_CASE_ITEM(SELECTED_ROWS);
GET_TYPE_CASE_ITEM(FEED_MINIBATCH);
GET_TYPE_CASE_ITEM(FETCH_LIST);
GET_TYPE_CASE_ITEM(STEP_SCOPES);
GET_TYPE_CASE_ITEM(PLACE_LIST);
GET_TYPE_CASE_ITEM(READER);
default:
LOG(FATAL) << "Unknown var type";
return VarDescAPI::Type();
}
#undef GET_TYPE_CASE_ITEM
}
void VarDesc::SetType(VarDescAPI::Type type) {
#define SET_TYPE_CASE_ITEM(type__) \
case VarDescAPI::Type::type__: \
desc_->mutable_type()->set_type(framework_proto::VarType::type__); \
break;
switch (type) {
SET_TYPE_CASE_ITEM(LOD_TENSOR);
SET_TYPE_CASE_ITEM(LOD_TENSOR_ARRAY);
SET_TYPE_CASE_ITEM(LOD_RANK_TABLE);
SET_TYPE_CASE_ITEM(SELECTED_ROWS);
SET_TYPE_CASE_ITEM(FEED_MINIBATCH);
SET_TYPE_CASE_ITEM(FETCH_LIST);
SET_TYPE_CASE_ITEM(STEP_SCOPES);
SET_TYPE_CASE_ITEM(PLACE_LIST);
SET_TYPE_CASE_ITEM(READER);
default:
LOG(FATAL) << "Unknown var type";
}
#undef SET_TYPE_CASE_ITEM
}
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
}
void VarDesc::SetTensorDescNum(size_t num) {
switch (desc_->type().type()) {
case framework_proto::VarType::READER: {
auto *lod_tensors_ptr =
desc_->mutable_type()->mutable_reader()->mutable_lod_tensor();
lod_tensors_ptr->Clear();
for (size_t i = 0; i < num; ++i) {
lod_tensors_ptr->Add();
}
return;
} break;
default:
LOG(FATAL) << "Setting 'sub_tensor_number' is not supported by the type "
"of var %s."
<< this->Name();
}
}
size_t VarDesc::GetTensorDescNum() const {
switch (desc_->type().type()) {
case framework_proto::VarType::READER:
return desc_->type().reader().lod_tensor_size();
break;
default:
LOG(FATAL) << "Getting 'sub_tensor_number' is not supported by the type "
"of var %s."
<< this->Name();
}
return 0;
}
void VarDesc::SetShapes(
const std::vector<std::vector<int64_t>> &multiple_dims) {
if (multiple_dims.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_dims.size());
}
std::vector<framework_proto::VarType::TensorDesc *> tensors =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_dims.size(); ++i) {
VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
}
}
std::vector<int64_t> VarDesc::GetShape() const {
return RepeatedToVector(tensor_desc().dims());
}
std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
std::vector<framework_proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<std::vector<int64_t>> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(RepeatedToVector(tensor_desc.dims()));
}
return res;
}
void VarDesc::SetDataType(VarDescAPI::VarDataType data_type) {
#define SET_DATA_TYPE_CASE_ITEM(type__) \
case cpp::VarDescAPI::Type::type__: \
mutable_tensor_desc()->set_data_type(framework_proto::VarType::type__); \
break;
switch (data_type) {
SET_DATA_TYPE_CASE_ITEM(BOOL);
SET_DATA_TYPE_CASE_ITEM(SIZE_T);
SET_DATA_TYPE_CASE_ITEM(UINT8);
SET_DATA_TYPE_CASE_ITEM(INT8);
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var type: " << static_cast<int>(data_type);
}
#undef SET_DATA_TYPE_CASE_ITEM
}
void VarDesc::SetDataTypes(
const std::vector<framework_proto::VarType::Type> &multiple_data_type) {
if (multiple_data_type.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given data types("
<< multiple_data_type.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_data_type.size());
}
std::vector<framework_proto::VarType::TensorDesc *> tensor_descs =
mutable_tensor_descs();
for (size_t i = 0; i < multiple_data_type.size(); ++i) {
tensor_descs[i]->set_data_type(multiple_data_type[i]);
}
}
// proto::VarType::Type VarDesc::GetDataType() const {
// return tensor_desc().data_type();
// }
cpp::VarDescAPI::VarDataType VarDesc::GetDataType() const {
CHECK(desc_->has_type()) << "The var's type hasn't been set.";
CHECK(desc_->type().has_type()) << "The var type hasn't been set.";
if (desc_->type().type() != framework_proto::VarType::LOD_TENSOR) {
return VarDescAPI::Type();
}
auto type = tensor_desc().data_type();
#define GET_DATA_TYPE_CASE_ITEM(type__) \
case framework_proto::VarType::Type::VarType_Type_##type__: \
return VarDescAPI::Type::type__
switch (type) {
GET_DATA_TYPE_CASE_ITEM(BOOL);
GET_DATA_TYPE_CASE_ITEM(SIZE_T);
GET_DATA_TYPE_CASE_ITEM(UINT8);
GET_DATA_TYPE_CASE_ITEM(INT8);
GET_DATA_TYPE_CASE_ITEM(INT16);
GET_DATA_TYPE_CASE_ITEM(INT32);
GET_DATA_TYPE_CASE_ITEM(INT64);
GET_DATA_TYPE_CASE_ITEM(FP16);
GET_DATA_TYPE_CASE_ITEM(FP32);
GET_DATA_TYPE_CASE_ITEM(FP64);
default:
LOG(FATAL) << "Unknown var type: " << static_cast<int>(type);
return VarDescAPI::Type();
}
#undef GET_DATA_TYPE_CASE_ITEM
}
std::vector<framework_proto::VarType::Type> VarDesc::GetDataTypes() const {
std::vector<framework_proto::VarType::TensorDesc> descs = tensor_descs();
std::vector<framework_proto::VarType::Type> res;
res.reserve(descs.size());
for (const auto &tensor_desc : descs) {
res.push_back(tensor_desc.data_type());
}
return res;
}
void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_->type().type()) {
case framework_proto::VarType::LOD_TENSOR:
desc_->mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
break;
case framework_proto::VarType::LOD_TENSOR_ARRAY:
desc_->mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
break;
default:
LOG(FATAL)
<< "Setting 'lod_level' is not supported by the type of var %s."
<< this->Name();
}
}
void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
if (multiple_lod_level.size() != GetTensorDescNum()) {
VLOG(3) << "WARNING: The number of given lod_levels("
<< multiple_lod_level.size()
<< ") doesn't match the existing tensor number("
<< GetTensorDescNum()
<< "). The Reader is going to be reinitialized.";
SetTensorDescNum(multiple_lod_level.size());
}
switch (desc_->type().type()) {
case framework_proto::VarType::READER: {
size_t i = 0;
for (auto &lod_tensor :
*desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) {
lod_tensor.set_lod_level(multiple_lod_level[i++]);
}
} break;
default:
LOG(FATAL)
<< "Setting 'lod_levels' is not supported by the type of var %s."
<< this->Name();
}
}
int32_t VarDesc::GetLoDLevel() const {
switch (desc_->type().type()) {
case framework_proto::VarType::LOD_TENSOR:
return desc_->type().lod_tensor().lod_level();
case framework_proto::VarType::LOD_TENSOR_ARRAY:
return desc_->type().tensor_array().lod_level();
default:
LOG(FATAL)
<< "Getting 'lod_level' is not supported by the type of var %s."
<< this->Name();
}
return 0;
}
std::vector<int32_t> VarDesc::GetLoDLevels() const {
std::vector<int32_t> res;
switch (desc_->type().type()) {
case framework_proto::VarType::READER:
res.reserve(desc_->type().reader().lod_tensor_size());
for (auto &lod_tensor : desc_->type().reader().lod_tensor()) {
res.push_back(lod_tensor.lod_level());
}
return res;
break;
default:
LOG(FATAL)
<< "Getting 'lod_levels' is not supported by the type of var %s."
<< this->Name();
}
return std::vector<int32_t>();
}
const framework_proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
CHECK(desc_->has_type()) << "The var's type hasn't been set.";
CHECK(desc_->type().has_type()) << "The var type hasn't been set.";
switch (desc_->type().type()) {
case framework_proto::VarType::SELECTED_ROWS:
return desc_->type().selected_rows();
case framework_proto::VarType::LOD_TENSOR:
return desc_->type().lod_tensor().tensor();
case framework_proto::VarType::LOD_TENSOR_ARRAY:
return desc_->type().tensor_array().tensor();
default:
LOG(FATAL)
<< "Getting 'tensor_desc' is not supported by the type of var %s."
<< this->Name();
}
return framework_proto::VarDesc().type().lod_tensor().tensor();
}
std::vector<framework_proto::VarType::TensorDesc> VarDesc::tensor_descs()
const {
CHECK(desc_->has_type()) << "The var type hasn't been set.";
std::vector<framework_proto::VarType::TensorDesc> res;
res.reserve(GetTensorDescNum());
switch (desc_->type().type()) {
case framework_proto::VarType::READER:
for (const auto &lod_tensor : desc_->type().reader().lod_tensor()) {
res.push_back(lod_tensor.tensor());
}
return res;
default:
LOG(FATAL)
<< "Getting 'tensor_descs' is not supported by the type of var "
"%s."
<< this->Name();
}
return std::vector<framework_proto::VarType::TensorDesc>();
}
framework_proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
CHECK(desc_->has_type()) << "The var type hasn't been set.";
CHECK(desc_->type().has_type()) << "The var type hasn't been set.";
switch (desc_->type().type()) {
case framework_proto::VarType::SELECTED_ROWS:
return desc_->mutable_type()->mutable_selected_rows();
case framework_proto::VarType::LOD_TENSOR:
return desc_->mutable_type()->mutable_lod_tensor()->mutable_tensor();
case framework_proto::VarType::LOD_TENSOR_ARRAY:
return desc_->mutable_type()->mutable_tensor_array()->mutable_tensor();
default:
LOG(FATAL) << "Getting 'mutable_tensor_desc' is not supported by the "
"type of var "
"%s."
<< this->Name();
}
return nullptr;
}
std::vector<framework_proto::VarType::TensorDesc *>
VarDesc::mutable_tensor_descs() {
CHECK(desc_->has_type()) << "The var type hasn't been set.";
CHECK(desc_->type().has_type()) << "The var type hasn't been set.";
std::vector<framework_proto::VarType::TensorDesc *> res;
res.reserve(GetTensorDescNum());
switch (desc_->type().type()) {
case framework_proto::VarType::READER:
for (auto &lod_tensor :
*desc_->mutable_type()->mutable_reader()->mutable_lod_tensor()) {
res.push_back(lod_tensor.mutable_tensor());
}
return res;
default:
LOG(FATAL)
<< "Getting 'tensor_descs' is not supported by the type of var "
"%s."
<< this->Name();
}
return std::vector<framework_proto::VarType::TensorDesc *>();
}
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include <google/protobuf/map.h>
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/infrt/paddle/cpp/desc_api.h"
#include "paddle/infrt/paddle/framework.pb.h"
namespace infrt::paddle::pb {
namespace framework_proto = ::paddle::framework::proto;
// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Clear();
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
}
}
class VarDesc : public cpp::VarDescAPI {
public:
VarDesc() = delete;
explicit VarDesc(framework_proto::VarDesc *desc) : desc_(desc) {
CHECK(desc_);
}
::paddle::framework::proto::VarDesc *Proto() { return desc_; }
const framework_proto::VarDesc &ReadonlyProto() const { return *desc_; }
std::string Name() const override { return desc_->name(); }
void SetName(std::string name) override { desc_->set_name(name); }
void SetTensorDescNum(size_t num);
size_t GetTensorDescNum() const;
void SetShape(const std::vector<int64_t> &dims);
void SetShapes(const std::vector<std::vector<int64_t>> &multiple_dims);
std::vector<int64_t> GetShape() const;
std::vector<std::vector<int64_t>> GetShapes() const;
void SetDataType(VarDescAPI::VarDataType data_type);
void SetDataTypes(
const std::vector<framework_proto::VarType::Type> &multiple_data_type);
VarDescAPI::VarDataType GetDataType() const;
std::vector<framework_proto::VarType::Type> GetDataTypes() const;
void SetLoDLevel(int32_t lod_level);
void SetLoDLevels(const std::vector<int32_t> &multiple_lod_level);
int32_t GetLoDLevel() const;
std::vector<int32_t> GetLoDLevels() const;
VarDescAPI::Type GetType() const override;
void SetType(VarDescAPI::Type type) override;
bool Persistable() const override { return desc_->persistable(); }
void SetPersistable(bool persistable) override {
desc_->set_persistable(persistable);
}
private:
const framework_proto::VarType::TensorDesc &tensor_desc() const;
std::vector<framework_proto::VarType::TensorDesc> tensor_descs() const;
framework_proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<framework_proto::VarType::TensorDesc *> mutable_tensor_descs();
framework_proto::VarDesc *desc_;
};
} // namespace infrt::paddle::pb
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/scope.h"
#include "paddle/infrt/common/common.h"
namespace infrt {
namespace paddle {
_Variable* Scope::FindVar(const std::string& name) const {
auto it = data_.find(name);
if (it != data_.end()) return it->second.get();
return nullptr;
}
Tensor Scope::GetTensor(const std::string& name) const {
CheckVarNameValid(name);
auto* var = FindVar(name);
CHECK(var) << "No variable called [" << name << "] found";
return var->get<Tensor>();
}
std::vector<std::string> Scope::var_names() const {
std::vector<std::string> names;
for (auto& item : data_) {
names.push_back(item.first);
}
return names;
}
} // namespace paddle
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/infrt/common/macros.h"
#include "paddle/infrt/paddle/tensor.h"
#include "paddle/infrt/support/variant.h"
namespace infrt {
namespace paddle {
using _Variable = Variant<Tensor>;
struct _Tensor_;
class Scope {
public:
static std::shared_ptr<Scope> Create() { return std::make_shared<Scope>(); }
//! Get or create a variable.
template <typename T>
_Variable* Var(const std::string& name);
//! Find a variable, get null if not exists.
_Variable* FindVar(const std::string& name) const;
Tensor GetTensor(const std::string& name) const;
//! Get variable names.
std::vector<std::string> var_names() const;
Scope() = default;
private:
std::unordered_map<std::string, std::unique_ptr<_Variable>> data_;
INFRT_DISALLOW_COPY_AND_ASSIGN(Scope);
};
template <typename T>
_Variable* Scope::Var(const std::string& name) {
VLOG(4) << "Scope insert Var [" << name << "]";
_Variable* x = FindVar(name);
if (x) return x;
auto* data = new _Variable(T());
data_[name].reset(data);
return data;
}
} // namespace paddle
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/paddle/tensor.h"
namespace infrt {
namespace paddle {} // namespace paddle
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <memory>
#include <numeric>
#include <vector>
#include "paddle/infrt/common/buffer.h"
#include "paddle/infrt/common/common.h"
#include "paddle/infrt/common/object.h"
namespace infrt {
namespace paddle {
using common::Target;
struct Shape {
using dim_t = int;
Shape() = default;
explicit Shape(const std::vector<dim_t>& data) : data_(data) {}
void SetData(const std::vector<dim_t>& data) { data_ = data; }
const std::vector<dim_t>& data() const INFRT_RESULT_SHOULD_USE {
return data_;
}
std::vector<dim_t>& data() INFRT_RESULT_SHOULD_USE { return data_; }
size_t size() const INFRT_RESULT_SHOULD_USE { return data_.size(); }
uint32_t numel() const INFRT_RESULT_SHOULD_USE {
return std::accumulate(
data_.begin(), data_.end(), 1, [](dim_t a, dim_t b) { return a * b; });
}
private:
std::vector<dim_t> data_;
};
class _Tensor_ : public common::Object {
public:
_Tensor_() : buffer_(std::make_shared<Buffer>()) {}
Shape& shape() { return shape_; }
void Resize(const Shape& shape) {
shape_ = shape;
buffer_->data()->resize(
reinterpret_cast<const infrt_dimension_t*>(shape.data().data()),
shape.size());
}
template <typename T>
inline T* mutable_data(const Target& target) {
set_type(type_of<T>());
if (target == common::DefaultHostTarget()) {
int alignment = type_of<T>().ElementOf().bits();
buffer_->ResizeLazy(alignment, shape_.numel() * sizeof(T), target);
} else {
buffer_->ResizeLazy(shape_.numel() * sizeof(T), target);
}
return reinterpret_cast<T*>(buffer_->data()->memory);
}
template <typename T>
const T* data() const {
return reinterpret_cast<T*>(buffer_->data()->memory);
}
const Type& type() { return type_; }
void set_type(Type type) { type_ = type; }
const Type& type() const { return type_; }
infrt_buffer_t* buffer() { return buffer_->data(); }
const char* type_info() const override { return __type_info__; }
private:
common::Type type_;
// A shared ptr to make it easier to share buffer between tensors.
std::shared_ptr<Buffer> buffer_;
Shape shape_;
static constexpr const char* __type_info__ = "_frontend_tensor_";
};
class Tensor : public Shared<_Tensor_> {
public:
Tensor() : Shared(new _Tensor_) {}
explicit Tensor(_Tensor_* x) : Shared(x) {}
};
} // namespace paddle
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file defines type traits related utilities.
#pragma once
#include <tuple>
#include <type_traits>
#include <utility>
#include "llvm/ADT/STLExtras.h"
namespace infrt {
// Utility template for tag dispatching.
template <typename T>
struct TypeTag {};
// This is the equivalent of std::void_t in C++17.
template <typename... Ts>
struct make_void {
typedef void type;
};
template <typename... Ts>
using void_t = typename make_void<Ts...>::type;
// The same as std::disjunction in C++17.
template <class...>
struct disjunction : std::false_type {};
template <class B1>
struct disjunction<B1> : B1 {};
template <class B1, class... Bn>
struct disjunction<B1, Bn...>
: std::conditional_t<bool(B1::value), B1, disjunction<Bn...>> {};
// Check whether T may be a base class.
template <typename T>
using MaybeBase =
llvm::conjunction<std::is_class<T>, llvm::negation<std::is_final<T>>>;
// Find the index of a type in a tuple.
//
// Example:
// using Tuple = std::tuple<int, float, double>;
// static_assert(TupleIndexOf<int, Tuple>::value == 0);
// static_assert(TupleIndexOf<double, Tuple>::value == 2);
template <class T, class Tuple>
struct TupleIndexOf;
template <class T, class... Types>
struct TupleIndexOf<T, std::tuple<T, Types...>>
: std::integral_constant<size_t, 0> {};
template <class T, class U, class... Types>
struct TupleIndexOf<T, std::tuple<U, Types...>>
: std::integral_constant<size_t,
1 + TupleIndexOf<T, std::tuple<Types...>>::value> {
};
template <typename T, typename Tuple>
struct TupleHasType;
template <typename T, typename... Us>
struct TupleHasType<T, std::tuple<Us...>>
: disjunction<std::is_same<T, Us>...> {};
// The detector pattern in C++ that can be used for checking whether a type has
// a specific property, e.g. whether an internal type is present or whether a
// particular operation is valid.
//
// Sample usage:
//
// struct Foo {
// using difference_type = int;
// int get();
// };
// struct Bar {};
//
// // Check whether a type T has an internal difference_type.
// template<class T>
// using diff_t = typename T::difference_type;
//
// static_assert(is_detected_v<diff_t, Foo>, "Foo has difference_type");
// static_assert(!is_detected_v<diff_t, Bar>, "Bar has no difference_type");
//
// // Check whether a type T has a get() member function.
// template<class T>
// using has_get_t = decltype(std::declval<T>().get());
//
// static_assert(is_detected_v<has_get_t, Foo>, "Foo has get()");
// static_assert(!is_detected_v<has_get_t, Bar>, "Bar has no get()");
//
// See https://en.cppreference.com/w/cpp/experimental/is_detected for details.
namespace internal {
// nonesuch is a class type used to indicate detection failure.
struct nonesuch {
~nonesuch() = delete;
nonesuch(nonesuch const&) = delete;
void operator=(nonesuch const&) = delete;
};
template <class Default,
class AlwaysVoid,
template <class...> class Op,
class... Args>
struct detector : std::false_type {
using value_t = std::false_type;
using type = Default;
};
template <class Default, template <class...> class Op, class... Args>
struct detector<Default, void_t<Op<Args...>>, Op, Args...> {
using value_t = std::true_type;
using type = Op<Args...>;
};
} // namespace internal
template <template <class...> class Op, class... Args>
using is_detected =
typename internal::detector<internal::nonesuch, void, Op, Args...>::value_t;
template <template <class...> class Op, class... Args>
using detected_t =
typename internal::detector<internal::nonesuch, void, Op, Args...>::type;
template <class Default, template <class...> class Op, class... Args>
using detected_or = internal::detector<Default, void, Op, Args...>;
template <template <class...> class Op, class... Args>
constexpr bool is_detected_v = is_detected<Op, Args...>::value;
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file implements the variant data structure similar to
// absl::variant in C++17.
#pragma once
#include <algorithm>
#include <tuple>
#include <type_traits>
#include <utility>
#include "paddle/infrt/support/type_traits.h"
namespace infrt {
// A Variant similar to absl::variant in C++17.
//
// Example usage:
//
// Variant<int, float, double> v;
//
// v = 1;
// assert(v.get<int>() == 1);
// assert(v.is<int>());
// assert(v.get_if<float>() == nullptr);
//
// // Print the variant.
// visit([](auto& t) { std::cout << t; }, v);
//
// v.emplace<float>(3);
//
template <typename... Ts>
class Variant {
// Convenient constant to check if a type is a variant.
template <typename T>
static constexpr bool IsVariant =
std::is_same<std::decay_t<T>, Variant>::value;
public:
using IndexT = int16_t;
using Types = std::tuple<Ts...>;
template <int N>
using TypeOf = typename std::tuple_element<N, Types>::type;
static constexpr size_t kNTypes = sizeof...(Ts);
// Default constructor sets the Variant to the default constructed fisrt type.
Variant() {
using Type0 = TypeOf<0>;
index_ = 0;
new (&storage_) Type0();
}
template <typename T, std::enable_if_t<!IsVariant<T>, int> = 0>
explicit Variant(T&& t) {
fillValue(std::forward<T>(t));
}
Variant(const Variant& v) {
visit([this](auto& t) { fillValue(t); }, v);
}
Variant(Variant&& v) {
visit([this](auto&& t) { fillValue(std::move(t)); }, v);
}
~Variant() { destroy(); }
Variant& operator=(Variant&& v) {
visit([this](auto& t) { *this = std::move(t); }, v);
return *this;
}
Variant& operator=(const Variant& v) {
visit([this](auto& t) { *this = t; }, v);
return *this;
}
template <typename T, std::enable_if_t<!IsVariant<T>, int> = 0>
Variant& operator=(T&& t) {
destroy();
fillValue(std::forward<T>(t));
return *this;
}
template <typename T, typename... Args>
T& emplace(Args&&... args) {
AssertHasType<T>();
destroy();
index_ = IndexOf<T>;
auto* t = new (&storage_) T(std::forward<Args>(args)...);
return *t;
}
template <typename T>
bool is() const {
AssertHasType<T>();
return IndexOf<T> == index_;
}
template <typename T>
const T& get() const {
AssertHasType<T>();
return *reinterpret_cast<const T*>(&storage_);
}
template <typename T>
T& get() {
AssertHasType<T>();
return *reinterpret_cast<T*>(&storage_);
}
template <typename T>
const T* get_if() const {
if (is<T>()) return &get<T>();
return nullptr;
}
template <typename T>
T* get_if() {
if (is<T>()) return &get<T>();
return nullptr;
}
IndexT index() { return index_; }
private:
template <typename T>
static constexpr size_t IndexOf = TupleIndexOf<T, Types>::value;
static constexpr size_t kStorageSize = std::max({sizeof(Ts)...});
static constexpr size_t kAlignment = std::max({alignof(Ts)...});
template <typename T>
static constexpr void AssertHasType() {
constexpr bool has_type = TupleHasType<T, Types>::value;
static_assert(has_type, "Invalid Type used for Variant");
}
void destroy() {
visit(
[](auto& t) {
using T = std::decay_t<decltype(t)>;
t.~T();
},
*this);
}
template <typename T>
void fillValue(T&& t) {
using Type = std::decay_t<T>;
AssertHasType<Type>();
index_ = IndexOf<Type>;
new (&storage_) Type(std::forward<T>(t));
}
using StorageT = std::aligned_storage_t<kStorageSize, kAlignment>;
StorageT storage_;
IndexT index_ = -1;
};
struct Monostate {};
namespace internal {
template <typename F, typename Variant>
decltype(auto) visitHelper(
F&& f,
Variant&& v,
std::integral_constant<int, std::decay_t<Variant>::kNTypes>) {
assert(false && "Unexpected index_ in Variant");
}
// Disable clang-format as it does not format less-than (<) in the template
// parameter properly.
//
// clang-format off
template <
typename F, typename Variant, int N,
std::enable_if_t<N < std::decay_t<Variant>::kNTypes, int> = 0>
decltype(auto) visitHelper(F&& f, Variant&& v, std::integral_constant<int, N>) {
// clang-format on
using VariantT = std::decay_t<Variant>;
using T = typename VariantT::template TypeOf<N>;
if (auto* t = v.template get_if<T>()) {
return f(*t);
} else {
return visitHelper(std::forward<F>(f),
std::forward<Variant>(v),
std::integral_constant<int, N + 1>());
}
}
} // namespace internal
template <typename F, typename Variant>
decltype(auto) visit(F&& f, Variant&& v) {
return internal::visitHelper(std::forward<F>(f),
std::forward<Variant>(v),
std::integral_constant<int, 0>());
}
} // namespace infrt
core_gather_headers()
gather_srcs(infrt_src SRCS
tensor_map.cc
tensor_metadata.cc
dense_tensor_view.cc
dense_host_tensor.cc
tensor_shape.cc
)
# set(tensor_map_mlir "${CMAKE_SOURCE_DIR}/infrt/dialect/mlir_tests/tensor_map.mlir")
# set(external_kernels_lib "${CMAKE_BINARY_DIR}/paddle/libexternal_kernels.so")
# message(STATUS "tensor_map_mlir: ${tensor_map_mlir}")
# message(STATUS "external_kernels_lib: ${external_kernels_lib}")
# Disable temporarily for the external-kernel's mkldnn is outdate
# add_test(
# NAME run_and_check_tensor_map
# COMMAND sh -c "sed -e 's|/infrt/build|${CMAKE_BINARY_DIR}|' ${tensor_map_mlir} > /tmp/tensor_map.mlir && ${CMAKE_BINARY_DIR}/infrt/host_context/infrt-exec -i /tmp/tensor_map.mlir --shared_libs=${external_kernels_lib} | ${LLVM_PATH}/bin/FileCheck ${tensor_map_mlir}"
# )
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/tensor/dense_host_tensor.h"
#include <llvm/Support/raw_os_ostream.h>
#include "paddle/infrt/common/buffer.h"
namespace infrt::tensor {
DenseHostTensor::DenseHostTensor(const TensorShape& shape, DType dtype)
: HostTensor(TensorMetadata{dtype, shape}) {
CHECK(metadata().IsValid()) << "Tensor construct get invalid metadata";
buffer_.reset(new infrt::Buffer(infrt::common::DefaultHostTarget()));
buffer_->ResizeLazy(dtype.GetHostSize() * shape.GetNumElements());
}
const TensorShape& DenseHostTensor::shape() const { return metadata().shape; }
void DenseHostTensor::Init(const std::vector<int64_t>& shape, DType dtype) {
auto shape_array = llvm::ArrayRef<int64_t>(shape.data(), shape.size());
auto metadata = TensorMetadata(dtype, shape_array);
setTensorMetadata(metadata);
buffer_.reset(new infrt::Buffer(infrt::common::DefaultHostTarget()));
buffer_->ResizeLazy(dtype.GetHostSize() * metadata.shape.GetNumElements());
}
const infrt::Buffer* DenseHostTensor::buffer() const { return buffer_.get(); }
template <typename T>
void DisplayArray(std::ostream& os, T* data, int num_elements) {
for (int i = 0; i < num_elements - 1; i++) os << data[i] << ", ";
if (num_elements > 0) os << data[num_elements - 1];
}
std::ostream& operator<<(std::ostream& os, const DenseHostTensor& instance) {
CHECK(instance.metadata().IsValid())
<< "Cann't print tensor with invalid metadata";
llvm::raw_os_ostream oos(os);
oos << "tensor: ";
oos << "shape=";
oos << instance.shape();
oos << ", values=[";
oos.flush();
if (instance.metadata().dtype == GetDType<float>()) {
auto* data = reinterpret_cast<float*>(instance.buffer()->data()->memory);
DisplayArray(os, data, instance.shape().GetNumElements());
} else if (instance.metadata().dtype == GetDType<double>()) {
auto* data = reinterpret_cast<double*>(instance.buffer()->data()->memory);
DisplayArray(os, data, instance.shape().GetNumElements());
} else if (instance.metadata().dtype == GetDType<int32_t>()) {
auto* data = reinterpret_cast<int32_t*>(instance.buffer()->data()->memory);
DisplayArray(os, data, instance.shape().GetNumElements());
} else if (instance.metadata().dtype == GetDType<int64_t>()) {
auto* data = reinterpret_cast<int64_t*>(instance.buffer()->data()->memory);
DisplayArray(os, data, instance.shape().GetNumElements());
} else {
LOG(FATAL) << "Not supported dtype [" << instance.metadata().dtype.name()
<< " " << static_cast<int>(instance.metadata().dtype.kind())
<< "] in print";
}
os << "]";
return os;
}
DenseHostTensor::~DenseHostTensor() {}
void* DenseHostTensor::raw_data() const { return buffer_->data()->memory; }
} // namespace infrt::tensor
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <utility>
#include "paddle/infrt/tensor/tensor_metadata.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt {
class Buffer;
} // namespace infrt
namespace infrt::tensor {
enum class DeviceKind {
kCPU = 0,
};
class Tensor {
public:
virtual bool IsHostTensor() const = 0;
virtual ~Tensor() = default;
const TensorMetadata& metadata() const { return metadata_; }
protected:
Tensor() = default;
void setTensorMetadata(TensorMetadata& metadata) { // NOLINT
metadata_ = metadata;
}
explicit Tensor(const TensorMetadata& metadata) : metadata_(metadata) {}
explicit Tensor(TensorMetadata&& metadata) : metadata_(std::move(metadata)) {}
private:
TensorMetadata metadata_;
};
class HostTensor : public Tensor {
public:
bool IsHostTensor() const override { return true; }
virtual ~HostTensor() {}
protected:
HostTensor() = default;
explicit HostTensor(const TensorMetadata& metadata) : Tensor(metadata) {}
explicit HostTensor(TensorMetadata&& metadata)
: Tensor(std::move(metadata)) {}
};
// TODO(Superjomn) Replace the hlir/framework/Tensor with this.
/**
* DenseTensor is a dense tensor, it holds a TensorShape and a buffer.
*/
class DenseHostTensor : public HostTensor {
public:
DenseHostTensor() = default;
DenseHostTensor(const TensorShape& shape, DType dtype);
void Init(const std::vector<int64_t>& shape, DType dtype);
const TensorShape& shape() const;
const Buffer* buffer() const;
void* raw_data() const;
friend std::ostream& operator<<(std::ostream& os,
const DenseHostTensor& instance);
virtual ~DenseHostTensor();
private:
// TODO(Superjomn) Discard the dependency of the Buffer in infrtcore or create
// a general buffer in common.
std::shared_ptr<Buffer> buffer_;
};
} // namespace infrt::tensor
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/tensor/dense_tensor_view.h"
namespace infrt::tensor {} // namespace infrt::tensor
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt::tensor {
template <typename DType>
class DTArrayView {
public:
using UnderlyingT = DenseHostTensor;
explicit DTArrayView(const DenseHostTensor* tensor) : tensor_(*tensor) {}
const TensorShape& shape() { return tensor_.shape(); }
size_t GetNumElements() const { return tensor_.shape().GetNumElements(); }
const DType* data() const {
return static_cast<const DType*>(tensor_.raw_data());
}
DType* data() { return static_cast<DType*>(tensor_.raw_data()); }
llvm::ArrayRef<DType> Elements() const {
return llvm::ArrayRef<DType>(data(), GetNumElements());
}
private:
const DenseHostTensor& tensor_;
};
template <typename DType>
class MutableDTArrayView : public DTArrayView<DType> {
public:
explicit MutableDTArrayView(DenseHostTensor* tensor)
: DTArrayView<DType>(tensor) {}
void Fill(const DType& v) {
std::fill(this->data(), this->data() + this->GetNumElements(), v);
}
using DTArrayView<DType>::data;
using DTArrayView<DType>::GetNumElements;
llvm::MutableArrayRef<DType> Elements() {
return llvm::MutableArrayRef<DType>(data(), this->GetNumElements());
}
};
} // namespace infrt::tensor
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/tensor/tensor_map.h"
#include <fstream>
#include <iostream>
#include "paddle/infrt/common/string.h"
#include "paddle/infrt/paddle/model_parser.h"
using Scope = infrt::paddle::Scope;
using Target = infrt::common::Target;
using Type = infrt::common::Type;
namespace infrt {
namespace tensor {
DType CinnType2DType_(Type type) {
if (type.is_bool()) return GetDType<bool>();
if (type.is_int(8)) return GetDType<int8_t>();
if (type.is_int(16)) return GetDType<int16_t>();
if (type.is_int(32)) return GetDType<int32_t>();
if (type.is_int(64)) return GetDType<int64_t>();
if (type.is_uint(8)) return GetDType<uint8_t>();
if (type.is_uint(16)) return GetDType<uint16_t>();
if (type.is_uint(32)) return GetDType<uint32_t>();
if (type.is_uint(64)) return GetDType<uint64_t>();
if (type.is_float(32)) return GetDType<float>();
if (type.is_float(64)) return GetDType<double>();
if (type.is_string()) return GetDType<std::string>();
return DType(DType::Kind::Unk);
}
TensorMap *LoadParams(const std::string &path) {
std::cout << "loading params from: " << path << std::endl;
TensorMap *map = new TensorMap();
Scope scope;
const Target &target = common::DefaultHostTarget();
std::string model_path = path + "/__model__";
// paddle::framework::proto::ProgramDesc pb_proto_prog =
// *infrt::frontend::paddle::LoadProgram(model_path);
auto pb_proto_prog = *paddle::LoadProgram(model_path);
// infrt::frontend::paddle::pb::ProgramDesc pb_prog_desc(&pb_proto_prog);
// infrt::frontend::paddle::TransformProgramDescAnyToCpp(pb_prog_desc,
// cpp_prog);
auto main_block = pb_proto_prog.blocks(0);
for (auto &var : main_block.vars()) {
if (var.name() == "feed" || var.name() == "fetch" || !var.persistable())
continue;
std::string param_path = path + "/" + var.name();
std::ifstream param_file(param_path, std::ios::binary);
switch (var.type().type()) {
case ::paddle::framework::proto::VarType_Type_LOD_TENSOR: {
auto var_name = infrt::TransValidVarName(var.name());
// std::cout << "var name: " << var.name() << " " << var_name <<
// std::endl;
auto *_var = scope.Var<paddle::Tensor>(var_name);
paddle::LoadLoDTensor(param_file, _var, target);
auto tensor = scope.GetTensor(var_name);
auto *src_data = tensor->data<float>();
auto &infrt_type = tensor->type();
std::vector<int64_t> shape;
for (int dim : tensor->shape().data()) shape.push_back(dim);
auto shape_array = llvm::ArrayRef<int64_t>(shape.data(), shape.size());
auto dtype = CinnType2DType_(infrt_type);
auto *dht = new DenseHostTensor(TensorShape(shape_array), dtype);
int num_elements = dht->shape().GetNumElements();
auto *dst_data = reinterpret_cast<float *>(dht->raw_data());
for (int i = 0; i < num_elements; ++i) dst_data[i] = src_data[i];
(*map)[var.name()] = dht;
break;
}
default:
std::cout << "unknown weight type" << std::endl;
break;
}
}
return map;
}
} // namespace tensor
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/infrt/tensor/dense_host_tensor.h"
namespace infrt {
namespace tensor {
using TensorMap = std::unordered_map<std::string, tensor::DenseHostTensor*>;
TensorMap* LoadParams(const std::string& path);
} // namespace tensor
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/tensor/tensor_metadata.h"
#include <llvm/Support/raw_ostream.h>
namespace infrt {
namespace tensor {
llvm::raw_ostream& operator<<(llvm::raw_ostream& os, TensorMetadata& meta) {
os << meta.dtype.name();
os << "\n";
os << meta.shape;
return os;
}
} // namespace tensor
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <glog/logging.h>
#include "paddle/infrt/common/dtype.h"
#include "paddle/infrt/tensor/tensor_shape.h"
namespace infrt {
namespace tensor {
struct TensorMetadata {
DType dtype;
TensorShape shape;
TensorMetadata() = default;
TensorMetadata(DType dtype, const TensorShape& shape)
: dtype(dtype), shape(shape) {
CHECK(IsValid());
}
TensorMetadata(DType dtype, llvm::ArrayRef<int64_t> shape)
: dtype(dtype), shape(shape) {
CHECK(IsValid());
}
size_t GetHostSizeInBytes() const {
return dtype.GetHostSize() * shape.GetNumElements();
}
bool IsValid() const { return dtype.IsValid(); }
bool IsInvalid() const { return !dtype.IsValid(); }
bool operator==(const TensorMetadata& other) const {
return dtype == other.dtype && shape == other.shape;
}
bool operator!=(const TensorMetadata& other) const {
return !(*this == other);
}
friend llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
TensorMetadata& meta);
};
} // namespace tensor
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/infrt/tensor/tensor_shape.h"
#include <glog/logging.h>
#include <llvm/Support/raw_ostream.h>
#include <algorithm>
#include <functional>
namespace infrt {
namespace tensor {
TensorShape::TensorShape(llvm::ArrayRef<int64_t> dims)
: dims_(dims.begin(), dims.end()) {}
int TensorShape::GetRank() const { return dims_.size(); }
int64_t TensorShape::GetDim(int idx) const {
CHECK_GE(idx, 0);
CHECK_LT(idx, GetRank());
return dims_[idx];
}
int TensorShape::GetNumElements() const {
int64_t size = 1;
for (int v : dims_) size *= v;
return size;
}
DynamicTensorShape::DynamicTensorShape(
llvm::Optional<llvm::ArrayRef<int64_t>> dims) {
if (dims.hasValue()) {
dims_ = llvm::SmallVector<int64_t, 4>(dims->begin(), dims->end());
}
}
int DynamicTensorShape::GetRank() const {
if (dims_.hasValue()) return dims_->size();
return kUnknownDimSize;
}
int64_t DynamicTensorShape::GetDim(int idx) const {
CHECK_GE(idx, 0);
CHECK_LT(idx, GetRank());
return (*dims_)[idx];
}
bool DynamicTensorShape::IsShapeKnown() const {
if (!dims_.hasValue()) return false;
for (int64_t v : *dims_) {
if (IsDimUnknown(v)) return false;
}
return true;
}
llvm::Optional<TensorShape> DynamicTensorShape::ToTensorShape() const {
if (IsShapeKnown()) {
return TensorShape(*dims_);
}
return llvm::None;
}
llvm::raw_ostream& operator<<(llvm::raw_ostream& os, const TensorShape& v) {
os << "shape[";
for (int i = 0; i < v.GetRank() - 1; i++) {
os << v.dims_[i] << ",";
}
if (v.GetRank() > 0) os << v.dims_.back();
os << "]";
return os;
}
std::ostream& operator<<(std::ostream& os, const DynamicTensorShape& v) {
os << "dynamic_shape[";
for (int i = 0; i < v.GetRank() - 1; i++) {
os << v << ",";
}
if (v.GetRank() > 0) os << v.dims_->back();
os << "]";
return os;
}
} // namespace tensor
} // namespace infrt
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <llvm/ADT/ArrayRef.h>
namespace infrt {
namespace tensor {
/**
* TensorShape represents the shape of a Tensor, all the dimensions should be
* known.
*/
class TensorShape {
public:
TensorShape() = default;
explicit TensorShape(llvm::ArrayRef<int64_t> dims);
int GetRank() const;
int64_t GetDim(int idx) const;
int GetNumElements() const;
friend llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
const TensorShape& v);
friend bool operator==(const TensorShape& a, const TensorShape& b) {
return a.dims_ == b.dims_;
}
private:
llvm::SmallVector<int64_t, 4> dims_;
};
/**
* DynamicTensorShape represents the shape of a Tensor, with some dimensions or
* even the rank is unknown.
*/
class DynamicTensorShape {
public:
explicit DynamicTensorShape(llvm::Optional<llvm::ArrayRef<int64_t>> dims);
//! Returns the rank if rank is known, or kUnknownDimSize.
int GetRank() const;
int64_t GetDim(int idx) const;
bool IsShapeKnown() const;
//! Convert to a TensorShape if all the dimensions are known.
llvm::Optional<TensorShape> ToTensorShape() const;
static constexpr int64_t kUnknownDimSize = -1;
static bool IsDimUnknown(int64_t dim) { return dim == kUnknownDimSize; }
friend std::ostream& operator<<(std::ostream& os,
const DynamicTensorShape& v);
friend bool operator==(const DynamicTensorShape& a,
const DynamicTensorShape& b) {
return a.dims_ == b.dims_;
}
private:
//! Will be std::nullopt if no dim is known.
llvm::Optional<llvm::SmallVector<int64_t, 4>> dims_;
};
} // namespace tensor
} // namespace infrt
...@@ -216,6 +216,7 @@ function cmake_base() { ...@@ -216,6 +216,7 @@ function cmake_base() {
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON -DCMAKE_EXPORT_COMPILE_COMMANDS=ON
-DWITH_CONTRIB=${WITH_CONTRIB:-ON} -DWITH_CONTRIB=${WITH_CONTRIB:-ON}
-DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON}
-DWITH_INFRT=${WITH_INFRT:-OFF}
-DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR}
-DPY_VERSION=${PY_VERSION:-2.7} -DPY_VERSION=${PY_VERSION:-2.7}
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build}
...@@ -262,6 +263,7 @@ EOF ...@@ -262,6 +263,7 @@ EOF
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \
-DWITH_CONTRIB=${WITH_CONTRIB:-ON} \ -DWITH_CONTRIB=${WITH_CONTRIB:-ON} \
-DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} \ -DWITH_INFERENCE_API_TEST=${WITH_INFERENCE_API_TEST:-ON} \
-DWITH_INFRT=${WITH_INFRT:-OFF} \
-DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} \ -DINFERENCE_DEMO_INSTALL_DIR=${INFERENCE_DEMO_INSTALL_DIR} \
-DPY_VERSION=${PY_VERSION:-2.7} \ -DPY_VERSION=${PY_VERSION:-2.7} \
-DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \ -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX:-/paddle/build} \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册