// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/ir/builder.h" #include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_type.h" #include "paddle/ir/dialect.h" #include "paddle/ir/ir_context.h" #include "paddle/ir/op_base.h" /// \brief Define built-in Trait, derived from OpTraitBase. class ReadOnlyTrait : public ir::OpTraitBase { public: explicit ReadOnlyTrait(ir::Operation *op) : ir::OpTraitBase(op) {} }; /// \brief Define built-in Interface, derived from OpInterfaceBase. Concepts and /// Models need to be defined within the class. Concept defines abstract /// interface functions, and Model is a template class that defines the specific /// implementation of interface functions based on template parameters. class InferShapeInterface : public ir::OpInterfaceBase { public: struct Concept { explicit Concept(void (*infer_shape)(ir::Operation *)) : infer_shape_(infer_shape) {} void (*infer_shape_)(ir::Operation *); }; template struct Model : public Concept { static void InferShape(ir::Operation *op) { ConcreteOp concret_op = ConcreteOp(op); if (concret_op == nullptr) throw("concret_op is nullptr"); concret_op.InferShape(); } Model() : Concept(InferShape) { static_assert(sizeof(Model) == sizeof(Concept), "sizeof(Model) != sizeof(Concept)"); } }; InferShapeInterface(ir::Operation *op, Concept *impl) : ir::OpInterfaceBase(op), impl_(impl) {} void InferShape() { impl_->infer_shape_(operation()); } private: Concept *impl_; }; ir::AttributeMap CreateAttributeMap(std::vector attribute_names, std::vector attributes) { ir::IrContext *ctx = ir::IrContext::Instance(); ir::AttributeMap attr_map; for (size_t i = 0; i < attribute_names.size(); i++) { ir::Attribute attr_value = ir::StrAttribute::get(ctx, attributes[i]); attr_map.insert( std::pair(attribute_names[i], attr_value)); } return attr_map; } // Define op1. class Operation1 : public ir::Op { public: using Op::Op; static const char *name() { return "test.operation1"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { if (attributes.count("op1_attr1") == 0 || !attributes.at("op1_attr1").isa()) { throw("Type of attribute: parameter_name is not right."); } if (attributes.count("op1_attr2") == 0 || !attributes.at("op1_attr2").isa()) { throw("Type of attribute: parameter_name is not right."); } } static void build(const ir::Builder &builder, ir::OperationArgument &argument) { // NOLINT std::vector inputs = {}; std::vector output_types = { ir::Float32Type::get(builder.context())}; std::unordered_map attributes = CreateAttributeMap({"op1_attr1", "op1_attr2"}, {"op1_attr1", "op1_attr2"}); argument.addOperands::iterator>(inputs.begin(), inputs.end()); argument.addTypes::iterator>(output_types.begin(), output_types.end()); argument.addAttributes< std::unordered_map::iterator>( attributes.begin(), attributes.end()); } }; const char *Operation1::attributes_name[attributes_num] = {"op1_attr1", "op1_attr2"}; // Define op2. class Operation2 : public ir::Op { public: using Op::Op; static const char *name() { return "test.operation2"; } static constexpr uint32_t attributes_num = 2; static const char *attributes_name[attributes_num]; static void verify(const std::vector &inputs, const std::vector &outputs, const ir::AttributeMap &attributes) { if (attributes.count("op2_attr1") == 0 || (!attributes.at("op2_attr1").isa())) { throw("Type of attribute: parameter_name is not right."); } if (attributes.count("op2_attr2") == 0 || (!attributes.at("op2_attr2").isa())) { throw("Type of attribute: parameter_name is not right."); } } static void InferShape() { VLOG(0) << "This is op2's InferShape interface."; } }; const char *Operation2::attributes_name[attributes_num] = {"op2_attr1", "op2_attr2"}; // Define a dialect, op1 and op2 will be registered by this dialect. class TestDialect : public ir::Dialect { public: explicit TestDialect(ir::IrContext *context) : ir::Dialect(name(), context, ir::TypeId::get()) { initialize(); } static const char *name() { return "test"; } private: void initialize() { RegisterOps(); } }; TEST(op_test, op_test) { // (1) Register Dialect, Operation1, Operation2 into IrContext. ir::IrContext *ctx = ir::IrContext::Instance(); ir::Dialect *test_dialect = ctx->GetOrRegisterDialect(); EXPECT_EQ(test_dialect != nullptr, true); // (2) Get registered operations. std::string op1_name = Operation1::name(); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name); EXPECT_EQ(op1_info != nullptr, true); std::string op2_name = Operation2::name(); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); EXPECT_EQ(op2_info != nullptr, true); EXPECT_EQ(op1_info.HasTrait(), false); EXPECT_EQ(op1_info.HasInterface(), false); EXPECT_EQ(op2_info.HasTrait(), true); EXPECT_EQ(op2_info.HasInterface(), true); // (3) Test uses for op. std::vector op_inputs = {}; std::vector op_output_types = {ir::Float32Type::get(ctx)}; ir::Operation *op2 = ir::Operation::create(op_inputs, op_output_types, CreateAttributeMap({"op2_attr1", "op2_attr2"}, {"op2_attr1", "op2_attr2"}), op2_info); ReadOnlyTrait trait = op2->dyn_cast(); EXPECT_EQ(trait.operation(), op2); InferShapeInterface interface = op2->dyn_cast(); interface.InferShape(); Operation2 Op2 = op2->dyn_cast(); EXPECT_EQ(Op2.operation(), op2); op2->destroy(); }