diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index 89ac8b0157f636bc7eba7b1e86be14288a17c423..106229e141ecb0217c7cec824d74fc19e3075d44 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -1193,8 +1193,8 @@ def OpGenerator( # generate get op info funciton: inputs inputs_info_str = "" + input_info_list = [] if len(op_input_name_list) > 0: - input_info_list = [] for idx in range(len(op_input_name_list)): input_info_list.append( CONSTRUCT_INPUT_INFO_TEMPLATE.format( @@ -1204,7 +1204,19 @@ def OpGenerator( no_need_buffer=op_input_no_need_buffer_list[idx], ) ) - inputs_info_str = ", ".join(input_info_list) + + # add mutable attribute as input + if len(op_mutable_attribute_name_list) > 0: + for idx in range(len(op_mutable_attribute_name_list)): + input_info_list.append( + CONSTRUCT_INPUT_INFO_TEMPLATE.format( + name=op_mutable_attribute_name_list[idx], + typename=op_mutable_attribute_type_list[idx], + optional='false', + no_need_buffer='false', + ) + ) + inputs_info_str = ", ".join(input_info_list) # generate get op info funciton: outputs outputs_info_str = "" @@ -1223,12 +1235,16 @@ def OpGenerator( # generate get op info funciton: attributes attribute_info_str = "" + op_mutable_attribute_name_set = set(op_mutable_attribute_name_list) if len(op_attribute_name_list) > 0: attribute_info_list = [] for idx in range(len(op_attribute_name_list)): + attribute_name = op_attribute_name_list[idx] + if attribute_name in op_mutable_attribute_name_set: + continue attribute_info_list.append( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( - name=op_attribute_name_list[idx], + name=attribute_name, typename=op_attribute_type_list[idx], data_type=op_attribute_data_type_list[idx], ) diff --git a/paddle/fluid/ir/dialect/pd_dialect.cc b/paddle/fluid/ir/dialect/pd_dialect.cc index d7b4b599b55fe613962ddee955d7ff10c6fa4524..b347d85d2a1cf0c9789fb56f9721b61825651488 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.cc +++ b/paddle/fluid/ir/dialect/pd_dialect.cc @@ -23,6 +23,8 @@ #include "paddle/fluid/ir/dialect/pd_type_storage.h" #include "paddle/fluid/ir/dialect/utils.h" #include "paddle/ir/core/dialect_interface.h" +#include "paddle/ir/core/utils.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/dense_tensor.h" namespace paddle { @@ -107,7 +109,7 @@ void PaddleDialect::initialize() { RegisterInterfaces(); } -void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { +void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const { DenseTensorType tensor_type = type.dyn_cast(); os << "tensor<"; @@ -119,5 +121,27 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { os << ">"; } +void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const { + if (auto int_array_attr = attr.dyn_cast()) { + phi::IntArray data = int_array_attr.data(); + os << "IntArray["; + const auto &inner_data = data.GetData(); + ir::PrintInterleave( + inner_data.begin(), + inner_data.end(), + [&os](int64_t i) { os << i; }, + [&os]() { os << ","; }); + os << "]"; + } else if (auto data_type_attr = attr.dyn_cast()) { + os << data_type_attr.data(); + } else if (auto place_type_attr = attr.dyn_cast()) { + os << place_type_attr.data(); + } else if (auto data_layout_attr = attr.dyn_cast()) { + os << data_layout_attr.data(); + } else { + os << "<#AttrNotImplemented>"; + } +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/ir/dialect/pd_dialect.h b/paddle/fluid/ir/dialect/pd_dialect.h index 069827bedcf9a1b1a3ee64b7f8da946c2735631c..b8782c156d88517257bb1e7997714c29509d547e 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.h +++ b/paddle/fluid/ir/dialect/pd_dialect.h @@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect { static const char* name() { return "pd"; } - void PrintType(ir::Type type, std::ostream& os); + void PrintType(ir::Type type, std::ostream& os) const; + void PrintAttribute(ir::Attribute type, std::ostream& os) const; private: void initialize(); diff --git a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt index c2a66c1e71318f30d73914b1eac059a5f59d0b6b..2f0014c69f74c42019e0674bde6f264200485ad6 100644 --- a/paddle/fluid/ir_adaptor/translator/CMakeLists.txt +++ b/paddle/fluid/ir_adaptor/translator/CMakeLists.txt @@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py) set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc) +set(op_compat_templat_file + ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc.j2) add_custom_command( OUTPUT ${op_compat_source_file} COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file ${op_compat_yaml_file} --output_source_file ${op_compat_source_file} - DEPENDS ${op_gen_file} ${op_compat_yaml_file} + DEPENDS ${op_gen_file} ${op_compat_yaml_file} ${op_compat_templat_file} VERBATIM) file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py index 5bc9df7ee8b34b40482492baae03e664ba14d788..5a852754aed1ec2a918a2102ea5cdb52c1f3e698 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_gen.py +++ b/paddle/fluid/ir_adaptor/translator/op_compat_gen.py @@ -14,7 +14,7 @@ import argparse from pathlib import Path -from typing import Dict +from typing import Dict, List, Set import yaml from jinja2 import Environment, FileSystemLoader, StrictUndefined @@ -46,8 +46,11 @@ def OpNameNormalizerInitialization( with open(op_compat_yaml_file, "r") as f: op_compat_infos = yaml.safe_load(f) - op_name_mappings = {} - op_arg_name_mappings = {} + op_name_mappings: Dict[str, str] = {} + op_arg_name_mappings: Dict[str, Dict[str, str]] = {} + op_mutable_attribues: Dict[str, Set[str]] = {} + op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {} + for op_compat_item in op_compat_infos: def insert_new_mappings(op_name_str: str) -> str: @@ -64,6 +67,23 @@ def OpNameNormalizerInitialization( op_arg_name_mappings[op_name] = {} op_arg_name_mappings[op_name].update(arg_mapping) + def insert_new_mutable_attributes( + op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]] + ): + op_mutable_attribues[op_name] = set() + op_mutable_attribute_infos[op_name] = {} + for ( + attribute_name, + mutable_attribute_info, + ) in mutable_attribute_infos.items(): + op_mutable_attribues[op_name].add(attribute_name) + op_mutable_attribute_infos[op_name][attribute_name] = [] + for k, v in mutable_attribute_info.items(): + if k == 'tensor_name' or k == 'tensors_name': + op_mutable_attribute_infos[op_name][ + attribute_name + ].append(v) + _, legacy_name = insert_new_mappings(op_compat_item["op"]) legacy_backward_op_names = [] if "backward" in op_compat_item: @@ -88,6 +108,14 @@ def OpNameNormalizerInitialization( for backward_op in legacy_backward_op_names: insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) + if "int_array" in op_compat_item: + insert_new_mutable_attributes( + legacy_name, op_compat_item["int_array"] + ) + + if "scalar" in op_compat_item: + insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"]) + # special op mappings op_name_mappings["fetch_v2"] = "fetch" @@ -96,6 +124,8 @@ def OpNameNormalizerInitialization( op_compat_definition = op_name_normailzer_template.render( op_name_pairs=op_name_mappings, op_arg_name_pairs=op_arg_name_mappings, + op_mutable_attributes=op_mutable_attribues, + op_mutable_attribute_infos=op_mutable_attribute_infos, ) f.write(op_compat_definition) diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 b/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 index bfc80986a34c988162494b105e47273fbc8577b6..e7b7812fe61bead9d0c29a602764c039f01d0083 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.cc.j2 @@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() { }, {% endfor %} }; + op_mutable_attributes = { + {% for op_name, mutable_attributes in op_mutable_attributes.items() %} + { + "{{op_name}}", + { + {% for attribute_name in mutable_attributes %} + "{{attribute_name}}", + {% endfor %} + }, + }, + {% endfor %} + }; + op_mutable_attribute_infos = { + {% for op_name, mutable_attribute_infos in op_mutable_attribute_infos.items() %} + { + "{{op_name}}", + { + {% for attribute_name, attribute_info in mutable_attribute_infos.items() %} + { + "{{attribute_name}}", + { + {% for candidate_var_name in attribute_info %} + "{{candidate_var_name}}", + {% endfor %} + }, + }, + {% endfor %} + }, + }, + {% endfor %} + }; } } // namespace translator diff --git a/paddle/fluid/ir_adaptor/translator/op_compat_info.h b/paddle/fluid/ir_adaptor/translator/op_compat_info.h index f2ccba28eb72d6c74a5ade754db3551864dfcde5..799e62c7544e3fc89435c66873f733122ae9deb9 100644 --- a/paddle/fluid/ir_adaptor/translator/op_compat_info.h +++ b/paddle/fluid/ir_adaptor/translator/op_compat_info.h @@ -15,6 +15,7 @@ #include #include #include +#include #include "glog/logging.h" @@ -25,6 +26,8 @@ namespace paddle { namespace translator { +using MutableAttributeInfo = std::vector; + class OpNameNormalizer { private: OpNameNormalizer(); // Disallow instantiation outside of the class. @@ -32,6 +35,12 @@ class OpNameNormalizer { std::unordered_map> op_arg_name_mappings; + std::unordered_map> + op_mutable_attribute_infos; + std::unordered_map> + op_mutable_attributes; + public: OpNameNormalizer(const OpNameNormalizer&) = delete; OpNameNormalizer& operator=(const OpNameNormalizer&) = delete; @@ -50,6 +59,21 @@ class OpNameNormalizer { return op_name_mappings.at(op_type); } + bool HasMutableAttribute(const std::string& op_type) { + return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end()); + } + + const std::unordered_set* GetMutableAttributes( + const std::string& op_type) { + if (!HasMutableAttribute(op_type)) return nullptr; + return &op_mutable_attributes.at(op_type); + } + + const MutableAttributeInfo& GetMutableAttributeInfos( + const std::string& op_type, const std::string& arg_name) { + return op_mutable_attribute_infos.at(op_type).at(arg_name); + } + std::string GetLegacyArgName(const std::string& op_type, const std::string& arg_name) { bool is_grad_op = (op_type.find("grad") != std::string::npos); diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 12794e7579e5c0dbda58942242b75f62eb03e944..3e2c8117897ceea5fa039511be27a27bfef6fd85 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -23,17 +23,23 @@ #include #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/ir/dialect/pd_type.h" #include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" +#include "paddle/ir/core/builder.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_type.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/value.h" -#include "paddle/phi/core/enforce.h" + +// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in +// paddle/fluid/ir/dialect/CMakeLists.txt. +#include "paddle/fluid/ir/dialect/pd_op.h" namespace paddle { namespace translator { @@ -66,8 +72,13 @@ inline bool IsInplace(const OpDesc& op_desc) { } auto input_names = op_desc.InputArgumentNames(); auto output_names = op_desc.OutputArgumentNames(); + if (input_names.size() == 0 || output_names.size() == 0) { + return inplace; + } std::vector name_intersection; + std::sort(input_names.begin(), input_names.end()); + std::sort(output_names.begin(), output_names.end()); std::set_intersection(input_names.begin(), input_names.end(), output_names.begin(), @@ -103,10 +114,9 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { << target_op_name; auto op_info = ctx->GetRegisteredOpInfo(target_op_name); if (!op_info) { - PADDLE_THROW(platform::errors::PreconditionNotMet( - "Op %d should have corresponding OpInfo %d", - op_desc.Type(), - target_op_name)); + IR_THROW("Op %d should have corresponding OpInfo %d", + op_desc.Type(), + target_op_name); } return op_info; @@ -158,18 +168,86 @@ inline ir::Operation* InsertCombineOperationForTarget( return operation; } -inline ir::Operation* InsertConstantOperationForOptionalArg( - ir::IrContext* ctx, ir::Program* program) { +inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx, + ir::Program* program, + ir::Attribute attr) { + float data = 0.0f; + phi::DataType dtype = phi::DataType::UNDEFINED; + if (attr.isa()) { + data = attr.dyn_cast().data(); + dtype = phi::DataType::FLOAT32; + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::FLOAT64; + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::INT32; + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::INT64; + } else if (attr.isa()) { + data = static_cast(attr.dyn_cast().data()); + dtype = phi::DataType::BOOL; + } + ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block()); + paddle::dialect::FullOp full_op = builder.Build( + std::vector{1}, data, dtype, phi::CPUPlace()); + + return full_op.operation(); +} + +inline ir::Operation* InsertFullArrayOperationForAttributeInput( + ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) { std::string constant_op_name(ir::ConstantOp::name()); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); - ir::Type null_type = ir::Type(nullptr); + ir::Type null_type = paddle::dialect::DenseTensorType::get( + ctx, + ir::Type(nullptr), + phi::DDim{}, + paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED, + phi::LoD{}, + 0); // TODO(lyk): to be done ir::Operation* operation = - ir::Operation::Create({}, {}, {null_type}, op_info); + ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info); program->block()->push_back(operation); return operation; } +inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx, + ir::Program* program, + const OpDesc& op_desc, + const OpInputInfo& input_info) { + auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); + + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name); + + if (!op_desc.HasAttr(legacy_attr_name)) { + IR_THROW("Op %s arg %s should not be zero size", + op_desc.Type(), + legacy_attr_name); + } + paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name); + VLOG(10) << "[" << op_desc.Type() << "][attribute]" + << " name: " << legacy_attr_name << " " << legacy_attr.index(); + ir::Attribute new_attr = + attribute_translator(input_info.type_name, legacy_attr); + + ir::Operation* defining_op = nullptr; + bool is_int_array = (input_info.type_name.find("IntArrayAttribute") != + input_info.type_name.npos); + if (is_int_array) { + defining_op = + InsertFullArrayOperationForAttributeInput(ctx, program, new_attr); + } else { + defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr); + } + + return defining_op->GetResultByIndex(0); +} + inline std::vector GenerateOperationInput( ir::IrContext* ctx, TranslationContext* param_map, @@ -184,14 +262,11 @@ inline std::vector GenerateOperationInput( auto& args = n.second; for (const auto& arg_name : args) { - PADDLE_ENFORCE_NE( - param_map->count(arg_name), - 0, - platform::errors::PreconditionNotMet( - "arg %s.%s as input should be exists before prasing %s", - name, - arg_name, - op_desc.Type())); + IR_ENFORCE(param_map->count(arg_name) != 0, + "arg %s.%s as input should be exists before prasing %s", + name, + arg_name, + op_desc.Type()); auto defining_info = (*param_map)[arg_name]; if (defining_info.generated_by_vector) { InsertSliceOperationForTarget( @@ -202,25 +277,59 @@ inline std::vector GenerateOperationInput( std::vector op_inputs; auto& op_normalizer = OpNameNormalizer::instance(); + const auto* mutable_attributes = + op_normalizer.GetMutableAttributes(op_desc.Type()); for (const auto& info : input_infos) { std::string legacy_input_name = op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << legacy_input_name; + + std::vector legacy_input_vars; // return empty OpResult if this arg is optional and not shown in OpDesc // TODO(lyk): HasInput doesnot consider variadic attribute - if (!op_desc.HasInput(legacy_input_name)) { - PADDLE_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_input_name)); - op_inputs.push_back(ir::OpResult(nullptr)); - continue; + if (op_desc.HasInput(legacy_input_name)) { + legacy_input_vars = op_desc.Input(legacy_input_name, true); + } + + if (legacy_input_vars.size() == 0) { + if (info.optional) { + op_inputs.push_back(ir::OpResult(nullptr)); + continue; + } + } + + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << legacy_input_name << " " << legacy_input_vars.size(); + + if (legacy_input_vars.size() == 0 && mutable_attributes != nullptr && + mutable_attributes->count(info.name) != 0) { + const auto& candidate_var_names = + op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name); + bool found_candidate_var = false; + for (const auto& var_name : candidate_var_names) { + VLOG(10) << "[handle mutable attribute][" << info.name << "][" + << var_name << "]"; + if (op_desc.HasInput(var_name)) { + legacy_input_vars = op_desc.Input(var_name, true); + if (legacy_input_vars.size() == 0) continue; + found_candidate_var = true; + break; + } + } + + if (!found_candidate_var) { + auto attribute_input = GetAttributeAsInput(ctx, program, op_desc, info); + op_inputs.push_back(attribute_input); + continue; + } } - const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true); bool is_vector = (info.type_name.find("VectorType") != std::string::npos); + VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " " + << is_vector << " " << info.type_name; // if src type is Tensor if (!is_vector) { @@ -262,11 +371,10 @@ inline std::tuple GenerateOperationOutput( VLOG(10) << "[output translating]" << "[" << op_desc.Type() << "] optional " << info.name << " :" << info.type_name << " " << legacy_output_name; - PADDLE_ENFORCE(info.optional, - platform::errors::PreconditionNotMet( - "Op %s arg %s should be optional if it can be empty", - op_desc.Type(), - legacy_output_name)); + IR_ENFORCE(info.optional, + "Op %s arg %s should be optional if it can be empty", + op_desc.Type(), + legacy_output_name); op_output_types.push_back(ir::Type(nullptr)); continue; } diff --git a/paddle/ir/core/attribute.h b/paddle/ir/core/attribute.h index 4f269187b751b3e8ab2ac42652dd975a68743cdf..ea7b0f5daae811fc47550a3fca2f102305c9b7b7 100644 --- a/paddle/ir/core/attribute.h +++ b/paddle/ir/core/attribute.h @@ -60,6 +60,10 @@ class Attribute { IrContext *ir_context() const; + /// @brief print attribute + /// @param os + void Print(std::ostream &os) const; + /// /// \brief Methods for type judgment and cast. /// @@ -80,6 +84,8 @@ class Attribute { protected: const Storage *storage_{nullptr}; }; + +std::ostream &operator<<(std::ostream &os, Attribute attr); } // namespace ir namespace std { diff --git a/paddle/ir/core/dialect.h b/paddle/ir/core/dialect.h index 3421b9d942f6ddbae7eb6ec8a1ef1b68a5e21797..1eabc8010d670dc71f93387b2d1f497d01c672cb 100644 --- a/paddle/ir/core/dialect.h +++ b/paddle/ir/core/dialect.h @@ -16,8 +16,10 @@ #include +#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/dialect_interface.h" +#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/op_base.h" #include "paddle/ir/core/type_base.h" @@ -33,15 +35,15 @@ class DialectInterface; /// class Dialect { public: - Dialect(std::string name, ir::IrContext *context, ir::TypeId id); + Dialect(std::string name, IrContext *context, TypeId id); virtual ~Dialect(); const std::string &name() const { return name_; } - ir::IrContext *ir_context() const { return context_; } + IrContext *ir_context() const { return context_; } - ir::TypeId id() const { return id_; } + TypeId id() const { return id_; } /// /// \brief Register all types contained in the template parameter Args. @@ -130,8 +132,12 @@ class Dialect { return *interface; } - virtual void PrintType(ir::Type type, std::ostream &os) { - throw std::logic_error("dialect has no registered type printing hook"); + virtual void PrintType(Type type, std::ostream &os) const { + IR_THROW("dialect has no registered type printing hook"); + } + + virtual void PrintAttribute(Attribute type, std::ostream &os) const { + IR_THROW("dialect has no registered attribute printing hook"); } private: @@ -141,9 +147,9 @@ class Dialect { std::string name_; - ir::IrContext *context_; // not owned + IrContext *context_; // not owned - ir::TypeId id_; + TypeId id_; std::unordered_map> registered_interfaces_; diff --git a/paddle/ir/core/enforce.h b/paddle/ir/core/enforce.h index b5c48c22a83dc91ddf4764b90390375879da2d35..e87ac0c41a07ee19558401858f5269587fe3c039 100644 --- a/paddle/ir/core/enforce.h +++ b/paddle/ir/core/enforce.h @@ -17,6 +17,8 @@ #include #include +#include "paddle/utils/string/printf.h" + #if !defined(_WIN32) #define UNLIKELY(condition) __builtin_expect(static_cast(condition), 0) #else @@ -37,27 +39,35 @@ class IrNotMetException : public std::exception { std::string err_str_; }; -#define IR_THROW(...) \ - do { \ - try { \ - throw ir::IrNotMetException(__VA_ARGS__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ +#define IR_THROW(...) \ + do { \ + try { \ + throw ir::IrNotMetException( \ + paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ + __FILE__, \ + __LINE__, \ + paddle::string::Sprintf(__VA_ARGS__))); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ } while (0) -#define IR_ENFORCE(COND, ...) \ - do { \ - auto __cond__ = (COND); \ - if (UNLIKELY(is_error(__cond__))) { \ - try { \ - throw ir::IrNotMetException(__VA_ARGS__); \ - } catch (const std::exception& e) { \ - std::cout << e.what() << std::endl; \ - throw; \ - } \ - } \ +#define IR_ENFORCE(COND, ...) \ + do { \ + auto __cond__ = (COND); \ + if (UNLIKELY(is_error(__cond__))) { \ + try { \ + throw ir::IrNotMetException( \ + paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \ + __FILE__, \ + __LINE__, \ + paddle::string::Sprintf(__VA_ARGS__))); \ + } catch (const std::exception& e) { \ + std::cout << e.what() << std::endl; \ + throw; \ + } \ + } \ } while (0) } // namespace ir diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index a7a962bd204294e9b3776aa9c94b7f07331aa88a..fd0a41fbdae3092fef9337072f55d48ff445ac10 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -23,69 +23,86 @@ #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/operation.h" #include "paddle/ir/core/program.h" +#include "paddle/ir/core/utils.h" #include "paddle/ir/core/value.h" namespace ir { namespace { constexpr char newline[] = "\n"; - -template -void PrintInterleave(ForwardIterator begin, - ForwardIterator end, - UnaryFunctor print_func, - NullFunctor between_func) { - if (begin == end) return; - print_func(*begin); - begin++; - for (; begin != end; begin++) { - between_func(); - print_func(*begin); - } -} - } // namespace class BasicIRPrinter { public: explicit BasicIRPrinter(std::ostream& os) : os(os) {} - void PrintType(ir::Type type) { + void PrintType(Type type) { if (!type) { os << "<>"; return; } - if (type.isa()) { + if (type.isa()) { os << "f16"; - } else if (type.isa()) { + } else if (type.isa()) { os << "f32"; - } else if (type.isa()) { + } else if (type.isa()) { os << "f64"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i16"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i32"; - } else if (type.isa()) { + } else if (type.isa()) { os << "i64"; - } else if (type.isa()) { - os << "vec<"; - auto inner_types = type.dyn_cast().data(); + } else if (type.isa()) { + os << "vec["; + auto inner_types = type.dyn_cast().data(); PrintInterleave( inner_types.begin(), inner_types.end(), - [this](ir::Type v) { this->PrintType(v); }, - [this]() { this->os << ", "; }); - os << ">"; + [this](Type v) { this->PrintType(v); }, + [this]() { this->os << ","; }); + os << "]"; } else { auto& dialect = type.dialect(); dialect.PrintType(type, os); } } - void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE }"; } + void PrintAttribute(const Attribute& attr) { + if (!attr) { + os << "<#AttrNull>"; + return; + } - protected: + if (auto s = attr.dyn_cast()) { + os << s.data(); + } else if (auto b = attr.dyn_cast()) { + os << b.data(); + } else if (auto f = attr.dyn_cast()) { + os << f.data(); + } else if (auto d = attr.dyn_cast()) { + os << d.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto i = attr.dyn_cast()) { + os << i.data(); + } else if (auto arr = attr.dyn_cast()) { + const auto& vec = arr.data(); + os << "array["; + PrintInterleave( + vec.begin(), + vec.end(), + [this](Attribute v) { this->PrintAttribute(v); }, + [this]() { this->os << ","; }); + os << "]"; + } else { + auto& dialect = attr.dialect(); + dialect.PrintAttribute(attr, os); + } + } + + public: std::ostream& os; }; @@ -96,14 +113,12 @@ class IRPrinter : public BasicIRPrinter { /// @brief print program /// @param program /// @example - void PrintProgram(ir::Program* program) { - PrintOperation(program->module_op()); - } + void PrintProgram(Program* program) { PrintOperation(program->module_op()); } /// @brief print operation /// @param op /// @example - void PrintOperation(ir::Operation* op) { + void PrintOperation(Operation* op) { for (size_t i = 0; i < op->num_regions(); ++i) { auto& region = op->GetRegion(i); for (auto it = region.begin(); it != region.end(); ++it) { @@ -120,7 +135,7 @@ class IRPrinter : public BasicIRPrinter { // TODO(lyk): add API to get operands directly PrintOpOperands(op); - PrintAttribute(op); + PrintAttributeMap(op); os << " :"; // PrintOpSingature @@ -138,7 +153,7 @@ class IRPrinter : public BasicIRPrinter { } private: - void PrintValue(ir::Value v) { + void PrintValue(Value v) { if (!v) { os << "<>"; return; @@ -156,10 +171,10 @@ class IRPrinter : public BasicIRPrinter { os << new_name; } - void PrintOpResult(ir::Operation* op) { + void PrintOpResult(Operation* op) { os << " ("; auto num_op_result = op->num_results(); - std::vector op_results; + std::vector op_results; op_results.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { op_results.push_back(op->GetResultByIndex(idx)); @@ -167,15 +182,31 @@ class IRPrinter : public BasicIRPrinter { PrintInterleave( op_results.begin(), op_results.end(), - [this](ir::Value v) { this->PrintValue(v); }, + [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintOpOperands(ir::Operation* op) { + void PrintAttributeMap(Operation* op) { + os << " {"; + + PrintInterleave( + op->attributes().begin(), + op->attributes().end(), + [this](std::pair it) { + this->os << it.first; + this->os << ":"; + this->PrintAttribute(it.second); + }, + [this]() { this->os << ","; }); + + os << "}"; + } + + void PrintOpOperands(Operation* op) { os << " ("; auto num_op_operands = op->num_operands(); - std::vector op_operands; + std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { op_operands.push_back(op->GetOperandByIndex(idx).source()); @@ -183,48 +214,48 @@ class IRPrinter : public BasicIRPrinter { PrintInterleave( op_operands.begin(), op_operands.end(), - [this](ir::Value v) { this->PrintValue(v); }, + [this](Value v) { this->PrintValue(v); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintOperandsType(ir::Operation* op) { + void PrintOperandsType(Operation* op) { auto num_op_operands = op->num_operands(); - std::vector op_operand_types; + std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { auto op_operand = op->GetOperandByIndex(idx); if (op_operand) { op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); } else { - op_operand_types.push_back(ir::Type(nullptr)); + op_operand_types.push_back(Type(nullptr)); } } os << " ("; PrintInterleave( op_operand_types.begin(), op_operand_types.end(), - [this](ir::Type t) { this->PrintType(t); }, + [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); os << ")"; } - void PrintOpReturnType(ir::Operation* op) { + void PrintOpReturnType(Operation* op) { auto num_op_result = op->num_results(); - std::vector op_result_types; + std::vector op_result_types; op_result_types.reserve(num_op_result); for (size_t idx = 0; idx < num_op_result; idx++) { auto op_result = op->GetResultByIndex(idx); if (op_result) { op_result_types.push_back(op_result.type()); } else { - op_result_types.push_back(ir::Type(nullptr)); + op_result_types.push_back(Type(nullptr)); } } PrintInterleave( op_result_types.begin(), op_result_types.end(), - [this](ir::Type t) { this->PrintType(t); }, + [this](Type t) { this->PrintType(t); }, [this]() { this->os << ", "; }); } @@ -248,4 +279,19 @@ void Type::Print(std::ostream& os) const { printer.PrintType(*this); } +void Attribute::Print(std::ostream& os) const { + BasicIRPrinter printer(os); + printer.PrintAttribute(*this); +} + +std::ostream& operator<<(std::ostream& os, Type type) { + type.Print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, Attribute attr) { + attr.Print(os); + return os; +} + } // namespace ir diff --git a/paddle/ir/core/type.cc b/paddle/ir/core/type.cc index e93d9f63e8c6f53a2ba54fe0a601874efebbfaf8..8b1451fa76fb72dca2182da98f5614e80cfe3916 100644 --- a/paddle/ir/core/type.cc +++ b/paddle/ir/core/type.cc @@ -17,10 +17,4 @@ namespace ir { IrContext* Type::ir_context() const { return dialect().ir_context(); } - -std::ostream& operator<<(std::ostream& os, Type type) { - type.Print(os); - return os; -} - } // namespace ir diff --git a/paddle/ir/core/utils.h b/paddle/ir/core/utils.h index f4316a7e57e446c3029c657e10f69a65149b3cd8..b619bc065fef57c4d6d1c65f7a04d16411ddb1b4 100644 --- a/paddle/ir/core/utils.h +++ b/paddle/ir/core/utils.h @@ -120,4 +120,18 @@ struct Filter { using Type = std::tuple<>; }; +template +void PrintInterleave(ForwardIterator begin, + ForwardIterator end, + UnaryFunctor print_func, + NullFunctor between_func) { + if (begin == end) return; + print_func(*begin); + begin++; + for (; begin != end; begin++) { + between_func(); + print_func(*begin); + } +} + } // namespace ir diff --git a/test/cpp/ir/core/program_translator_test.cc b/test/cpp/ir/core/program_translator_test.cc index 811fcf4c8b68f9c87e4830a8d2a0a4030b18f344..c8824a5f7c89982bdc4d7339664322dd0642ab2a 100644 --- a/test/cpp/ir/core/program_translator_test.cc +++ b/test/cpp/ir/core/program_translator_test.cc @@ -53,11 +53,13 @@ TEST(PaddleDialectTest, Translator) { ir::IrContext *ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ctx->GetOrRegisterDialect(); - // auto program = paddle::TranslateLegacyProgramToProgram(p); + auto program = paddle::TranslateLegacyProgramToProgram(p); - // size_t op_size = program->block()->size(); - // // ops.size() = op size in BlockDesc + get_parameter_op + combine op - // EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); + size_t op_size = program->block()->size(); + // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int + // array op + full op + EXPECT_EQ(op_size, + p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8); - // program->Print(std::cout); + program->Print(std::cout); }