未验证 提交 896d7cfa 编写于 作者: W winter-wang 提交者: GitHub

[IR] finetune the StrAttribute interface. (#55439)

上级 5464b5c4
...@@ -47,7 +47,7 @@ OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) { ...@@ -47,7 +47,7 @@ OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) {
// and so that they would be dispatched to host thread. // and so that they would be dispatched to host thread.
auto op_attributes = op->attributes(); auto op_attributes = op->attributes();
auto op_name = auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data(); op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
if (op_name == kCoalesceTensor && if (op_name == kCoalesceTensor &&
(!platform::is_xpu_place(place) || (!platform::is_xpu_place(place) ||
op->attribute<ir::BoolAttribute>("persist_output").data() == false) && op->attribute<ir::BoolAttribute>("persist_output").data() == false) &&
...@@ -77,7 +77,7 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -77,7 +77,7 @@ PhiKernelInstruction::PhiKernelInstruction(
: InstructionBase(id, place) { : InstructionBase(id, place) {
auto op_attributes = op->attributes(); auto op_attributes = op->attributes();
auto op_name = auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data(); op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name); ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
phi_op_name_ = op_name; phi_op_name_ = op_name;
...@@ -142,7 +142,7 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -142,7 +142,7 @@ PhiKernelInstruction::PhiKernelInstruction(
VLOG(6) << "finish process infer meta context"; VLOG(6) << "finish process infer meta context";
auto kernel_name = auto kernel_name =
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().data(); op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = op_attributes.at("kernel_key") auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>() .dyn_cast<paddle::dialect::KernelAttribute>()
.data(); .data();
......
...@@ -952,7 +952,8 @@ void BuildOpFuncList( ...@@ -952,7 +952,8 @@ void BuildOpFuncList(
OpFuncNode op_func_node; OpFuncNode op_func_node;
auto attr_map = (*it)->attributes(); auto attr_map = (*it)->attributes();
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data(); auto op_name =
attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
op_func_node.phi_op_name_ = op_name; op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed" || if (op_name == "builtin.combine" || op_name == "pd.feed" ||
...@@ -986,7 +987,7 @@ void BuildOpFuncList( ...@@ -986,7 +987,7 @@ void BuildOpFuncList(
&(op_func_node.infer_meta_context_)); &(op_func_node.infer_meta_context_));
auto kernel_name = auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data(); attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key") auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>() .dyn_cast<paddle::dialect::KernelAttribute>()
.data(); .data();
......
...@@ -45,10 +45,10 @@ void PhiKernelOp::Verify() { ...@@ -45,10 +45,10 @@ void PhiKernelOp::Verify() {
} }
std::string PhiKernelOp::op_name() { std::string PhiKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); return attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
} }
std::string PhiKernelOp::kernel_name() { std::string PhiKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().data(); return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
} }
phi::KernelKey PhiKernelOp::kernel_key() { phi::KernelKey PhiKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data(); return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
......
...@@ -542,11 +542,14 @@ def gen_build_func_str( ...@@ -542,11 +542,14 @@ def gen_build_func_str(
GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """ GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data(); {attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data();
"""
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<ir::StrAttribute>().AsString();
""" """
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name}; {attr_type} {attribute_name};
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{ for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].dyn_cast<{inner_type}>().data()); {attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).dyn_cast<{inner_type}>().{data_name}());
}} }}
""" """
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """ GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
...@@ -566,11 +569,15 @@ def gen_build_func_str( ...@@ -566,11 +569,15 @@ def gen_build_func_str(
# attr_type = "std::vector<int>" # attr_type = "std::vector<int>"
if "ir::ArrayAttribute" in op_attribute_type_list[idx]: if "ir::ArrayAttribute" in op_attribute_type_list[idx]:
inner_type = op_attribute_type_list[idx][19:-1] inner_type = op_attribute_type_list[idx][19:-1]
data_name = "data"
if inner_type == "ir::StrAttribute":
data_name = "AsString"
get_attributes_str += ( get_attributes_str += (
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format( GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
attr_type=attr_type, attr_type=attr_type,
attribute_name=op_attribute_name_list[idx], attribute_name=op_attribute_name_list[idx],
inner_type=inner_type, inner_type=inner_type,
data_name=data_name,
) )
) )
elif ( elif (
...@@ -593,6 +600,14 @@ def gen_build_func_str( ...@@ -593,6 +600,14 @@ def gen_build_func_str(
attribute_name=op_attribute_name_list[idx], attribute_name=op_attribute_name_list[idx],
) )
) )
elif "ir::StrAttribute" in op_attribute_type_list[idx]:
get_attributes_str += (
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
attribute_name=op_attribute_name_list[idx],
attr_ir_type=op_attribute_type_list[idx],
)
)
else: else:
get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format( get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type, attr_type=attr_type,
......
...@@ -78,7 +78,7 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """ ...@@ -78,7 +78,7 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{ for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}""" }}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """ OUTPUT_TYPE_CHECK_TEMPLATE = """
......
...@@ -76,7 +76,8 @@ class PhiKernelAdaptor { ...@@ -76,7 +76,8 @@ class PhiKernelAdaptor {
for (auto it = block->begin(); it != block->end(); ++it) { for (auto it = block->begin(); it != block->end(); ++it) {
auto attr_map = (*it)->attributes(); auto attr_map = (*it)->attributes();
auto op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data(); auto op_name =
attr_map.at("op_name").dyn_cast<ir::StrAttribute>().AsString();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name); ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name);
...@@ -104,7 +105,7 @@ class PhiKernelAdaptor { ...@@ -104,7 +105,7 @@ class PhiKernelAdaptor {
infer_meta_impl->infer_meta_(&ctx); infer_meta_impl->infer_meta_(&ctx);
auto kernel_name = auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data(); attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key") auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>() .dyn_cast<paddle::dialect::KernelAttribute>()
.data(); .data();
......
...@@ -171,7 +171,7 @@ void HandleForSpecialOp( ...@@ -171,7 +171,7 @@ void HandleForSpecialOp(
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
} }
if (op_name == "pd.fetch") { if (op_name == "pd.fetch") {
...@@ -244,7 +244,7 @@ void HandleForSpecialOp( ...@@ -244,7 +244,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes() auto param_name = op->attributes()
.at("parameter_name") .at("parameter_name")
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .AsString();
auto value = op->operand(0); auto value = op->operand(0);
// change opreand name to param_name // change opreand name to param_name
...@@ -262,7 +262,7 @@ void HandleForSpecialOp( ...@@ -262,7 +262,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes() auto param_name = op->attributes()
.at("parameter_name") .at("parameter_name")
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .AsString();
auto value = op->result(0); auto value = op->result(0);
value_2_var_name->emplace(value, param_name); value_2_var_name->emplace(value, param_name);
} }
...@@ -306,7 +306,7 @@ void HandleForInplaceOp( ...@@ -306,7 +306,7 @@ void HandleForInplaceOp(
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
} }
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
...@@ -356,8 +356,10 @@ void BuildScope(const ir::Block& block, ...@@ -356,8 +356,10 @@ void BuildScope(const ir::Block& block,
std::string op_name = op->name(); std::string op_name = op->name();
if (op->attributes().count("op_name")) { if (op->attributes().count("op_name")) {
op_name = op_name = op->attributes()
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data(); .at("op_name")
.dyn_cast<ir::StrAttribute>()
.AsString();
} }
VLOG(4) << "build op:" << op_name; VLOG(4) << "build op:" << op_name;
......
...@@ -188,10 +188,10 @@ void BuildPhiContext( ...@@ -188,10 +188,10 @@ void BuildPhiContext(
} else if (attr_type_name == "ir::BoolAttribute") { } else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") { } else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data()); ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().AsString());
} else if (attr_type_name == } else if (attr_type_name ==
"ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") { "ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<phi::Scalar> vec_res; std::vector<phi::Scalar> vec_res;
if (array_list.size() > 0) { if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -207,7 +207,7 @@ void BuildPhiContext( ...@@ -207,7 +207,7 @@ void BuildPhiContext(
} }
ctx->EmplaceBackAttr(vec_res); ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int32_t> vec_res; std::vector<int32_t> vec_res;
if (array_list.size() > 0) { if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -222,7 +222,7 @@ void BuildPhiContext( ...@@ -222,7 +222,7 @@ void BuildPhiContext(
} }
ctx->EmplaceBackAttr(vec_res); ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<float> vec_res; std::vector<float> vec_res;
if (array_list.size() > 0) { if (array_list.size() > 0) {
if (array_list[0].isa<ir::FloatAttribute>()) { if (array_list[0].isa<ir::FloatAttribute>()) {
...@@ -238,7 +238,7 @@ void BuildPhiContext( ...@@ -238,7 +238,7 @@ void BuildPhiContext(
} }
ctx->EmplaceBackAttr(vec_res); ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int64_t> vec_res; std::vector<int64_t> vec_res;
if (array_list.size() > 0) { if (array_list.size() > 0) {
...@@ -255,7 +255,7 @@ void BuildPhiContext( ...@@ -255,7 +255,7 @@ void BuildPhiContext(
} }
ctx->EmplaceBackAttr(vec_res); ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") { } else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data(); auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int64_t> vec_res; std::vector<int64_t> vec_res;
if (array_list.size() > 0) { if (array_list.size() > 0) {
...@@ -286,7 +286,7 @@ void BuildPhiContext( ...@@ -286,7 +286,7 @@ void BuildPhiContext(
// TODO(phlrain): use var type instead of op name // TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") && if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() == (op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString() ==
"pd.fetch")) { "pd.fetch")) {
// process fetch op // process fetch op
auto fetch_var = inner_scope->FindVar("fetch"); auto fetch_var = inner_scope->FindVar("fetch");
......
...@@ -34,7 +34,7 @@ std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) { ...@@ -34,7 +34,7 @@ std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) {
std::string name = op->attributes() std::string name = op->attributes()
.at(op.attributes_name[0]) .at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data(); .AsString();
ir::Parameter* param = program->GetParameter(name); ir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null.")); param, phi::errors::InvalidArgument("Parameter should not be null."));
......
...@@ -223,7 +223,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue( ...@@ -223,7 +223,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
if (defining_op->HasAttribute(kAttrStopGradients)) { if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients) stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>() .dyn_cast<ir::ArrayAttribute>()
.data(); .AsVector();
} else { } else {
stop_gradients = std::vector<ir::Attribute>( stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false)); defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
......
...@@ -56,7 +56,7 @@ class DeadCodeEliminationPass : public ir::Pass { ...@@ -56,7 +56,7 @@ class DeadCodeEliminationPass : public ir::Pass {
get_parameter_op->attributes() get_parameter_op->attributes()
.at(get_parameter_op.attributes_name[0]) .at(get_parameter_op.attributes_name[0])
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.data()); .AsString());
} }
block->erase(*op); block->erase(*op);
} }
......
...@@ -15,27 +15,69 @@ ...@@ -15,27 +15,69 @@
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
namespace ir { namespace ir {
std::string StrAttribute::data() const { return storage()->GetAsKey(); }
uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); } bool BoolAttribute::data() const { return storage()->data(); }
bool BoolAttribute::data() const { return storage()->GetAsKey(); } float FloatAttribute::data() const { return storage()->data(); }
float FloatAttribute::data() const { return storage()->GetAsKey(); } double DoubleAttribute::data() const { return storage()->data(); }
double DoubleAttribute::data() const { return storage()->GetAsKey(); } int32_t Int32Attribute::data() const { return storage()->data(); }
int32_t Int32Attribute::data() const { return storage()->GetAsKey(); } int64_t Int64Attribute::data() const { return storage()->data(); }
int64_t Int64Attribute::data() const { return storage()->GetAsKey(); } void* PointerAttribute::data() const { return storage()->data(); }
std::vector<Attribute> ArrayAttribute::data() const { Type TypeAttribute::data() const { return storage()->data(); }
return storage()->GetAsKey();
bool StrAttribute::operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
std::string StrAttribute::AsString() const { return storage()->AsString(); }
size_t StrAttribute::size() const { return storage()->size(); }
StrAttribute StrAttribute::get(ir::IrContext* ctx, const std::string& value) {
return AttributeManager::get<StrAttribute>(ctx, value);
}
std::vector<Attribute> ArrayAttribute::AsVector() const {
return storage()->AsVector();
}
size_t ArrayAttribute::size() const { return storage()->size(); }
bool ArrayAttribute::empty() const { return storage()->empty(); }
Attribute ArrayAttribute::at(size_t index) const {
return storage()->at(index);
} }
void* PointerAttribute::data() const { return storage()->GetAsKey(); } ArrayAttribute ArrayAttribute::get(IrContext* ctx,
const std::vector<Attribute>& value) {
return AttributeManager::get<ArrayAttribute>(ctx, value);
}
Type TypeAttribute::data() const { return storage()->GetAsKey(); } ArrayAttributeStorage::ArrayAttributeStorage(const ParamKey& key)
: size_(key.size()) {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
data_ = static_cast<Attribute*>(
::operator new(size_ * sizeof(Attribute), std::align_val_t(align)));
} else {
data_ = static_cast<Attribute*>(::operator new(size_ * sizeof(Attribute)));
}
memcpy(data_, key.data(), sizeof(Attribute) * size_);
}
ArrayAttributeStorage::~ArrayAttributeStorage() {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
::operator delete(data_, std::align_val_t(align));
} else {
::operator delete(data_);
}
}
} // namespace ir } // namespace ir
......
...@@ -19,21 +19,6 @@ ...@@ -19,21 +19,6 @@
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
namespace ir { namespace ir {
class IR_API StrAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
bool operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
std::string data() const;
uint32_t size() const;
};
class IR_API BoolAttribute : public Attribute { class IR_API BoolAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
...@@ -79,37 +64,55 @@ class IR_API Int64Attribute : public Attribute { ...@@ -79,37 +64,55 @@ class IR_API Int64Attribute : public Attribute {
int64_t data() const; int64_t data() const;
}; };
class IR_API ArrayAttribute : public Attribute { class IR_API PointerAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage);
std::vector<Attribute> data() const; void* data() const;
};
size_t size() const { return data().size(); } class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;
bool empty() const { return data().empty(); } DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);
Attribute operator[](size_t index) const { return data()[index]; } Type data() const;
}; };
class IR_API PointerAttribute : public Attribute { class IR_API StrAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
void* data() const; bool operator<(const StrAttribute& right) const;
std::string AsString() const;
size_t size() const;
static StrAttribute get(IrContext* ctx, const std::string& value);
}; };
class IR_API TypeAttribute : public Attribute { class IR_API ArrayAttribute : public Attribute {
public: public:
using Attribute::Attribute; using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage); DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage);
Type data() const; std::vector<Attribute> AsVector() const;
size_t size() const;
bool empty() const;
Attribute at(size_t index) const;
static ArrayAttribute get(IrContext* ctx,
const std::vector<Attribute>& value);
}; };
} // namespace ir } // namespace ir
......
...@@ -20,32 +20,41 @@ ...@@ -20,32 +20,41 @@
#include "paddle/ir/core/attribute.h" #include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h" #include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
namespace ir { namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ #define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(ConcreteStorage, BaseType) \
struct concrete_storage : public ir::AttributeStorage { \ struct ConcreteStorage : public AttributeStorage { \
using ParamKey = base_type; \ using ParamKey = BaseType; \
\ \
explicit concrete_storage(const ParamKey &key) { data_ = key; } \ explicit ConcreteStorage(ParamKey key) { data_ = key; } \
\ \
static concrete_storage *Construct(const ParamKey &key) { \ static ConcreteStorage *Construct(ParamKey key) { \
return new concrete_storage(key); \ return new ConcreteStorage(key); \
} \ } \
\ \
static std::size_t HashValue(const ParamKey &key) { \ static size_t HashValue(ParamKey key) { \
return std::hash<base_type>()(key); \ return std::hash<ParamKey>{}(key); \
} \ } \
\ \
bool operator==(const ParamKey &key) const { return data_ == key; } \ bool operator==(ParamKey key) const { return data_ == key; } \
\ \
ParamKey GetAsKey() const { return data_; } \ BaseType data() const { return data_; } \
\ \
private: \ private: \
ParamKey data_; \ BaseType data_; \
}; }
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(TypeAttributeStorage, Type);
/// ///
/// \brief Define Parametric AttributeStorage for StrAttribute. /// \brief Define Parametric AttributeStorage for StrAttribute.
...@@ -53,53 +62,53 @@ namespace ir { ...@@ -53,53 +62,53 @@ namespace ir {
struct StrAttributeStorage : public AttributeStorage { struct StrAttributeStorage : public AttributeStorage {
using ParamKey = std::string; using ParamKey = std::string;
explicit StrAttributeStorage(const ParamKey &key) { explicit StrAttributeStorage(const ParamKey &key) : size_(key.size()) {
data_ = reinterpret_cast<char *>(malloc(key.size())); if (size_ > kLocalSize) {
memcpy(data_, key.c_str(), key.size()); data_ = static_cast<char *>(::operator new(size_));
size_ = key.size(); memcpy(data_, key.c_str(), size_);
} else {
memcpy(buff_, key.c_str(), size_);
}
} }
~StrAttributeStorage() { free(data_); } ~StrAttributeStorage() {
if (size_ > kLocalSize) ::operator delete(data_);
}
static StrAttributeStorage *Construct(const ParamKey &key) { static StrAttributeStorage *Construct(const ParamKey &key) {
return new StrAttributeStorage(key); return new StrAttributeStorage(key);
} }
static std::size_t HashValue(const ParamKey &key) { static size_t HashValue(const ParamKey &key) {
return std::hash<std::string>()(key); return std::hash<std::string>{}(key);
} }
bool operator==(const ParamKey &key) const { bool operator==(const ParamKey &key) const {
return std::equal(data_, data_ + size_, key.c_str()); if (size_ != key.size()) return false;
const char *data = size_ > kLocalSize ? data_ : buff_;
return std::equal(data, data + size_, key.c_str());
} }
ParamKey GetAsKey() const { return ParamKey(data_, size_); } // Note: The const char* is not end with '\0'.
const char *data() const { return size_ > kLocalSize ? data_ : buff_; }
size_t size() const { return size_; }
std::string AsString() const { return std::string(data(), size_); }
private: private:
static constexpr size_t kLocalSize = sizeof(void *) / sizeof(char);
union {
char *data_; char *data_;
uint32_t size_; char buff_[kLocalSize];
};
const size_t size_;
}; };
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage { struct ArrayAttributeStorage : public AttributeStorage {
using ParamKey = std::vector<Attribute>; using ParamKey = std::vector<Attribute>;
explicit ArrayAttributeStorage(const ParamKey &key) { explicit ArrayAttributeStorage(const ParamKey &key);
data_ =
reinterpret_cast<Attribute *>(malloc(sizeof(Attribute) * key.size()));
memcpy(reinterpret_cast<void *>(data_),
reinterpret_cast<const void *>(key.data()),
sizeof(Attribute) * key.size());
length_ = key.size();
}
~ArrayAttributeStorage() { free(reinterpret_cast<void *>(data_)); } ~ArrayAttributeStorage();
static ArrayAttributeStorage *Construct(const ParamKey &key) { static ArrayAttributeStorage *Construct(const ParamKey &key) {
return new ArrayAttributeStorage(key); return new ArrayAttributeStorage(key);
...@@ -114,43 +123,25 @@ struct ArrayAttributeStorage : public AttributeStorage { ...@@ -114,43 +123,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
} }
bool operator==(const ParamKey &key) const { bool operator==(const ParamKey &key) const {
if (key.size() != length_) { return key.size() == size_ && std::equal(key.begin(), key.end(), data_);
return false;
} }
for (size_t i = 0; i < length_; ++i) {
if (data_[i] != key[i]) {
return false;
}
}
return true;
}
ParamKey GetAsKey() const { return ParamKey(data_, data_ + length_); }
private: std::vector<Attribute> AsVector() const {
Attribute *data_ = nullptr; return std::vector<Attribute>(data_, data_ + size_);
size_t length_ = 0;
};
struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;
explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}
static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
} }
static std::size_t HashValue(const ParamKey &key) { size_t size() const { return size_; }
return std::hash<Type>()(key);
}
bool operator==(const ParamKey &key) const { return value_ == key; } bool empty() const { return size_ == 0u; }
ParamKey GetAsKey() const { return value_; } Attribute at(size_t index) const {
IR_ENFORCE(index < size_, "Invalid index");
return data_[index];
}
private: private:
Type value_; Attribute *data_;
const size_t size_;
}; };
} // namespace ir } // namespace ir
...@@ -85,7 +85,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { ...@@ -85,7 +85,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
} }
if (auto s = attr.dyn_cast<StrAttribute>()) { if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data(); os << s.AsString();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) { } else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data(); os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) { } else if (auto f = attr.dyn_cast<FloatAttribute>()) {
...@@ -99,7 +99,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) { ...@@ -99,7 +99,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
} else if (auto p = attr.dyn_cast<PointerAttribute>()) { } else if (auto p = attr.dyn_cast<PointerAttribute>()) {
os << p.data(); os << p.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) { } else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data(); const auto& vec = arr.AsVector();
os << "array["; os << "array[";
PrintInterleave( PrintInterleave(
vec.begin(), vec.begin(),
......
...@@ -232,7 +232,7 @@ class RedundantTransposeFusePattern ...@@ -232,7 +232,7 @@ class RedundantTransposeFusePattern
private: private:
std::vector<int> GetAxis(paddle::dialect::TransposeOp op) const { std::vector<int> GetAxis(paddle::dialect::TransposeOp op) const {
auto array_attr = op.attribute<ir::ArrayAttribute>("perm").data(); auto array_attr = op.attribute<ir::ArrayAttribute>("perm").AsVector();
std::vector<int> axis(array_attr.size()); std::vector<int> axis(array_attr.size());
for (size_t i = 0; i < array_attr.size(); ++i) { for (size_t i = 0; i < array_attr.size(); ++i) {
axis[i] = array_attr[i].dyn_cast<ir::Int32Attribute>().data(); axis[i] = array_attr[i].dyn_cast<ir::Int32Attribute>().data();
...@@ -333,7 +333,7 @@ class Conv2dBnFusePattern ...@@ -333,7 +333,7 @@ class Conv2dBnFusePattern
phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out()); phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1); std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format = std::string data_format =
new_conv2d_op.attribute<ir::StrAttribute>("data_format").data(); new_conv2d_op.attribute<ir::StrAttribute>("data_format").AsString();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now."); IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[1] = new_conv2d_out_shape[1]; new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op = paddle::dialect::ReshapeOp reshape_bias_op =
...@@ -503,7 +503,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, ...@@ -503,7 +503,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
strides.push_back(attributes.at("strides") strides.push_back(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>() .dyn_cast<ir::Int32Attribute>()
.data()); .data());
} }
...@@ -513,27 +514,30 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, ...@@ -513,27 +514,30 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
paddings_t.push_back(attributes.at("paddings_t") paddings_t.push_back(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>() .dyn_cast<ir::Int32Attribute>()
.data()); .data());
} }
std::string padding_algorithm = std::string padding_algorithm = attributes.at("padding_algorithm")
attributes.at("padding_algorithm").dyn_cast<ir::StrAttribute>().data(); .dyn_cast<ir::StrAttribute>()
.AsString();
std::vector<int> dilations_t; std::vector<int> dilations_t;
for (size_t i = 0; for (size_t i = 0;
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
dilations_t.push_back(attributes.at("dilations_t") dilations_t.push_back(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>() .dyn_cast<ir::Int32Attribute>()
.data()); .data());
} }
int groups = attributes.at("groups").dyn_cast<ir::Int32Attribute>().data(); int groups = attributes.at("groups").dyn_cast<ir::Int32Attribute>().data();
std::string data_format = std::string data_format =
attributes.at("data_format").dyn_cast<ir::StrAttribute>().data(); attributes.at("data_format").dyn_cast<ir::StrAttribute>().AsString();
std::string activation = std::string activation =
attributes.at("activation").dyn_cast<ir::StrAttribute>().data(); attributes.at("activation").dyn_cast<ir::StrAttribute>().AsString();
bool exhaustive_search = bool exhaustive_search =
attributes.at("exhaustive_search").dyn_cast<ir::BoolAttribute>().data(); attributes.at("exhaustive_search").dyn_cast<ir::BoolAttribute>().data();
std::vector<int> channels; std::vector<int> channels;
...@@ -541,7 +545,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder, ...@@ -541,7 +545,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
channels.push_back(attributes.at("channels") channels.push_back(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>() .dyn_cast<ir::Int32Attribute>()
.data()); .data());
} }
...@@ -776,7 +781,8 @@ void Conv2dFusionOpTest::Verify() { ...@@ -776,7 +781,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
PADDLE_ENFORCE(attributes.at("strides") PADDLE_ENFORCE(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(), .isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type of attribute: strides is not right.")); "Type of attribute: strides is not right."));
...@@ -789,7 +795,8 @@ void Conv2dFusionOpTest::Verify() { ...@@ -789,7 +795,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
PADDLE_ENFORCE(attributes.at("paddings_t") PADDLE_ENFORCE(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(), .isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type of attribute: paddings_t is not right.")); "Type of attribute: paddings_t is not right."));
...@@ -807,7 +814,8 @@ void Conv2dFusionOpTest::Verify() { ...@@ -807,7 +814,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
PADDLE_ENFORCE(attributes.at("dilations_t") PADDLE_ENFORCE(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(), .isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type of attribute: dilations_t is not right.")); "Type of attribute: dilations_t is not right."));
...@@ -837,7 +845,8 @@ void Conv2dFusionOpTest::Verify() { ...@@ -837,7 +845,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size(); i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) { i++) {
PADDLE_ENFORCE(attributes.at("channels") PADDLE_ENFORCE(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i] .dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(), .isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type of attribute: channels is not right.")); "Type of attribute: channels is not right."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册