// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. /** * This file contains all the internal representations used in CINN project. */ #pragma once #include #include #include #include #include #include #include #include "paddle/cinn/common/shared.h" #include "paddle/cinn/common/type.h" #include "paddle/cinn/ir/function_base.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/utils/small_vector.h" namespace cinn { namespace poly { class Stage; } // namespace poly namespace ir { class Buffer; class BufferRange; struct LoweredFunc; class Module; using common::Object; using common::Shared; // NOTE attr_t only support POD, can not contain Expr or other IR nodes, or the IRVisitor or IRCopy on PrimitiveNode // will result in undefined behavior. using attr_t = absl::variant; /** * Cast a node to another type, can't change the width. */ struct Cast : public ExprNode { Cast() : ExprNode(1) {} static Expr Make(Type t, Expr v); template static Expr Make(Type t, T v) { return Make(t, Expr(v)); } Expr& v() { return operand(0); } const Expr& v() const { return operand(0); } void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Cast; std::vector expr_fields() override { return {&operand(0)}; } std::vector expr_fields() const override { return {&operand(0)}; } }; /** * The sum of two expressions. */ struct Add : public BinaryOpNode { Add(Expr a, Expr b); static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Add; }; /** * The difference of two expressions. */ struct Sub : public BinaryOpNode { Sub(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Sub; }; /** * The product of two expressions. */ struct Mul : public BinaryOpNode { Mul(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Mul; }; /** * The ratio of two expressions. */ struct Div : public BinaryOpNode
{ Div(Expr a, Expr b) : BinaryOpNode
(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Div; }; /** * The mod of two expressions. */ struct Mod : public BinaryOpNode { Mod(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Mod; }; /** * The lesser of two expressions. */ struct Min : public BinaryOpNode { Min(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Min; }; /** * The larger of two expressions. */ struct Max : public BinaryOpNode { Max(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Max; }; /** * Tell whether the first expression equals to the second expression. */ struct EQ : public BinaryOpNode { EQ(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::EQ; }; /** * Tell whether the first expression not equals to the second expression. */ struct NE : public BinaryOpNode { NE(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::NE; }; /** * Tell whether the first expression is lower than the second expression. */ struct LT : public BinaryOpNode { LT(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::LT; }; /** * Tell whether the first expression is no larger than the second expression. */ struct LE : public BinaryOpNode { LE(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::LE; }; /** * Tell whether the first expression is larger than the second expression. */ struct GT : public BinaryOpNode { GT(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::GT; }; /** * Tell whether the first expression is not less than the second expression. */ struct GE : public BinaryOpNode { GE(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) {} Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::GE; }; /** * Logical and. */ struct And : public BinaryOpNode { And(Expr a, Expr b) : BinaryOpNode(a.type(), a, b) { CHECK(a->type().is_bool()); CHECK(b->type().is_bool()); } Type type() const { return Bool(a()->type().lanes()); } static Expr Make(Expr a, Expr b); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::And; }; /** * -x */ struct Minus : public UnaryOpNode { explicit Minus(Expr x) : UnaryOpNode(x.type(), x) {} static Expr Make(Expr a); void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Minus; }; /** * Logical or. */ struct Or : public BinaryOpNode { Or(Expr a, Expr b) : BinaryOpNode(Bool(), a, b) { CHECK(a->type().is_bool()); CHECK(b->type().is_bool()); } static Expr Make(Expr a, Expr b); Type type() const override; void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Or; }; /** * Logical not. */ struct Not : public UnaryOpNode { explicit Not(Expr v) : UnaryOpNode(Bool(), v) {} static Expr Make(Expr v); Type type() const override; void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Not; }; struct Let : public ExprNode { Expr symbol; Expr body; static Expr Make(Expr symbol, Expr body); Type type() const override; void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Let; std::vector expr_fields() override { if (!body.defined()) return {&symbol}; return {&symbol, &body}; } std::vector expr_fields() const override { if (!body.defined()) return {&symbol}; return {&symbol, &body}; } }; enum CallType : int { //! Extern "C" function. Extern = 0, //! CINN-style call, call a CINN function. CINN, //! Intrinsic functions. Intrinsic, //! Generated from ISL Ast. ISL, }; struct Call : public ExprNode { explicit Call(Type t) : ExprNode(t) {} //! The name of the function/intrinsic. std::string name; //! The arguments. std::vector read_args; std::vector write_args; //! the attribute of this CallNode. std::map attrs; //! Type of calls. CallType call_type; //! The function to be called. FunctionRef func; //! The output value index if func's value is a tuple. int value_index{-1}; static Expr Make(Type type, const std::string& name, const std::vector& read_args, const std::vector& write_args, CallType call_type, FunctionRef func = FunctionRef(), int value_index = 0, const std::map& attrs = {}); void Verify() const override; inline size_t total_args_count() const { return read_args.size() + write_args.size(); } inline bool is_extern_call() const { return call_type == CallType::Extern; } inline bool is_cinn_call() const { return call_type == CallType::CINN; } inline bool is_intrinsic_call() const { return call_type == CallType::Intrinsic; } inline bool is_isl_call() const { return call_type == CallType::ISL; } std::vector expr_fields() override; std::vector expr_fields() const override; static const IrNodeTy _node_type_ = IrNodeTy::Call; }; /** * Variable used as iterator value or bound definition. */ struct _Var_ : public ExprNode<_Var_> { std::string name; bool is_reduce_axis{false}; //! Lower bound and upper bound of a axis. // @{ Expr lower_bound; Expr upper_bound; // @} // ! Extra tag of this variable/axis. std::string tag; _Var_() = default; _Var_(const std::string& name, Type type) : ExprNode<_Var_>(type), name(name) {} static Expr Make(const std::string& name, const Type& type); //! Make a reduce axis. static Expr Make(Expr lower_bound, Expr upper_bound, const std::string& name, bool is_reduce); void Verify() const override; Expr Copy() const override; static const IrNodeTy _node_type_ = IrNodeTy::_Var_; }; //! A named variable. struct Var : public IrNodeRef { Var() = default; explicit Var(IrNode* n) : IrNodeRef(n) {} explicit Var(const std::string& name_hint, Type t = type_of()) : Var(_Var_::Make(name_hint, t).ptr()) {} Var(Expr lower_bound, Expr upper_bound, const std::string& name, bool is_reduce = false) : Var(_Var_::Make(lower_bound, upper_bound, name, is_reduce)) {} Var(int upper_bound, const std::string& name) : Var(_Var_::Make(Expr(0), Expr(upper_bound), name, false)) {} Var(Expr upper_bound, const std::string& name) : Var(_Var_::Make(Expr(0), upper_bound, name, false)) {} operator Expr() { return Expr(get()); } operator Expr() const { Var v = *this; return Expr(v); } bool operator==(const Var& o) const; bool operator!=(const Var& o) const; Var& operator=(_Var_* x); Var& operator=(const _Var_* x); const _Var_* operator->() const { return get(); } _Var_* operator->() { return get(); } const _Var_* get() const { return static_cast(ptr()); } _Var_* get() { return static_cast<_Var_*>(ptr()); } }; struct Reduce : public ExprNode { enum ReduceType { kSum = 0, kSub, kMul, kDiv, kMax, kMin, kAll, kAny, }; //! The initial value. Expr init; // ! The body. Expr body; utils::SmallVector reduce_axis; //! The type of the reduce operation. ReduceType reduce_type; static Expr Make(ReduceType reduce_type, Expr init, Expr body, const std::vector& reduce_aixs); Type type() const override { return body.type().ElementOf(); } std::vector expr_fields() override; std::vector expr_fields() const override; void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Reduce; }; /** * Evaluates `true_value` and `false_value` then selects between them based on `condition`. */ struct Select : public ExprNode(true_value.type()), condition(condition), true_value(true_value), false_value(false_value) { CHECK_EQ(true_value.type(), false_value.type()); CHECK(condition.type().is_bool()); } static Expr Make(Expr condition, Expr true_value, Expr false_value) { auto node = make_shared