未验证 提交 51ca74bb 编写于 作者: K kangguangli 提交者: GitHub

[IR] IR attribute printer and support mutable attribute (#54369)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* refactor scalar attribute

* draft

* fix

* fix op build

* fix op build

* temporarily save

* adpat mutable attribute

* refine op_comat_gen process

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* complete dialect attribute printer and refine ir_throw

* polish code

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 fc880209
...@@ -1193,8 +1193,8 @@ def OpGenerator( ...@@ -1193,8 +1193,8 @@ def OpGenerator(
# generate get op info funciton: inputs # generate get op info funciton: inputs
inputs_info_str = "" inputs_info_str = ""
if len(op_input_name_list) > 0:
input_info_list = [] input_info_list = []
if len(op_input_name_list) > 0:
for idx in range(len(op_input_name_list)): for idx in range(len(op_input_name_list)):
input_info_list.append( input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format( CONSTRUCT_INPUT_INFO_TEMPLATE.format(
...@@ -1204,6 +1204,18 @@ def OpGenerator( ...@@ -1204,6 +1204,18 @@ def OpGenerator(
no_need_buffer=op_input_no_need_buffer_list[idx], no_need_buffer=op_input_no_need_buffer_list[idx],
) )
) )
# add mutable attribute as input
if len(op_mutable_attribute_name_list) > 0:
for idx in range(len(op_mutable_attribute_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_mutable_attribute_name_list[idx],
typename=op_mutable_attribute_type_list[idx],
optional='false',
no_need_buffer='false',
)
)
inputs_info_str = ", ".join(input_info_list) inputs_info_str = ", ".join(input_info_list)
# generate get op info funciton: outputs # generate get op info funciton: outputs
...@@ -1223,12 +1235,16 @@ def OpGenerator( ...@@ -1223,12 +1235,16 @@ def OpGenerator(
# generate get op info funciton: attributes # generate get op info funciton: attributes
attribute_info_str = "" attribute_info_str = ""
op_mutable_attribute_name_set = set(op_mutable_attribute_name_list)
if len(op_attribute_name_list) > 0: if len(op_attribute_name_list) > 0:
attribute_info_list = [] attribute_info_list = []
for idx in range(len(op_attribute_name_list)): for idx in range(len(op_attribute_name_list)):
attribute_name = op_attribute_name_list[idx]
if attribute_name in op_mutable_attribute_name_set:
continue
attribute_info_list.append( attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=op_attribute_name_list[idx], name=attribute_name,
typename=op_attribute_type_list[idx], typename=op_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx], data_type=op_attribute_data_type_list[idx],
) )
......
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "paddle/fluid/ir/dialect/pd_type_storage.h" #include "paddle/fluid/ir/dialect/pd_type_storage.h"
#include "paddle/fluid/ir/dialect/utils.h" #include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/ir/core/dialect_interface.h" #include "paddle/ir/core/dialect_interface.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
namespace paddle { namespace paddle {
...@@ -107,7 +109,7 @@ void PaddleDialect::initialize() { ...@@ -107,7 +109,7 @@ void PaddleDialect::initialize() {
RegisterInterfaces<ParameterConvertInterface>(); RegisterInterfaces<ParameterConvertInterface>();
} }
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { void PaddleDialect::PrintType(ir::Type type, std::ostream &os) const {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>(); DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<"; os << "tensor<";
...@@ -119,5 +121,27 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { ...@@ -119,5 +121,27 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
os << ">"; os << ">";
} }
void PaddleDialect::PrintAttribute(ir::Attribute attr, std::ostream &os) const {
if (auto int_array_attr = attr.dyn_cast<IntArrayAttribute>()) {
phi::IntArray data = int_array_attr.data();
os << "IntArray[";
const auto &inner_data = data.GetData();
ir::PrintInterleave(
inner_data.begin(),
inner_data.end(),
[&os](int64_t i) { os << i; },
[&os]() { os << ","; });
os << "]";
} else if (auto data_type_attr = attr.dyn_cast<DataTypeAttribute>()) {
os << data_type_attr.data();
} else if (auto place_type_attr = attr.dyn_cast<PlaceAttribute>()) {
os << place_type_attr.data();
} else if (auto data_layout_attr = attr.dyn_cast<DataLayoutAttribute>()) {
os << data_layout_attr.data();
} else {
os << "<#AttrNotImplemented>";
}
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect { ...@@ -39,7 +39,8 @@ class PaddleDialect : public ir::Dialect {
static const char* name() { return "pd"; } static const char* name() { return "pd"; }
void PrintType(ir::Type type, std::ostream& os); void PrintType(ir::Type type, std::ostream& os) const;
void PrintAttribute(ir::Attribute type, std::ostream& os) const;
private: private:
void initialize(); void initialize();
......
...@@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR ...@@ -5,12 +5,14 @@ set(PD_PROGRAM_TRANSLATOR_BINARY_DIR
set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py) set(op_gen_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_gen.py)
set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml) set(op_compat_yaml_file ${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc) set(op_compat_source_file ${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc)
set(op_compat_templat_file
${PD_PROGRAM_TRANSLATOR_SOURCE_DIR}/op_compat_info.cc.j2)
add_custom_command( add_custom_command(
OUTPUT ${op_compat_source_file} OUTPUT ${op_compat_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file COMMAND ${PYTHON_EXECUTABLE} ${op_gen_file} --op_compat_yaml_file
${op_compat_yaml_file} --output_source_file ${op_compat_source_file} ${op_compat_yaml_file} --output_source_file ${op_compat_source_file}
DEPENDS ${op_gen_file} ${op_compat_yaml_file} DEPENDS ${op_gen_file} ${op_compat_yaml_file} ${op_compat_templat_file}
VERBATIM) VERBATIM)
file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc") file(GLOB PD_PROGRAM_TRANSLATOR_SRCS "*.cc")
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from typing import Dict from typing import Dict, List, Set
import yaml import yaml
from jinja2 import Environment, FileSystemLoader, StrictUndefined from jinja2 import Environment, FileSystemLoader, StrictUndefined
...@@ -46,8 +46,11 @@ def OpNameNormalizerInitialization( ...@@ -46,8 +46,11 @@ def OpNameNormalizerInitialization(
with open(op_compat_yaml_file, "r") as f: with open(op_compat_yaml_file, "r") as f:
op_compat_infos = yaml.safe_load(f) op_compat_infos = yaml.safe_load(f)
op_name_mappings = {} op_name_mappings: Dict[str, str] = {}
op_arg_name_mappings = {} op_arg_name_mappings: Dict[str, Dict[str, str]] = {}
op_mutable_attribues: Dict[str, Set[str]] = {}
op_mutable_attribute_infos: Dict[str, Dict[str, List[str]]] = {}
for op_compat_item in op_compat_infos: for op_compat_item in op_compat_infos:
def insert_new_mappings(op_name_str: str) -> str: def insert_new_mappings(op_name_str: str) -> str:
...@@ -64,6 +67,23 @@ def OpNameNormalizerInitialization( ...@@ -64,6 +67,23 @@ def OpNameNormalizerInitialization(
op_arg_name_mappings[op_name] = {} op_arg_name_mappings[op_name] = {}
op_arg_name_mappings[op_name].update(arg_mapping) op_arg_name_mappings[op_name].update(arg_mapping)
def insert_new_mutable_attributes(
op_name: str, mutable_attribute_infos: Dict[str, Dict[str, str]]
):
op_mutable_attribues[op_name] = set()
op_mutable_attribute_infos[op_name] = {}
for (
attribute_name,
mutable_attribute_info,
) in mutable_attribute_infos.items():
op_mutable_attribues[op_name].add(attribute_name)
op_mutable_attribute_infos[op_name][attribute_name] = []
for k, v in mutable_attribute_info.items():
if k == 'tensor_name' or k == 'tensors_name':
op_mutable_attribute_infos[op_name][
attribute_name
].append(v)
_, legacy_name = insert_new_mappings(op_compat_item["op"]) _, legacy_name = insert_new_mappings(op_compat_item["op"])
legacy_backward_op_names = [] legacy_backward_op_names = []
if "backward" in op_compat_item: if "backward" in op_compat_item:
...@@ -88,6 +108,14 @@ def OpNameNormalizerInitialization( ...@@ -88,6 +108,14 @@ def OpNameNormalizerInitialization(
for backward_op in legacy_backward_op_names: for backward_op in legacy_backward_op_names:
insert_new_arg_mappings(backward_op, op_compat_item["outputs"]) insert_new_arg_mappings(backward_op, op_compat_item["outputs"])
if "int_array" in op_compat_item:
insert_new_mutable_attributes(
legacy_name, op_compat_item["int_array"]
)
if "scalar" in op_compat_item:
insert_new_mutable_attributes(legacy_name, op_compat_item["scalar"])
# special op mappings # special op mappings
op_name_mappings["fetch_v2"] = "fetch" op_name_mappings["fetch_v2"] = "fetch"
...@@ -96,6 +124,8 @@ def OpNameNormalizerInitialization( ...@@ -96,6 +124,8 @@ def OpNameNormalizerInitialization(
op_compat_definition = op_name_normailzer_template.render( op_compat_definition = op_name_normailzer_template.render(
op_name_pairs=op_name_mappings, op_name_pairs=op_name_mappings,
op_arg_name_pairs=op_arg_name_mappings, op_arg_name_pairs=op_arg_name_mappings,
op_mutable_attributes=op_mutable_attribues,
op_mutable_attribute_infos=op_mutable_attribute_infos,
) )
f.write(op_compat_definition) f.write(op_compat_definition)
......
...@@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() { ...@@ -21,6 +21,37 @@ OpNameNormalizer::OpNameNormalizer() {
}, },
{% endfor %} {% endfor %}
}; };
op_mutable_attributes = {
{% for op_name, mutable_attributes in op_mutable_attributes.items() %}
{
"{{op_name}}",
{
{% for attribute_name in mutable_attributes %}
"{{attribute_name}}",
{% endfor %}
},
},
{% endfor %}
};
op_mutable_attribute_infos = {
{% for op_name, mutable_attribute_infos in op_mutable_attribute_infos.items() %}
{
"{{op_name}}",
{
{% for attribute_name, attribute_info in mutable_attribute_infos.items() %}
{
"{{attribute_name}}",
{
{% for candidate_var_name in attribute_info %}
"{{candidate_var_name}}",
{% endfor %}
},
},
{% endfor %}
},
},
{% endfor %}
};
} }
} // namespace translator } // namespace translator
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include <functional> #include <functional>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "glog/logging.h" #include "glog/logging.h"
...@@ -25,6 +26,8 @@ ...@@ -25,6 +26,8 @@
namespace paddle { namespace paddle {
namespace translator { namespace translator {
using MutableAttributeInfo = std::vector<std::string>;
class OpNameNormalizer { class OpNameNormalizer {
private: private:
OpNameNormalizer(); // Disallow instantiation outside of the class. OpNameNormalizer(); // Disallow instantiation outside of the class.
...@@ -32,6 +35,12 @@ class OpNameNormalizer { ...@@ -32,6 +35,12 @@ class OpNameNormalizer {
std::unordered_map<std::string, std::unordered_map<std::string, std::string>> std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
op_arg_name_mappings; op_arg_name_mappings;
std::unordered_map<std::string,
std::unordered_map<std::string, MutableAttributeInfo>>
op_mutable_attribute_infos;
std::unordered_map<std::string, std::unordered_set<std::string>>
op_mutable_attributes;
public: public:
OpNameNormalizer(const OpNameNormalizer&) = delete; OpNameNormalizer(const OpNameNormalizer&) = delete;
OpNameNormalizer& operator=(const OpNameNormalizer&) = delete; OpNameNormalizer& operator=(const OpNameNormalizer&) = delete;
...@@ -50,6 +59,21 @@ class OpNameNormalizer { ...@@ -50,6 +59,21 @@ class OpNameNormalizer {
return op_name_mappings.at(op_type); return op_name_mappings.at(op_type);
} }
bool HasMutableAttribute(const std::string& op_type) {
return (op_mutable_attributes.find(op_type) != op_mutable_attributes.end());
}
const std::unordered_set<std::string>* GetMutableAttributes(
const std::string& op_type) {
if (!HasMutableAttribute(op_type)) return nullptr;
return &op_mutable_attributes.at(op_type);
}
const MutableAttributeInfo& GetMutableAttributeInfos(
const std::string& op_type, const std::string& arg_name) {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}
std::string GetLegacyArgName(const std::string& op_type, std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) { const std::string& arg_name) {
bool is_grad_op = (op_type.find("grad") != std::string::npos); bool is_grad_op = (op_type.find("grad") != std::string::npos);
......
...@@ -23,17 +23,23 @@ ...@@ -23,17 +23,23 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h" #include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir_adaptor/translator/attribute_translator.h" #include "paddle/fluid/ir_adaptor/translator/attribute_translator.h"
#include "paddle/fluid/ir_adaptor/translator/op_compat_info.h" #include "paddle/fluid/ir_adaptor/translator/op_compat_info.h"
#include "paddle/fluid/ir_adaptor/translator/program_translator.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/phi/core/enforce.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/ir/dialect/CMakeLists.txt.
#include "paddle/fluid/ir/dialect/pd_op.h"
namespace paddle { namespace paddle {
namespace translator { namespace translator {
...@@ -66,8 +72,13 @@ inline bool IsInplace(const OpDesc& op_desc) { ...@@ -66,8 +72,13 @@ inline bool IsInplace(const OpDesc& op_desc) {
} }
auto input_names = op_desc.InputArgumentNames(); auto input_names = op_desc.InputArgumentNames();
auto output_names = op_desc.OutputArgumentNames(); auto output_names = op_desc.OutputArgumentNames();
if (input_names.size() == 0 || output_names.size() == 0) {
return inplace;
}
std::vector<std::string> name_intersection; std::vector<std::string> name_intersection;
std::sort(input_names.begin(), input_names.end());
std::sort(output_names.begin(), output_names.end());
std::set_intersection(input_names.begin(), std::set_intersection(input_names.begin(),
input_names.end(), input_names.end(),
output_names.begin(), output_names.begin(),
...@@ -103,10 +114,9 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { ...@@ -103,10 +114,9 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) {
<< target_op_name; << target_op_name;
auto op_info = ctx->GetRegisteredOpInfo(target_op_name); auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) { if (!op_info) {
PADDLE_THROW(platform::errors::PreconditionNotMet( IR_THROW("Op %d should have corresponding OpInfo %d",
"Op %d should have corresponding OpInfo %d",
op_desc.Type(), op_desc.Type(),
target_op_name)); target_op_name);
} }
return op_info; return op_info;
...@@ -158,18 +168,86 @@ inline ir::Operation* InsertCombineOperationForTarget( ...@@ -158,18 +168,86 @@ inline ir::Operation* InsertCombineOperationForTarget(
return operation; return operation;
} }
inline ir::Operation* InsertConstantOperationForOptionalArg( inline ir::Operation* InsertFullOperationForAttributeInput(ir::IrContext* ctx,
ir::IrContext* ctx, ir::Program* program) { ir::Program* program,
ir::Attribute attr) {
float data = 0.0f;
phi::DataType dtype = phi::DataType::UNDEFINED;
if (attr.isa<ir::FloatAttribute>()) {
data = attr.dyn_cast<ir::FloatAttribute>().data();
dtype = phi::DataType::FLOAT32;
} else if (attr.isa<ir::DoubleAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::DoubleAttribute>().data());
dtype = phi::DataType::FLOAT64;
} else if (attr.isa<ir::Int32_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int32_tAttribute>().data());
dtype = phi::DataType::INT32;
} else if (attr.isa<ir::Int64_tAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::Int64_tAttribute>().data());
dtype = phi::DataType::INT64;
} else if (attr.isa<ir::BoolAttribute>()) {
data = static_cast<float>(attr.dyn_cast<ir::BoolAttribute>().data());
dtype = phi::DataType::BOOL;
}
ir::Builder builder = ir::Builder::AtBlockEnd(ctx, program->block());
paddle::dialect::FullOp full_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, data, dtype, phi::CPUPlace());
return full_op.operation();
}
inline ir::Operation* InsertFullArrayOperationForAttributeInput(
ir::IrContext* ctx, ir::Program* program, ir::Attribute attr) {
std::string constant_op_name(ir::ConstantOp::name()); std::string constant_op_name(ir::ConstantOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(constant_op_name);
ir::Type null_type = ir::Type(nullptr); ir::Type null_type = paddle::dialect::DenseTensorType::get(
ctx,
ir::Type(nullptr),
phi::DDim{},
paddle::dialect::DenseTensorTypeStorage::DataLayout::UNDEFINED,
phi::LoD{},
0); // TODO(lyk): to be done
ir::Operation* operation = ir::Operation* operation =
ir::Operation::Create({}, {}, {null_type}, op_info); ir::Operation::Create({}, {{"value", attr}}, {null_type}, op_info);
program->block()->push_back(operation); program->block()->push_back(operation);
return operation; return operation;
} }
inline ir::OpResult GetAttributeAsInput(ir::IrContext* ctx,
ir::Program* program,
const OpDesc& op_desc,
const OpInputInfo& input_info) {
auto& attribute_translator = AttributeTranslator::instance();
auto& op_normalizer = OpNameNormalizer::instance();
auto legacy_attr_name =
op_normalizer.GetLegacyAttrName(op_desc.Type(), input_info.name);
if (!op_desc.HasAttr(legacy_attr_name)) {
IR_THROW("Op %s arg %s should not be zero size",
op_desc.Type(),
legacy_attr_name);
}
paddle::framework::Attribute legacy_attr = op_desc.GetAttr(legacy_attr_name);
VLOG(10) << "[" << op_desc.Type() << "][attribute]"
<< " name: " << legacy_attr_name << " " << legacy_attr.index();
ir::Attribute new_attr =
attribute_translator(input_info.type_name, legacy_attr);
ir::Operation* defining_op = nullptr;
bool is_int_array = (input_info.type_name.find("IntArrayAttribute") !=
input_info.type_name.npos);
if (is_int_array) {
defining_op =
InsertFullArrayOperationForAttributeInput(ctx, program, new_attr);
} else {
defining_op = InsertFullOperationForAttributeInput(ctx, program, new_attr);
}
return defining_op->GetResultByIndex(0);
}
inline std::vector<ir::OpResult> GenerateOperationInput( inline std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx, ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
...@@ -184,14 +262,11 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -184,14 +262,11 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
auto& args = n.second; auto& args = n.second;
for (const auto& arg_name : args) { for (const auto& arg_name : args) {
PADDLE_ENFORCE_NE( IR_ENFORCE(param_map->count(arg_name) != 0,
param_map->count(arg_name),
0,
platform::errors::PreconditionNotMet(
"arg %s.%s as input should be exists before prasing %s", "arg %s.%s as input should be exists before prasing %s",
name, name,
arg_name, arg_name,
op_desc.Type())); op_desc.Type());
auto defining_info = (*param_map)[arg_name]; auto defining_info = (*param_map)[arg_name];
if (defining_info.generated_by_vector) { if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget( InsertSliceOperationForTarget(
...@@ -202,25 +277,59 @@ inline std::vector<ir::OpResult> GenerateOperationInput( ...@@ -202,25 +277,59 @@ inline std::vector<ir::OpResult> GenerateOperationInput(
std::vector<ir::OpResult> op_inputs; std::vector<ir::OpResult> op_inputs;
auto& op_normalizer = OpNameNormalizer::instance(); auto& op_normalizer = OpNameNormalizer::instance();
const auto* mutable_attributes =
op_normalizer.GetMutableAttributes(op_desc.Type());
for (const auto& info : input_infos) { for (const auto& info : input_infos) {
std::string legacy_input_name = std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name); op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< legacy_input_name;
std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc // return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute // TODO(lyk): HasInput doesnot consider variadic attribute
if (!op_desc.HasInput(legacy_input_name)) { if (op_desc.HasInput(legacy_input_name)) {
PADDLE_ENFORCE(info.optional, legacy_input_vars = op_desc.Input(legacy_input_name, true);
platform::errors::PreconditionNotMet( }
"Op %s arg %s should be optional if it can be empty",
op_desc.Type(), if (legacy_input_vars.size() == 0) {
legacy_input_name)); if (info.optional) {
op_inputs.push_back(ir::OpResult(nullptr)); op_inputs.push_back(ir::OpResult(nullptr));
continue; continue;
} }
}
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< legacy_input_name << " " << legacy_input_vars.size();
if (legacy_input_vars.size() == 0 && mutable_attributes != nullptr &&
mutable_attributes->count(info.name) != 0) {
const auto& candidate_var_names =
op_normalizer.GetMutableAttributeInfos(op_desc.Type(), info.name);
bool found_candidate_var = false;
for (const auto& var_name : candidate_var_names) {
VLOG(10) << "[handle mutable attribute][" << info.name << "]["
<< var_name << "]";
if (op_desc.HasInput(var_name)) {
legacy_input_vars = op_desc.Input(var_name, true);
if (legacy_input_vars.size() == 0) continue;
found_candidate_var = true;
break;
}
}
if (!found_candidate_var) {
auto attribute_input = GetAttributeAsInput(ctx, program, op_desc, info);
op_inputs.push_back(attribute_input);
continue;
}
}
const auto& legacy_input_vars = op_desc.Input(legacy_input_name, true);
bool is_vector = (info.type_name.find("VectorType") != std::string::npos); bool is_vector = (info.type_name.find("VectorType") != std::string::npos);
VLOG(10) << "[op:" << op_desc.Type() << "][input]" << info.name << " "
<< is_vector << " " << info.type_name;
// if src type is Tensor // if src type is Tensor
if (!is_vector) { if (!is_vector) {
...@@ -262,11 +371,10 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput( ...@@ -262,11 +371,10 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
VLOG(10) << "[output translating]" VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "] optional " << info.name << " :" << "[" << op_desc.Type() << "] optional " << info.name << " :"
<< info.type_name << " " << legacy_output_name; << info.type_name << " " << legacy_output_name;
PADDLE_ENFORCE(info.optional, IR_ENFORCE(info.optional,
platform::errors::PreconditionNotMet(
"Op %s arg %s should be optional if it can be empty", "Op %s arg %s should be optional if it can be empty",
op_desc.Type(), op_desc.Type(),
legacy_output_name)); legacy_output_name);
op_output_types.push_back(ir::Type(nullptr)); op_output_types.push_back(ir::Type(nullptr));
continue; continue;
} }
......
...@@ -60,6 +60,10 @@ class Attribute { ...@@ -60,6 +60,10 @@ class Attribute {
IrContext *ir_context() const; IrContext *ir_context() const;
/// @brief print attribute
/// @param os
void Print(std::ostream &os) const;
/// ///
/// \brief Methods for type judgment and cast. /// \brief Methods for type judgment and cast.
/// ///
...@@ -80,6 +84,8 @@ class Attribute { ...@@ -80,6 +84,8 @@ class Attribute {
protected: protected:
const Storage *storage_{nullptr}; const Storage *storage_{nullptr};
}; };
std::ostream &operator<<(std::ostream &os, Attribute attr);
} // namespace ir } // namespace ir
namespace std { namespace std {
......
...@@ -16,8 +16,10 @@ ...@@ -16,8 +16,10 @@
#include <ostream> #include <ostream>
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/dialect_interface.h" #include "paddle/ir/core/dialect_interface.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/ir/core/type_base.h" #include "paddle/ir/core/type_base.h"
...@@ -33,15 +35,15 @@ class DialectInterface; ...@@ -33,15 +35,15 @@ class DialectInterface;
/// ///
class Dialect { class Dialect {
public: public:
Dialect(std::string name, ir::IrContext *context, ir::TypeId id); Dialect(std::string name, IrContext *context, TypeId id);
virtual ~Dialect(); virtual ~Dialect();
const std::string &name() const { return name_; } const std::string &name() const { return name_; }
ir::IrContext *ir_context() const { return context_; } IrContext *ir_context() const { return context_; }
ir::TypeId id() const { return id_; } TypeId id() const { return id_; }
/// ///
/// \brief Register all types contained in the template parameter Args. /// \brief Register all types contained in the template parameter Args.
...@@ -130,8 +132,12 @@ class Dialect { ...@@ -130,8 +132,12 @@ class Dialect {
return *interface; return *interface;
} }
virtual void PrintType(ir::Type type, std::ostream &os) { virtual void PrintType(Type type, std::ostream &os) const {
throw std::logic_error("dialect has no registered type printing hook"); IR_THROW("dialect has no registered type printing hook");
}
virtual void PrintAttribute(Attribute type, std::ostream &os) const {
IR_THROW("dialect has no registered attribute printing hook");
} }
private: private:
...@@ -141,9 +147,9 @@ class Dialect { ...@@ -141,9 +147,9 @@ class Dialect {
std::string name_; std::string name_;
ir::IrContext *context_; // not owned IrContext *context_; // not owned
ir::TypeId id_; TypeId id_;
std::unordered_map<TypeId, std::unique_ptr<DialectInterface>> std::unordered_map<TypeId, std::unique_ptr<DialectInterface>>
registered_interfaces_; registered_interfaces_;
......
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include <exception> #include <exception>
#include <string> #include <string>
#include "paddle/utils/string/printf.h"
#if !defined(_WIN32) #if !defined(_WIN32)
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0) #define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)
#else #else
...@@ -40,7 +42,11 @@ class IrNotMetException : public std::exception { ...@@ -40,7 +42,11 @@ class IrNotMetException : public std::exception {
#define IR_THROW(...) \ #define IR_THROW(...) \
do { \ do { \
try { \ try { \
throw ir::IrNotMetException(__VA_ARGS__); \ throw ir::IrNotMetException( \
paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \
__FILE__, \
__LINE__, \
paddle::string::Sprintf(__VA_ARGS__))); \
} catch (const std::exception& e) { \ } catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \ std::cout << e.what() << std::endl; \
throw; \ throw; \
...@@ -52,7 +58,11 @@ class IrNotMetException : public std::exception { ...@@ -52,7 +58,11 @@ class IrNotMetException : public std::exception {
auto __cond__ = (COND); \ auto __cond__ = (COND); \
if (UNLIKELY(is_error(__cond__))) { \ if (UNLIKELY(is_error(__cond__))) { \
try { \ try { \
throw ir::IrNotMetException(__VA_ARGS__); \ throw ir::IrNotMetException( \
paddle::string::Sprintf("Error occured at: %s:%d :\n%s", \
__FILE__, \
__LINE__, \
paddle::string::Sprintf(__VA_ARGS__))); \
} catch (const std::exception& e) { \ } catch (const std::exception& e) { \
std::cout << e.what() << std::endl; \ std::cout << e.what() << std::endl; \
throw; \ throw; \
......
...@@ -23,69 +23,86 @@ ...@@ -23,69 +23,86 @@
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
#include "paddle/ir/core/utils.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
namespace ir { namespace ir {
namespace { namespace {
constexpr char newline[] = "\n"; constexpr char newline[] = "\n";
template <typename ForwardIterator, typename UnaryFunctor, typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
} // namespace } // namespace
class BasicIRPrinter { class BasicIRPrinter {
public: public:
explicit BasicIRPrinter(std::ostream& os) : os(os) {} explicit BasicIRPrinter(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) { void PrintType(Type type) {
if (!type) { if (!type) {
os << "<<NULL TYPE>>"; os << "<<NULL TYPE>>";
return; return;
} }
if (type.isa<ir::Float16Type>()) { if (type.isa<Float16Type>()) {
os << "f16"; os << "f16";
} else if (type.isa<ir::Float32Type>()) { } else if (type.isa<Float32Type>()) {
os << "f32"; os << "f32";
} else if (type.isa<ir::Float64Type>()) { } else if (type.isa<Float64Type>()) {
os << "f64"; os << "f64";
} else if (type.isa<ir::Int16Type>()) { } else if (type.isa<Int16Type>()) {
os << "i16"; os << "i16";
} else if (type.isa<ir::Int32Type>()) { } else if (type.isa<Int32Type>()) {
os << "i32"; os << "i32";
} else if (type.isa<ir::Int64Type>()) { } else if (type.isa<Int64Type>()) {
os << "i64"; os << "i64";
} else if (type.isa<ir::VectorType>()) { } else if (type.isa<VectorType>()) {
os << "vec<"; os << "vec[";
auto inner_types = type.dyn_cast<ir::VectorType>().data(); auto inner_types = type.dyn_cast<VectorType>().data();
PrintInterleave( PrintInterleave(
inner_types.begin(), inner_types.begin(),
inner_types.end(), inner_types.end(),
[this](ir::Type v) { this->PrintType(v); }, [this](Type v) { this->PrintType(v); },
[this]() { this->os << ", "; }); [this]() { this->os << ","; });
os << ">"; os << "]";
} else { } else {
auto& dialect = type.dialect(); auto& dialect = type.dialect();
dialect.PrintType(type, os); dialect.PrintType(type, os);
} }
} }
void PrintAttribute(ir::Operation* op) { os << " { ATTRIBUTE }"; } void PrintAttribute(const Attribute& attr) {
if (!attr) {
os << "<#AttrNull>";
return;
}
protected: if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) {
os << f.data();
} else if (auto d = attr.dyn_cast<DoubleAttribute>()) {
os << d.data();
} else if (auto i = attr.dyn_cast<Int32_tAttribute>()) {
os << i.data();
} else if (auto i = attr.dyn_cast<Int64_tAttribute>()) {
os << i.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
os << "array[";
PrintInterleave(
vec.begin(),
vec.end(),
[this](Attribute v) { this->PrintAttribute(v); },
[this]() { this->os << ","; });
os << "]";
} else {
auto& dialect = attr.dialect();
dialect.PrintAttribute(attr, os);
}
}
public:
std::ostream& os; std::ostream& os;
}; };
...@@ -96,14 +113,12 @@ class IRPrinter : public BasicIRPrinter { ...@@ -96,14 +113,12 @@ class IRPrinter : public BasicIRPrinter {
/// @brief print program /// @brief print program
/// @param program /// @param program
/// @example /// @example
void PrintProgram(ir::Program* program) { void PrintProgram(Program* program) { PrintOperation(program->module_op()); }
PrintOperation(program->module_op());
}
/// @brief print operation /// @brief print operation
/// @param op /// @param op
/// @example /// @example
void PrintOperation(ir::Operation* op) { void PrintOperation(Operation* op) {
for (size_t i = 0; i < op->num_regions(); ++i) { for (size_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->GetRegion(i); auto& region = op->GetRegion(i);
for (auto it = region.begin(); it != region.end(); ++it) { for (auto it = region.begin(); it != region.end(); ++it) {
...@@ -120,7 +135,7 @@ class IRPrinter : public BasicIRPrinter { ...@@ -120,7 +135,7 @@ class IRPrinter : public BasicIRPrinter {
// TODO(lyk): add API to get operands directly // TODO(lyk): add API to get operands directly
PrintOpOperands(op); PrintOpOperands(op);
PrintAttribute(op); PrintAttributeMap(op);
os << " :"; os << " :";
// PrintOpSingature // PrintOpSingature
...@@ -138,7 +153,7 @@ class IRPrinter : public BasicIRPrinter { ...@@ -138,7 +153,7 @@ class IRPrinter : public BasicIRPrinter {
} }
private: private:
void PrintValue(ir::Value v) { void PrintValue(Value v) {
if (!v) { if (!v) {
os << "<<NULL VALUE>>"; os << "<<NULL VALUE>>";
return; return;
...@@ -156,10 +171,10 @@ class IRPrinter : public BasicIRPrinter { ...@@ -156,10 +171,10 @@ class IRPrinter : public BasicIRPrinter {
os << new_name; os << new_name;
} }
void PrintOpResult(ir::Operation* op) { void PrintOpResult(Operation* op) {
os << " ("; os << " (";
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<ir::OpResult> op_results; std::vector<OpResult> op_results;
op_results.reserve(num_op_result); op_results.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
op_results.push_back(op->GetResultByIndex(idx)); op_results.push_back(op->GetResultByIndex(idx));
...@@ -167,15 +182,31 @@ class IRPrinter : public BasicIRPrinter { ...@@ -167,15 +182,31 @@ class IRPrinter : public BasicIRPrinter {
PrintInterleave( PrintInterleave(
op_results.begin(), op_results.begin(),
op_results.end(), op_results.end(),
[this](ir::Value v) { this->PrintValue(v); }, [this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintOpOperands(ir::Operation* op) { void PrintAttributeMap(Operation* op) {
os << " {";
PrintInterleave(
op->attributes().begin(),
op->attributes().end(),
[this](std::pair<std::string, Attribute> it) {
this->os << it.first;
this->os << ":";
this->PrintAttribute(it.second);
},
[this]() { this->os << ","; });
os << "}";
}
void PrintOpOperands(Operation* op) {
os << " ("; os << " (";
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<ir::Value> op_operands; std::vector<Value> op_operands;
op_operands.reserve(num_op_operands); op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->GetOperandByIndex(idx).source()); op_operands.push_back(op->GetOperandByIndex(idx).source());
...@@ -183,48 +214,48 @@ class IRPrinter : public BasicIRPrinter { ...@@ -183,48 +214,48 @@ class IRPrinter : public BasicIRPrinter {
PrintInterleave( PrintInterleave(
op_operands.begin(), op_operands.begin(),
op_operands.end(), op_operands.end(),
[this](ir::Value v) { this->PrintValue(v); }, [this](Value v) { this->PrintValue(v); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintOperandsType(ir::Operation* op) { void PrintOperandsType(Operation* op) {
auto num_op_operands = op->num_operands(); auto num_op_operands = op->num_operands();
std::vector<ir::Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx); auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) { if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else { } else {
op_operand_types.push_back(ir::Type(nullptr)); op_operand_types.push_back(Type(nullptr));
} }
} }
os << " ("; os << " (";
PrintInterleave( PrintInterleave(
op_operand_types.begin(), op_operand_types.begin(),
op_operand_types.end(), op_operand_types.end(),
[this](ir::Type t) { this->PrintType(t); }, [this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
os << ")"; os << ")";
} }
void PrintOpReturnType(ir::Operation* op) { void PrintOpReturnType(Operation* op) {
auto num_op_result = op->num_results(); auto num_op_result = op->num_results();
std::vector<ir::Type> op_result_types; std::vector<Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
auto op_result = op->GetResultByIndex(idx); auto op_result = op->GetResultByIndex(idx);
if (op_result) { if (op_result) {
op_result_types.push_back(op_result.type()); op_result_types.push_back(op_result.type());
} else { } else {
op_result_types.push_back(ir::Type(nullptr)); op_result_types.push_back(Type(nullptr));
} }
} }
PrintInterleave( PrintInterleave(
op_result_types.begin(), op_result_types.begin(),
op_result_types.end(), op_result_types.end(),
[this](ir::Type t) { this->PrintType(t); }, [this](Type t) { this->PrintType(t); },
[this]() { this->os << ", "; }); [this]() { this->os << ", "; });
} }
...@@ -248,4 +279,19 @@ void Type::Print(std::ostream& os) const { ...@@ -248,4 +279,19 @@ void Type::Print(std::ostream& os) const {
printer.PrintType(*this); printer.PrintType(*this);
} }
void Attribute::Print(std::ostream& os) const {
BasicIRPrinter printer(os);
printer.PrintAttribute(*this);
}
std::ostream& operator<<(std::ostream& os, Type type) {
type.Print(os);
return os;
}
std::ostream& operator<<(std::ostream& os, Attribute attr) {
attr.Print(os);
return os;
}
} // namespace ir } // namespace ir
...@@ -17,10 +17,4 @@ ...@@ -17,10 +17,4 @@
namespace ir { namespace ir {
IrContext* Type::ir_context() const { return dialect().ir_context(); } IrContext* Type::ir_context() const { return dialect().ir_context(); }
std::ostream& operator<<(std::ostream& os, Type type) {
type.Print(os);
return os;
}
} // namespace ir } // namespace ir
...@@ -120,4 +120,18 @@ struct Filter<BaseT, Tuple, true> { ...@@ -120,4 +120,18 @@ struct Filter<BaseT, Tuple, true> {
using Type = std::tuple<>; using Type = std::tuple<>;
}; };
template <typename ForwardIterator, typename UnaryFunctor, typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
} // namespace ir } // namespace ir
...@@ -53,11 +53,13 @@ TEST(PaddleDialectTest, Translator) { ...@@ -53,11 +53,13 @@ TEST(PaddleDialectTest, Translator) {
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<PaddleDialect>(); ctx->GetOrRegisterDialect<PaddleDialect>();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>(); ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// auto program = paddle::TranslateLegacyProgramToProgram(p); auto program = paddle::TranslateLegacyProgramToProgram(p);
// size_t op_size = program->block()->size(); size_t op_size = program->block()->size();
// // ops.size() = op size in BlockDesc + get_parameter_op + combine op // ops.size() = op size in BlockDesc + get_parameter_op + combine op + int
// EXPECT_EQ(op_size, p.Block(0).OpSize() + program->parameters_num() + 21); // array op + full op
EXPECT_EQ(op_size,
p.Block(0).OpSize() + program->parameters_num() + 20 + 3 + 8);
// program->Print(std::cout); program->Print(std::cout);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册