// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/lang/placeholder.h" #include "paddle/cinn/poly/schedule.h" namespace cinn { namespace lang { using compute_handler_t = std::function &)>; using attr_t = absl::variant; //! Compute methods for one to five Vars as arguments. // @{ // The shape are constant integers. ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, std::function fn, const std::string &name, const std::vector &shape = {}); ir::Tensor Compute(const std::vector &domain, compute_handler_t fn, const std::string &name, const std::vector &shape = {}); // @} struct ReturnType { Type type; std::vector dims; std::string name; }; /** * \brief Call a lowered function and return one or more tensors as result. * * A lowered function is generated by lang::Lower method. * * TODO(Superjomn) Add a registry (symbol table?) to make return result inference automatically. * * @param func_name The name of the function to call. * @param args The readonly arguments(while the mutable tensors are return result). * @param return_types The types of the return values. * @return Return one or more tensors as result. */ std::vector CallLowered(const std::string &func_name, const std::vector &args, const std::vector &return_types); /** * \brief Call an external function and get some tensors as result. * * There are two kinds of extern functions distinguished by the return type. * * 1. Void, there are one or more mutable tensors in the argument list. * \code * Tensor tuple = Compute({M}, []() { return CallExtern("mkl_gemm", {X, W}); }); * \endcode * * To support returning multiple value one time, we include the tuple concept, it is a Tensor with CallOp marked with * value_offset(from 0 to num_returns-1). * * 2. POD value, return an expression directly, and it can be inline expand in following computations. * \code * Tensor tanh_out = Compute({M}, [](Var i) { return CallExtern("tanh", X(i)); }); * \endcode * * Will generate something like * * \code * for (i) { * gemm_mkl(X[i], gemm_out[i]) * } * \endcode * * @param func_name The name of the function to call. * @param args The readonly arguments(while there should be only one tensor as result). * @param attrs The readonly attrs. */ Expr CallExtern(const std::string &func_name, const std::vector &args, const std::map &attrs = {}); } // namespace lang } // namespace cinn