未验证 提交 51ca74bb 编写于 作者: K kangguangli 提交者: GitHub

[IR] IR attribute printer and support mutable attribute (#54369)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* refactor scalar attribute

* draft

* fix

* fix op build

* fix op build

* temporarily save

* adpat mutable attribute

* refine op_comat_gen process

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* complete dialect attribute printer and refine ir_throw

* polish code

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 fc880209
......@@ -1193,8 +1193,8 @@ def OpGenerator(
# generate get op info funciton: inputs
inputs_info_str = ""
if len(op_input_name_list) > 0:
input_info_list = []
if len(op_input_name_list) > 0:
for idx in range(len(op_input_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
......@@ -1204,6 +1204,18 @@ def OpGenerator(
no_need_buffer=op_input_no_need_buffer_list[idx],
)
)
# add mutable attribute as input
if len(op_mutable_attribute_name_list) > 0:
for idx in range(len(op_mutable_attribute_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
typename=op_mutable_attribute_type_list[idx],
optional='false',
no_need_buffer='false',
)
)
inputs_info_str = ", ".join(input_info_list)
# generate get op info funciton: outputs
......@@ -1223,12 +1235,16 @@ def OpGenerator(
# generate get op info funciton: attributes
attribute_info_str = ""
op_mutable_attribute_name_set = set(op_mutable_attribute_name_list)
if len(op_attribute_name_list) > 0:
attribute_info_list = []
for idx in range(len(op_attribute_name_list)):
attribute_name = op_attribute_name_list[idx]
if attribute_name in op_mutable_attribute_name_set:
continue
attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=op_attribute_name_list[idx],
name=attribute_name,
typename=op_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx],
)
......
......@@ -23,6 +23,8 @@
#include "paddle/fluid/ir/dialect/pd_type_storage.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/ir/core/dialect_interface.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
......@@ -107,7 +109,7 @@ void PaddleDialect::initialize() {
RegisterInterfaces<ParameterConvertInterface>();
}
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<";
......@@ -119,5 +121,27 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
os << ">";
}
void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
if (auto int_array_attr = attr.dyn_cast<IntArrayAttribute>()) {
phi::IntArray data = int_array_attr.data();
os << "IntArray[";
const auto &inner_data = data.GetData();
ir::PrintInterleave(
inner_data.begin(),
inner_data.end(),
[&os](int64_t i) { os << i; },
[&os]() { os << ","; });
os << "]";
} else if (auto data_type_attr = attr.dyn_cast<DataTypeAttribute>()) {
os << data_type_attr.data();
} else if (auto place_type_attr = attr.dyn_cast<PlaceAttribute>()) {
os << place_type_attr.data();
} else if (auto data_layout_attr = attr.dyn_cast<DataLayoutAttribute>()) {
os << data_layout_attr.data();
} else {
os << "<#AttrNotImplemented>";
}
}
} // namespace dialect
} // namespace paddle
......@@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; }
void PrintType(ir::Type type, std::ostream& os);
void PrintType(ir::Type type, std::ostream& os) const;
void PrintAttribute(ir::Attribute type, std::ostream& os) const;
private:
void initialize();
......
......@@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR
set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py)
set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc)
set(op_compat_templat_file
${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc.j2)
add_custom_command(
OUTPUT ${op_compat_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file
${op_compat_yaml_file} --output_source_file ${op_compat_source_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file} ${op_compat_templat_file}
VERBATIM)
file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc")
......
......@@ -14,7 +14,7 @@
import argparse
from pathlib import Path
from typing import Dict
from typing import Dict, List, Set
import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined
......@@ -46,8 +46,11 @@ def OpNameNormalizerInitialization(
with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f)
op_name_mappings = {}
op_arg_name_mappings = {}
op_name_mappings: Dict[str, str] = {}
op_arg_name_mappings: Dict[str, Dict[str, str]] = {}
op_mutable_attribues: Dict[str, Set[str]] = {}
op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {}
for op_compat_item in op_compat_infos:
def insert_new_mappings(op_name_str: str) -> str:
......@@ -64,6 +67,23 @@ def OpNameNormalizerInitialization(
op_arg_name_mappings[op_name] = {}
op_arg_name_mappings[op_name].update(arg_mapping)
def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
):
op_mutable_attribues[op_name] = set()
op_mutable_attribute_infos[op_name] = {}
for (
attribute_name,
mutable_attribute_info,
) in mutable_attribute_infos.items():
op_mutable_attribues[op_name].add(attribute_name)
op_mutable_attribute_infos[op_name][attribute_name] = []
for k, v in mutable_attribute_info.items():
if k == 'tensor_name' or k == 'tensors_name':
op_mutable_attribute_infos[op_name][
attribute_name
].append(v)
_, legacy_name = insert_new_mappings(op_compat_item["op"])
legacy_backward_op_names = []
if "backward" in op_compat_item:
......@@ -88,6 +108,14 @@ def OpNameNormalizerInitialization(
for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["outputs"])
if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
# special op mappings
op_name_mappings["fetch_v2"] = "fetch"
......@@ -96,6 +124,8 @@ def OpNameNormalizerInitialization(
op_compat_definition = op_name_normailzer_template.render(
op_name_pairs=op_name_mappings,
op_arg_name_pairs=op_arg_name_mappings,
op_mutable_attributes=op_mutable_attribues,
op_mutable_attribute_infos=op_mutable_attribute_infos,
)
f.write(op_compat_definition)
......
......@@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() {
},
{% endfor %}
};
op_mutable_attributes = {
{% for op_name, mutable_attributes in op_mutable_attributes.items() %}
{
"{{op_name}}",
{
{% for attribute_name in mutable_attributes %}
"{{attribute_name}}",
{% endfor %}
},
},
{% endfor %}
};
op_mutable_attribute_infos = {
{% for op_name, mutable_attribute_infos in op_mutable_attribute_infos.items() %}
{
"{{op_name}}",
{
{% for attribute_name, attribute_info in mutable_attribute_infos.items() %}
{
"{{attribute_name}}",
{
{% for candidate_var_name in attribute_info %}
"{{candidate_var_name}}",
{% endfor %}
},
},
{% endfor %}
},
},
{% endfor %}
};
}
} // namespace translator
......
......@@ -15,6 +15,7 @@
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "glog/logging.h"
......@@ -25,6 +26,8 @@
namespace paddle {
namespace translator {
using MutableAttributeInfo = std::vector<std::string>;
class OpNameNormalizer {
private:
OpNameNormalizer(); // Disallow instantiation outside of the class.
......@@ -32,6 +35,12 @@ class OpNameNormalizer {
std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
op_arg_name_mappings;
std::unordered_map<std::string,
std::unordered_map<std::string, MutableAttributeInfo>>
op_mutable_attribute_infos;
std::unordered_map<std::string, std::unordered_set<std::string>>
op_mutable_attributes;
public:
OpNameNormalizer(const OpNameNormalizer&) = delete;
OpNameNormalizer& operator=(const OpNameNormalizer&) = delete;
......@@ -50,6 +59,21 @@ class OpNameNormalizer {
return op_name_mappings.at(op_type);
}
bool HasMutableAttribute(const std::string& op_type) {
return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end());
}
const std::unordered_set<std::string>* GetMutableAttributes(
const std::string& op_type) {
if (!HasMutableAttribute(op_type)) return nullptr;
return &op_mutable_attributes.at(op_type);
}
const MutableAttributeInfo& GetMutableAttributeInfos(
const std::string& op_type, const std::string& arg_name) {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}
std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos);
......
......@@ -23,17 +23,23 @@
#include <vector>
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
namespace paddle {
namespace translator {
......@@ -66,8 +72,13 @@ inline bool IsInplace(const OpDesc& op_desc) {
}
auto input_names = op_desc.InputArgumentNames();
auto output_names = op_desc.OutputArgumentNames();
if (input_names.size() == 0 || output_names.size() == 0) {
return inplace;
}
std::vector<std::string> name_intersection;
std::sort(input_names.begin(), input_names.end());
std::sort(output_names.begin(), output_names.end());
std::set_intersection(input_names.begin(),
input_names.end(),
output_names.begin(),
......@@ -103,10 +114,9 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) {
<< target_op_name;
auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Op %d should have corresponding OpInfo %d",
IR_THROW("Op %d should have corresponding OpInfo %d",
op_desc.Type(),
target_op_name));
target_op_name);
}
return op_info;
......@@ -158,18 +168,86 @@ inline ir::Operation* InsertCombineOperationForTarget(
return operation;
}
inline ir::Operation* InsertConstantOperationForOptionalArg(
ir::IrContext* ctx, ir::Program* program) {
inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
ir::Program* program,
ir::Attribute attr) {
float data = 0.0f;
phi::DataType dtype = phi::DataType::UNDEFINED;
if (attr.isa<ir::FloatAttribute>()) {
data = attr.dyn_cast<ir::FloatAttribute>().data();
dtype = phi::DataType::FLOAT32;
} else if (attr.isa<ir::DoubleAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data());
dtype = phi::DataType::FLOAT64;
} else if (attr.isa<ir::Int32_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int32_tAttribute>().data());
dtype = phi::DataType::INT32;
} else if (attr.isa<ir::Int64_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int64_tAttribute>().data());
dtype = phi::DataType::INT64;
} else if (attr.isa<ir::BoolAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
dtype = phi::DataType::BOOL;
}
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());
return full_op.operation();
}
inline ir::Operation* InsertFullArrayOperationForAttributeInput(
ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
std::string constant_op_name(ir::ConstantOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name);
ir::Type null_type = ir::Type(nullptr);
ir::Type null_type = paddle::dialect::DenseTensorType::get(
ctx,
ir::Type(nullptr),
phi::DDim{},
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED,
phi::LoD{},
0); // TODO(lyk): to be done
ir::Operation* operation =
ir::Operation::Create({}, {}, {null_type}, op_info);
ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info);
program->block()->push_back(operation);
return operation;
}
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);
if (!op_desc.HasAttr(legacy_attr_name)) {
IR_THROW("Op %s arg %s should not be zero size",
op_desc.Type(),
legacy_attr_name);
}
paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "[" << op_desc.Type() << "][attribute]"
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr =
attribute_translator(input_info.type_name, legacy_attr);
ir::Operation* defining_op = nullptr;
bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
input_info.type_name.npos);
if (is_int_array) {
defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
} else {
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
}
return defining_op->GetResultByIndex(0);
}
inline std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
......@@ -184,14 +262,11 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
auto& args = n.second;
for (const auto& arg_name : args) {
PADDLE_ENFORCE_NE(
param_map->count(arg_name),
0,
platform::errors::PreconditionNotMet(
IR_ENFORCE(param_map->count(arg_name) != 0,
"arg %s.%s as input should be exists before prasing %s",
name,
arg_name,
op_desc.Type()));
op_desc.Type());
auto defining_info = (*param_map)[arg_name];
if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
......@@ -202,25 +277,59 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
std::vector<ir::OpResult> op_inputs;
auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
for (const auto& info : input_infos) {
std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< legacy_input_name;
std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (!op_desc.HasInput(legacy_input_name)) {
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_input_name));
if (op_desc.HasInput(legacy_input_name)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true);
}
if (legacy_input_vars.size() == 0) {
if (info.optional) {
op_inputs.push_back(ir::OpResult(nullptr));
continue;
}
}
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< legacy_input_name << " " << legacy_input_vars.size();
if (legacy_input_vars.size() == 0 && mutable_attributes != nullptr &&
mutable_attributes->count(info.name) != 0) {
const auto& candidate_var_names =
op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name);
bool found_candidate_var = false;
for (const auto& var_name : candidate_var_names) {
VLOG(10) << "[handle mutable attribute][" << info.name << "]["
<< var_name << "]";
if (op_desc.HasInput(var_name)) {
legacy_input_vars = op_desc.Input(var_name, true);
if (legacy_input_vars.size() == 0) continue;
found_candidate_var = true;
break;
}
}
if (!found_candidate_var) {
auto attribute_input = GetAttributeAsInput(ctx, program, op_desc, info);
op_inputs.push_back(attribute_input);
continue;
}
}
const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< is_vector << " " << info.type_name;
// if src type is Tensor
if (!is_vector) {
......@@ -262,11 +371,10 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "] optional " << info.name << " :"
<< info.type_name << " " << legacy_output_name;
PADDLE_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
IR_ENFORCE(info.optional,
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(),
legacy_output_name));
legacy_output_name);
op_output_types.push_back(ir::Type(nullptr));
continue;
}
......
......@@ -60,6 +60,10 @@ class Attribute {
IrContext *ir_context() const;
/// @brief print attribute
/// @param os
void Print(std::ostream &os) const;
///
/// \brief Methods for type judgment and cast.
///
......@@ -80,6 +84,8 @@ class Attribute {
protected:
const Storage *storage_{nullptr};
};
std::ostream &operator<<(std::ostream &os, Attribute attr);
} // namespace ir
namespace std {
......
......@@ -16,8 +16,10 @@
#include <ostream>
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/dialect_interface.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/type_base.h"
......@@ -33,15 +35,15 @@ class DialectInterface;
///
class Dialect {
public:
Dialect(std::string name, ir::IrContext *context, ir::TypeId id);
Dialect(std::string name, IrContext *context, TypeId id);
virtual ~Dialect();
const std::string &name() const { return name_; }
ir::IrContext *ir_context() const { return context_; }
IrContext *ir_context() const { return context_; }
ir::TypeId id() const { return id_; }
TypeId id() const { return id_; }
///
/// \brief Register all types contained in the template parameter Args.
......@@ -130,8 +132,12 @@ class Dialect {
return *interface;
}
virtual void PrintType(ir::Type type, std::ostream &os) {
throw std::logic_error("dialect has no registered type printing hook");
virtual void PrintType(Type type, std::ostream &os) const {
IR_THROW("dialect has no registered type printing hook");
}
virtual void PrintAttribute(Attribute type, std::ostream &os) const {
IR_THROW("dialect has no registered attribute printing hook");
}
private:
......@@ -141,9 +147,9 @@ class Dialect {
std::string name_;
ir::IrContext *context_; // not owned
IrContext *context_; // not owned
ir::TypeId id_;
TypeId id_;
std::unordered_map<TypeId, std::unique_ptr<DialectInterface>>
registered_interfaces_;
......
......@@ -17,6 +17,8 @@
#include <exception>
#include <string>
#include "paddle/utils/string/printf.h"
#if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else
......@@ -40,7 +42,11 @@ class IrNotMetException : public std::exception {
#define IR_THROW(...) \
do { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
throw ir::IrNotMetException( \
paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \
__FILE__, \
__LINE__, \
paddle::string::Sprintf(__VA_ARGS__))); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
......@@ -52,7 +58,11 @@ class IrNotMetException : public std::exception {
auto __cond__ = (COND); \
if (UNLIKELY(is_error(__cond__))) { \
try { \
throw ir::IrNotMetException(__VA_ARGS__); \
throw ir::IrNotMetException( \
paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \
__FILE__, \
__LINE__, \
paddle::string::Sprintf(__VA_ARGS__))); \
} catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \
throw; \
......
......@@ -23,69 +23,86 @@
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h"
#include "paddle/ir/core/value.h"
namespace ir {
namespace {
constexpr char newline[] = "\n";
template <typename ForwardIterator, typename UnaryFunctor, typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
} // namespace
class BasicIRPrinter {
public:
explicit BasicIRPrinter(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) {
void PrintType(Type type) {
if (!type) {
os << "<<NULL TYPE>>";
return;
}
if (type.isa<ir::Float16Type>()) {
if (type.isa<Float16Type>()) {
os << "f16";
} else if (type.isa<ir::Float32Type>()) {
} else if (type.isa<Float32Type>()) {
os << "f32";
} else if (type.isa<ir::Float64Type>()) {
} else if (type.isa<Float64Type>()) {
os << "f64";
} else if (type.isa<ir::Int16Type>()) {
} else if (type.isa<Int16Type>()) {
os << "i16";
} else if (type.isa<ir::Int32Type>()) {
} else if (type.isa<Int32Type>()) {
os << "i32";
} else if (type.isa<ir::Int64Type>()) {
} else if (type.isa<Int64Type>()) {
os << "i64";
} else if (type.isa<ir::VectorType>()) {
os << "vec<";
auto inner_types = type.dyn_cast<ir::VectorType>().data();
} else if (type.isa<VectorType>()) {
os << "vec[";
auto inner_types = type.dyn_cast<VectorType>().data();
PrintInterleave(
inner_types.begin(),
inner_types.end(),
[this](ir::Type v) { this->PrintType(v); },
[this]() { this->os << ", "; });
os << ">";
[this](Type v) { this->PrintType(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = type.dialect();
dialect.PrintType(type, os);
}
}
void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE }"; }
void PrintAttribute(const Attribute& attr) {
if (!attr) {
os << "<#AttrNull>";
return;
}
protected:
if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data();
} else if (auto i = attr.dyn_cast<Int32_tAttribute>()) {
os << i.data();
} else if (auto i = attr.dyn_cast<Int64_tAttribute>()) {
os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
os << "array[";
PrintInterleave(
vec.begin(),
vec.end(),
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
}
}
public:
std::ostream& os;
};
......@@ -96,14 +113,12 @@ class IRPrinter : public BasicIRPrinter {
/// @brief print program
/// @param program
/// @example
void PrintProgram(ir::Program* program) {
PrintOperation(program->module_op());
}
void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
/// @brief print operation
/// @param op
/// @example
void PrintOperation(ir::Operation* op) {
void PrintOperation(Operation* op) {
for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) {
......@@ -120,7 +135,7 @@ class IRPrinter : public BasicIRPrinter {
// TODO(lyk): add API to get operands directly
PrintOpOperands(op);
PrintAttribute(op);
PrintAttributeMap(op);
os << " :";
// PrintOpSingature
......@@ -138,7 +153,7 @@ class IRPrinter : public BasicIRPrinter {
}
private:
void PrintValue(ir::Value v) {
void PrintValue(Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
......@@ -156,10 +171,10 @@ class IRPrinter : public BasicIRPrinter {
os << new_name;
}
void PrintOpResult(ir::Operation* op) {
void PrintOpResult(Operation* op) {
os << " (";
auto num_op_result = op->num_results();
std::vector<ir::OpResult> op_results;
std::vector<OpResult> op_results;
op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx));
......@@ -167,15 +182,31 @@ class IRPrinter : public BasicIRPrinter {
PrintInterleave(
op_results.begin(),
op_results.end(),
[this](ir::Value v) { this->PrintValue(v); },
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpOperands(ir::Operation* op) {
void PrintAttributeMap(Operation* op) {
os << " {";
PrintInterleave(
op->attributes().begin(),
op->attributes().end(),
[this](std::pair<std::string, Attribute> it) {
this->os << it.first;
this->os << ":";
this->PrintAttribute(it.second);
},
[this]() { this->os << ","; });
os << "}";
}
void PrintOpOperands(Operation* op) {
os << " (";
auto num_op_operands = op->num_operands();
std::vector<ir::Value> op_operands;
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source());
......@@ -183,48 +214,48 @@ class IRPrinter : public BasicIRPrinter {
PrintInterleave(
op_operands.begin(),
op_operands.end(),
[this](ir::Value v) { this->PrintValue(v); },
[this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOperandsType(ir::Operation* op) {
void PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands();
std::vector<ir::Type> op_operand_types;
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else {
op_operand_types.push_back(ir::Type(nullptr));
op_operand_types.push_back(Type(nullptr));
}
}
os << " (";
PrintInterleave(
op_operand_types.begin(),
op_operand_types.end(),
[this](ir::Type t) { this->PrintType(t); },
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
os << ")";
}
void PrintOpReturnType(ir::Operation* op) {
void PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results();
std::vector<ir::Type> op_result_types;
std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
op_result_types.push_back(ir::Type(nullptr));
op_result_types.push_back(Type(nullptr));
}
}
PrintInterleave(
op_result_types.begin(),
op_result_types.end(),
[this](ir::Type t) { this->PrintType(t); },
[this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; });
}
......@@ -248,4 +279,19 @@ void Type::Print(std::ostream& os) const {
printer.PrintType(*this);
}
void Attribute::Print(std::ostream& os) const {
BasicIRPrinter printer(os);
printer.PrintAttribute(*this);
}
std::ostream& operator<<(std::ostream& os, Type type) {
type.Print(os);
return os;
}
std::ostream& operator<<(std::ostream& os, Attribute attr) {
attr.Print(os);
return os;
}
} // namespace ir
......@@ -17,10 +17,4 @@
namespace ir {
IrContext* Type::ir_context() const { return dialect().ir_context(); }
std::ostream& operator<<(std::ostream& os, Type type) {
type.Print(os);
return os;
}
} // namespace ir
......@@ -120,4 +120,18 @@ struct Filter<BaseT, Tuple, true> {
using Type = std::tuple<>;
};
template <typename ForwardIterator, typename UnaryFunctor, typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
} // namespace ir
......@@ -53,11 +53,13 @@ TEST(PaddleDialectTest, Translator) {
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto program = paddle::TranslateLegacyProgramToProgram(p);
auto program = paddle::TranslateLegacyProgramToProgram(p);
// size_t op_size = program->block()->size();
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21);
size_t op_size = program->block()->size();
// ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// array op + full op
EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
// program->Print(std::cout);
program->Print(std::cout);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册