diff --git a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py index 0f3411701c9fb5638851b4e8446a6f4710300eb6..30d35a5f6e7fb3548b1077200986d587c5f39488 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_member_func_gen.py @@ -18,18 +18,6 @@ OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_i """ OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} """ -OP_GET_ATTRIBUTE_TEMPLATE = """ ir::Attribute attribute(const std::string &name) { - PADDLE_ENFORCE(attributes().count(name) > 0, - phi::errors::PreconditionNotMet("Attribute is not exist.")); - return attributes().at(name); - } - template - T attribute(const std::string &name) { - PADDLE_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa(), - phi::errors::PreconditionNotMet("Attribute is not right.")); - return attributes().at(name).dyn_cast(); - } -""" def gen_op_get_inputs_outputs_str( @@ -51,5 +39,4 @@ def gen_op_get_inputs_outputs_str( output_name=op_output_name_list[idx], output_index=idx, ) - op_get_inputs_outputs_str += OP_GET_ATTRIBUTE_TEMPLATE return op_get_inputs_outputs_str diff --git a/paddle/ir/core/builtin_attribute_storage.h b/paddle/ir/core/builtin_attribute_storage.h index 7050555c1d0374b5e9c98b5fd1fe147761844cef..624abaf0047182a80df5bfc3abb6f6277527036a 100644 --- a/paddle/ir/core/builtin_attribute_storage.h +++ b/paddle/ir/core/builtin_attribute_storage.h @@ -135,7 +135,10 @@ struct ArrayAttributeStorage : public AttributeStorage { bool empty() const { return size_ == 0u; } Attribute at(size_t index) const { - IR_ENFORCE(index < size_, "Invalid index"); + IR_ENFORCE(index < size_, + "The index (%d) must be less than size (%d).", + index, + size_); return data_[index]; } diff --git a/paddle/ir/core/op_base.h b/paddle/ir/core/op_base.h index 16b2d3c9dec32063d051f69f8436da7445d9044d..f1a984a85f76f38faf1b634a2b1217a0e074728a 100644 --- a/paddle/ir/core/op_base.h +++ b/paddle/ir/core/op_base.h @@ -15,6 +15,7 @@ #pragma once #include +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/utils.h" @@ -68,25 +69,37 @@ class IR_API OpBase { public: explicit OpBase(Operation *operation = nullptr) : operation_(operation) {} - Operation *operation() const { return operation_; } + Operation *operation() const { + IR_ENFORCE(operation_, "Can't use operation() in a null op."); + return operation_; + } + + explicit operator bool() const { return operation_ != nullptr; } + + operator Operation *() const { return operation(); } - explicit operator bool() const { return operation() != nullptr; } + Operation *operator->() const { return operation(); } - operator Operation *() const { return operation_; } + IrContext *ir_context() const { return operation()->ir_context(); } - Operation *operator->() const { return operation_; } + uint32_t num_results() const { return operation()->num_results(); } - IrContext *ir_context() const { return operation_->ir_context(); } + uint32_t num_operands() const { return operation()->num_operands(); } - uint32_t num_results() const { return operation_->num_results(); } + const AttributeMap &attributes() const { return operation()->attributes(); } - uint32_t num_operands() const { return operation_->num_operands(); } + Value operand(uint32_t index) const { return operation()->operand(index); } - const AttributeMap &attributes() const { return operation_->attributes(); } + OpResult result(uint32_t index) const { return operation()->result(index); } - Value operand(uint32_t index) const { return operation_->operand(index); } + ir::Attribute attribute(const std::string &name) { + return operation()->attribute(name); + } - OpResult result(uint32_t index) const { return operation_->result(index); } + template + T attribute(const std::string &name) { + return operation()->attribute(name); + } private: Operation *operation_; // Not owned diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index d6f06229efe12c08fa34f64331819fbc941370ae..2a7b65f6f0e833a99a89bb8459327fc90ae8a81c 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -69,9 +69,9 @@ class IR_API alignas(8) Operation final { template T attribute(const std::string &name) { - IR_ENFORCE(attributes().count(name) > 0 && attributes().at(name).isa(), - "Attribute is not right."); - return attributes().at(name).dyn_cast(); + Attribute attr = attribute(name); + IR_ENFORCE(attr.isa(), "Attribute (%s) type is not right.", name); + return attr.dyn_cast(); } void set_attribute(const std::string &key, Attribute value) {