// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** * \file This file defines some functions and macros to help register the extern functions into JIT. */ #pragma once #include #include #include #include #include "paddle/cinn/backends/extern_func_emitter.h" #include "paddle/cinn/backends/extern_func_emitter_builtin.h" #include "paddle/cinn/backends/extern_func_protos.h" #include "paddle/cinn/backends/function_prototype.h" #include "paddle/cinn/backends/llvm/codegen_llvm.h" #include "paddle/cinn/backends/llvm/ir_builder_mixin.h" #include "paddle/cinn/backends/llvm/llvm_util.h" #include "paddle/cinn/backends/llvm/runtime_symbol_registry.h" #include "paddle/cinn/common/macros.h" /** * Helper to register an external function into CINN, including the prototype, the function address. * @param fn__: name of the function * @param target__: the Target. */ #define REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ ::cinn::backends::RegisterExternFunction(#fn__, target__, reinterpret_cast(fn__)) #define REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) ::cinn::backends::RegisterExternFunction(#fn__, target__) /** * Register an external function with one input and one output. */ #define REGISTER_EXTERN_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ REGISTER_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() /** * Register an external function with one input and one output. */ #define REGISTER_EXTERN_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ REGISTER_EXTERN_FUNC_HELPER(fn__, target__) \ .SetRetType() \ .AddInputType() \ .AddInputType() \ .End() /** * Register a sourced function(No function address, called in generated source code). */ #define REGISTER_EXTERN_SOURCE_FUNC_1_IN_1_OUT(fn__, target__, in_type__, out_type__) \ REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__).SetRetType().AddInputType().End() /** * Register a sourced function(No function address, called in generated source code). */ #define REGISTER_EXTERN_SOURCE_FUNC_2_IN_1_OUT(fn__, target__, in_type1__, in_type2__, out_type__) \ REGISTER_FACKED_EXTERN_FUNC_HELPER(fn__, target__) \ .SetRetType() \ .AddInputType() \ .AddInputType() \ .End() namespace cinn { namespace backends { static const char* TargetToBackendRepr(Target target) { switch (target.arch) { case Target::Arch::X86: return backend_llvm_host; case Target::Arch::NVGPU: return backend_nvgpu; default: CINN_NOT_IMPLEMENTED } return nullptr; } /** * Helper class to register an external function. */ struct RegisterExternFunction { /** * Constructor. * @param fn_name Name of the function. * @param target Target of the function. * @param fn_ptr Address of the function, not valid if leave as null. */ RegisterExternFunction(const std::string& fn_name, Target target, void* fn_ptr = nullptr) : fn_name_(fn_name), target_(target), fn_ptr_(fn_ptr), fn_proto_builder_(fn_name) {} /** * Add an input type. * @tparam T The input type. * @return itself. */ template RegisterExternFunction& AddInputType() { fn_proto_builder_.AddInputType(); return *this; } /** * Add an output type. * @tparam T The output type. * @return itself. */ template RegisterExternFunction& AddOutputType() { fn_proto_builder_.AddOutputType(); return *this; } /** * Add an return type. * @tparam T The return type. * @return itself. */ template RegisterExternFunction& SetRetType() { fn_proto_builder_.SetRetType(); return *this; } /** * Add an shape inference. * @param handle The handle to help inference the shape. * @return itself. */ RegisterExternFunction& SetShapeInference(FunctionProto::shape_inference_t handle) { fn_proto_builder_.SetShapeInference(handle); return *this; } /** * End the register, once end, futher modification is disallowed. */ void End(); private: const std::string& fn_name_; Target target_; void* fn_ptr_{}; FunctionProto::Builder fn_proto_builder_; }; } // namespace backends } // namespace cinn