未验证 提交 896d7cfa 编写于 作者: W winter-wang 提交者: GitHub

[IR] finetune the StrAttribute interface. (#55439)

上级 5464b5c4
......@@ -47,7 +47,7 @@ OpFuncType AnalyseOpFuncType(ir::Operation* op, const platform::Place& place) {
// and so that they would be dispatched to host thread.
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
if (op_name == kCoalesceTensor &&
(!platform::is_xpu_place(place) ||
op->attribute<ir::BoolAttribute>("persist_output").data() == false) &&
......@@ -77,7 +77,7 @@ PhiKernelInstruction::PhiKernelInstruction(
: InstructionBase(id, place) {
auto op_attributes = op->attributes();
auto op_name =
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().data();
op_attributes.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
ir::OpInfo op_info = ir::IrContext::Instance()->GetRegisteredOpInfo(op_name);
phi_op_name_ = op_name;
......@@ -142,7 +142,7 @@ PhiKernelInstruction::PhiKernelInstruction(
VLOG(6) << "finish process infer meta context";
auto kernel_name =
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
op_attributes.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = op_attributes.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
......
......@@ -952,7 +952,8 @@ void BuildOpFuncList(
OpFuncNode op_func_node;
auto attr_map = (*it)->attributes();
auto op_name = attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().data();
auto op_name =
attr_map.at("op_name").dyn_cast<::ir::StrAttribute>().AsString();
op_func_node.phi_op_name_ = op_name;
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
......@@ -986,7 +987,7 @@ void BuildOpFuncList(
&(op_func_node.infer_meta_context_));
auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
......
......@@ -45,10 +45,10 @@ void PhiKernelOp::Verify() {
}
std::string PhiKernelOp::op_name() {
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
return attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
std::string PhiKernelOp::kernel_name() {
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().data();
return attributes().at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
}
phi::KernelKey PhiKernelOp::kernel_key() {
return attributes().at("kernel_key").dyn_cast<KernelAttribute>().data();
......
......@@ -542,11 +542,14 @@ def gen_build_func_str(
GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data();
"""
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<ir::StrAttribute>().AsString();
"""
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
{attr_type} {attribute_name};
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].dyn_cast<{inner_type}>().data());
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).dyn_cast<{inner_type}>().{data_name}());
}}
"""
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
......@@ -566,11 +569,15 @@ def gen_build_func_str(
# attr_type = "std::vector<int>"
if "ir::ArrayAttribute" in op_attribute_type_list[idx]:
inner_type = op_attribute_type_list[idx][19:-1]
data_name = "data"
if inner_type == "ir::StrAttribute":
data_name = "AsString"
get_attributes_str += (
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
attribute_name=op_attribute_name_list[idx],
inner_type=inner_type,
data_name=data_name,
)
)
elif (
......@@ -593,6 +600,14 @@ def gen_build_func_str(
attribute_name=op_attribute_name_list[idx],
)
)
elif "ir::StrAttribute" in op_attribute_type_list[idx]:
get_attributes_str += (
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
attribute_name=op_attribute_name_list[idx],
attr_ir_type=op_attribute_type_list[idx],
)
)
else:
get_attributes_str += GET_ATTRIBUTES_FROM_MAP_TEMPLATE.format(
attr_type=attr_type,
......
......@@ -78,7 +78,7 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """
PADDLE_ENFORCE(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(),
PADDLE_ENFORCE(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().at(i).isa<{standard}>(),
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """
......
......@@ -76,7 +76,8 @@ class PhiKernelAdaptor {
for (auto it = block->begin(); it != block->end(); ++it) {
auto attr_map = (*it)->attributes();
auto op_name = attr_map.at("op_name").dyn_cast<ir::StrAttribute>().data();
auto op_name =
attr_map.at("op_name").dyn_cast<ir::StrAttribute>().AsString();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op_name);
......@@ -104,7 +105,7 @@ class PhiKernelAdaptor {
infer_meta_impl->infer_meta_(&ctx);
auto kernel_name =
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().data();
attr_map.at("kernel_name").dyn_cast<ir::StrAttribute>().AsString();
auto kernel_key = attr_map.at("kernel_key")
.dyn_cast<paddle::dialect::KernelAttribute>()
.data();
......
......@@ -171,7 +171,7 @@ void HandleForSpecialOp(
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
if (op_name == "pd.fetch") {
......@@ -244,7 +244,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();
auto value = op->operand(0);
// change opreand name to param_name
......@@ -262,7 +262,7 @@ void HandleForSpecialOp(
auto param_name = op->attributes()
.at("parameter_name")
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();
auto value = op->result(0);
value_2_var_name->emplace(value, param_name);
}
......@@ -306,7 +306,7 @@ void HandleForInplaceOp(
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString();
}
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_name);
......@@ -356,8 +356,10 @@ void BuildScope(const ir::Block& block,
std::string op_name = op->name();
if (op->attributes().count("op_name")) {
op_name =
op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data();
op_name = op->attributes()
.at("op_name")
.dyn_cast<ir::StrAttribute>()
.AsString();
}
VLOG(4) << "build op:" << op_name;
......
......@@ -188,10 +188,10 @@ void BuildPhiContext(
} else if (attr_type_name == "ir::BoolAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::BoolAttribute>().data());
} else if (attr_type_name == "ir::StrAttribute") {
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().data());
ctx->EmplaceBackAttr(attr_map[t].dyn_cast<ir::StrAttribute>().AsString());
} else if (attr_type_name ==
"ir::ArrayAttribute<paddle::dialect::ScalarAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<phi::Scalar> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
......@@ -207,7 +207,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int32Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int32_t> vec_res;
if (array_list.size() > 0) {
PADDLE_ENFORCE_EQ(
......@@ -222,7 +222,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::FloatAttribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<float> vec_res;
if (array_list.size() > 0) {
if (array_list[0].isa<ir::FloatAttribute>()) {
......@@ -238,7 +238,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
......@@ -255,7 +255,7 @@ void BuildPhiContext(
}
ctx->EmplaceBackAttr(vec_res);
} else if (attr_type_name == "ir::ArrayAttribute<ir::Int64Attribute>") {
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().data();
auto array_list = attr_map[t].dyn_cast<ir::ArrayAttribute>().AsVector();
std::vector<int64_t> vec_res;
if (array_list.size() > 0) {
......@@ -286,7 +286,7 @@ void BuildPhiContext(
// TODO(phlrain): use var type instead of op name
if (op->attributes().count("op_name") &&
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().data() ==
(op->attributes().at("op_name").dyn_cast<ir::StrAttribute>().AsString() ==
"pd.fetch")) {
// process fetch op
auto fetch_var = inner_scope->FindVar("fetch");
......
......@@ -34,7 +34,7 @@ std::pair<std::string, ir::Parameter*> GetParameterFromValue(ir::Value value) {
std::string name = op->attributes()
.at(op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data();
.AsString();
ir::Parameter* param = program->GetParameter(name);
PADDLE_ENFORCE_NOT_NULL(
param, phi::errors::InvalidArgument("Parameter should not be null."));
......
......@@ -223,7 +223,7 @@ void ProgramTranslator::SetStopGradientAttributeForAllValue(
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.data();
.AsVector();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
......
......@@ -56,7 +56,7 @@ class DeadCodeEliminationPass : public ir::Pass {
get_parameter_op->attributes()
.at(get_parameter_op.attributes_name[0])
.dyn_cast<ir::StrAttribute>()
.data());
.AsString());
}
block->erase(*op);
}
......
......@@ -15,27 +15,69 @@
#include "paddle/ir/core/builtin_attribute.h"
namespace ir {
std::string StrAttribute::data() const { return storage()->GetAsKey(); }
uint32_t StrAttribute::size() const { return storage()->GetAsKey().size(); }
bool BoolAttribute::data() const { return storage()->data(); }
bool BoolAttribute::data() const { return storage()->GetAsKey(); }
float FloatAttribute::data() const { return storage()->data(); }
float FloatAttribute::data() const { return storage()->GetAsKey(); }
double DoubleAttribute::data() const { return storage()->data(); }
double DoubleAttribute::data() const { return storage()->GetAsKey(); }
int32_t Int32Attribute::data() const { return storage()->data(); }
int32_t Int32Attribute::data() const { return storage()->GetAsKey(); }
int64_t Int64Attribute::data() const { return storage()->data(); }
int64_t Int64Attribute::data() const { return storage()->GetAsKey(); }
void* PointerAttribute::data() const { return storage()->data(); }
std::vector<Attribute> ArrayAttribute::data() const {
return storage()->GetAsKey();
Type TypeAttribute::data() const { return storage()->data(); }
bool StrAttribute::operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
std::string StrAttribute::AsString() const { return storage()->AsString(); }
size_t StrAttribute::size() const { return storage()->size(); }
StrAttribute StrAttribute::get(ir::IrContext* ctx, const std::string& value) {
return AttributeManager::get<StrAttribute>(ctx, value);
}
std::vector<Attribute> ArrayAttribute::AsVector() const {
return storage()->AsVector();
}
size_t ArrayAttribute::size() const { return storage()->size(); }
bool ArrayAttribute::empty() const { return storage()->empty(); }
Attribute ArrayAttribute::at(size_t index) const {
return storage()->at(index);
}
void* PointerAttribute::data() const { return storage()->GetAsKey(); }
ArrayAttribute ArrayAttribute::get(IrContext* ctx,
const std::vector<Attribute>& value) {
return AttributeManager::get<ArrayAttribute>(ctx, value);
}
Type TypeAttribute::data() const { return storage()->GetAsKey(); }
ArrayAttributeStorage::ArrayAttributeStorage(const ParamKey& key)
: size_(key.size()) {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
data_ = static_cast<Attribute*>(
::operator new(size_ * sizeof(Attribute), std::align_val_t(align)));
} else {
data_ = static_cast<Attribute*>(::operator new(size_ * sizeof(Attribute)));
}
memcpy(data_, key.data(), sizeof(Attribute) * size_);
}
ArrayAttributeStorage::~ArrayAttributeStorage() {
constexpr size_t align = alignof(Attribute);
if (align > __STDCPP_DEFAULT_NEW_ALIGNMENT__) {
::operator delete(data_, std::align_val_t(align));
} else {
::operator delete(data_);
}
}
} // namespace ir
......
......@@ -19,21 +19,6 @@
#include "paddle/ir/core/utils.h"
namespace ir {
class IR_API StrAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
bool operator<(const StrAttribute& right) const {
return storage() < right.storage();
}
std::string data() const;
uint32_t size() const;
};
class IR_API BoolAttribute : public Attribute {
public:
using Attribute::Attribute;
......@@ -79,37 +64,55 @@ class IR_API Int64Attribute : public Attribute {
int64_t data() const;
};
class IR_API ArrayAttribute : public Attribute {
class IR_API PointerAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage);
std::vector<Attribute> data() const;
void* data() const;
};
size_t size() const { return data().size(); }
class IR_API TypeAttribute : public Attribute {
public:
using Attribute::Attribute;
bool empty() const { return data().empty(); }
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);
Attribute operator[](size_t index) const { return data()[index]; }
Type data() const;
};
class IR_API PointerAttribute : public Attribute {
class IR_API StrAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(PointerAttribute, PointerAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(StrAttribute, StrAttributeStorage);
void* data() const;
bool operator<(const StrAttribute& right) const;
std::string AsString() const;
size_t size() const;
static StrAttribute get(IrContext* ctx, const std::string& value);
};
class IR_API TypeAttribute : public Attribute {
class IR_API ArrayAttribute : public Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(TypeAttribute, TypeAttributeStorage);
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(ArrayAttribute, ArrayAttributeStorage);
Type data() const;
std::vector<Attribute> AsVector() const;
size_t size() const;
bool empty() const;
Attribute at(size_t index) const;
static ArrayAttribute get(IrContext* ctx,
const std::vector<Attribute>& value);
};
} // namespace ir
......
......@@ -20,32 +20,41 @@
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/attribute_base.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h"
namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \
struct concrete_storage : public ir::AttributeStorage { \
using ParamKey = base_type; \
\
explicit concrete_storage(const ParamKey &key) { data_ = key; } \
\
static concrete_storage *Construct(const ParamKey &key) { \
return new concrete_storage(key); \
} \
\
static std::size_t HashValue(const ParamKey &key) { \
return std::hash<base_type>()(key); \
} \
\
bool operator==(const ParamKey &key) const { return data_ == key; } \
\
ParamKey GetAsKey() const { return data_; } \
\
private: \
ParamKey data_; \
};
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(ConcreteStorage, BaseType) \
struct ConcreteStorage : public AttributeStorage { \
using ParamKey = BaseType; \
\
explicit ConcreteStorage(ParamKey key) { data_ = key; } \
\
static ConcreteStorage *Construct(ParamKey key) { \
return new ConcreteStorage(key); \
} \
\
static size_t HashValue(ParamKey key) { \
return std::hash<ParamKey>{}(key); \
} \
\
bool operator==(ParamKey key) const { return data_ == key; } \
\
BaseType data() const { return data_; } \
\
private: \
BaseType data_; \
}
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(TypeAttributeStorage, Type);
///
/// \brief Define Parametric AttributeStorage for StrAttribute.
......@@ -53,53 +62,53 @@ namespace ir {
struct StrAttributeStorage : public AttributeStorage {
using ParamKey = std::string;
explicit StrAttributeStorage(const ParamKey &key) {
data_ = reinterpret_cast<char *>(malloc(key.size()));
memcpy(data_, key.c_str(), key.size());
size_ = key.size();
explicit StrAttributeStorage(const ParamKey &key) : size_(key.size()) {
if (size_ > kLocalSize) {
data_ = static_cast<char *>(::operator new(size_));
memcpy(data_, key.c_str(), size_);
} else {
memcpy(buff_, key.c_str(), size_);
}
}
~StrAttributeStorage() { free(data_); }
~StrAttributeStorage() {
if (size_ > kLocalSize) ::operator delete(data_);
}
static StrAttributeStorage *Construct(const ParamKey &key) {
return new StrAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<std::string>()(key);
static size_t HashValue(const ParamKey &key) {
return std::hash<std::string>{}(key);
}
bool operator==(const ParamKey &key) const {
return std::equal(data_, data_ + size_, key.c_str());
if (size_ != key.size()) return false;
const char *data = size_ > kLocalSize ? data_ : buff_;
return std::equal(data, data + size_, key.c_str());
}
ParamKey GetAsKey() const { return ParamKey(data_, size_); }
// Note: The const char* is not end with '\0'.
const char *data() const { return size_ > kLocalSize ? data_ : buff_; }
size_t size() const { return size_; }
std::string AsString() const { return std::string(data(), size_); }
private:
char *data_;
uint32_t size_;
static constexpr size_t kLocalSize = sizeof(void *) / sizeof(char);
union {
char *data_;
char buff_[kLocalSize];
};
const size_t size_;
};
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(BoolAttributeStorage, bool);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(FloatAttributeStorage, float);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(DoubleAttributeStorage, double);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int32AttributeStorage, int32_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(Int64AttributeStorage, int64_t);
DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(PointerAttributeStorage, void *);
struct ArrayAttributeStorage : public AttributeStorage {
using ParamKey = std::vector<Attribute>;
explicit ArrayAttributeStorage(const ParamKey &key) {
data_ =
reinterpret_cast<Attribute *>(malloc(sizeof(Attribute) * key.size()));
memcpy(reinterpret_cast<void *>(data_),
reinterpret_cast<const void *>(key.data()),
sizeof(Attribute) * key.size());
length_ = key.size();
}
explicit ArrayAttributeStorage(const ParamKey &key);
~ArrayAttributeStorage() { free(reinterpret_cast<void *>(data_)); }
~ArrayAttributeStorage();
static ArrayAttributeStorage *Construct(const ParamKey &key) {
return new ArrayAttributeStorage(key);
......@@ -114,43 +123,25 @@ struct ArrayAttributeStorage : public AttributeStorage {
}
bool operator==(const ParamKey &key) const {
if (key.size() != length_) {
return false;
}
for (size_t i = 0; i < length_; ++i) {
if (data_[i] != key[i]) {
return false;
}
}
return true;
return key.size() == size_ && std::equal(key.begin(), key.end(), data_);
}
ParamKey GetAsKey() const { return ParamKey(data_, data_ + length_); }
private:
Attribute *data_ = nullptr;
size_t length_ = 0;
};
struct TypeAttributeStorage : public AttributeStorage {
using ParamKey = Type;
explicit TypeAttributeStorage(const ParamKey &key) : value_(key) {}
static TypeAttributeStorage *Construct(ParamKey key) {
return new TypeAttributeStorage(key);
std::vector<Attribute> AsVector() const {
return std::vector<Attribute>(data_, data_ + size_);
}
static std::size_t HashValue(const ParamKey &key) {
return std::hash<Type>()(key);
}
size_t size() const { return size_; }
bool operator==(const ParamKey &key) const { return value_ == key; }
bool empty() const { return size_ == 0u; }
ParamKey GetAsKey() const { return value_; }
Attribute at(size_t index) const {
IR_ENFORCE(index < size_, "Invalid index");
return data_[index];
}
private:
Type value_;
Attribute *data_;
const size_t size_;
};
} // namespace ir
......@@ -85,7 +85,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
}
if (auto s = attr.dyn_cast<StrAttribute>()) {
os << s.data();
os << s.AsString();
} else if (auto b = attr.dyn_cast<BoolAttribute>()) {
os << b.data();
} else if (auto f = attr.dyn_cast<FloatAttribute>()) {
......@@ -99,7 +99,7 @@ void BasicIrPrinter::PrintAttribute(Attribute attr) {
} else if (auto p = attr.dyn_cast<PointerAttribute>()) {
os << p.data();
} else if (auto arr = attr.dyn_cast<ArrayAttribute>()) {
const auto& vec = arr.data();
const auto& vec = arr.AsVector();
os << "array[";
PrintInterleave(
vec.begin(),
......
......@@ -232,7 +232,7 @@ class RedundantTransposeFusePattern
private:
std::vector<int> GetAxis(paddle::dialect::TransposeOp op) const {
auto array_attr = op.attribute<ir::ArrayAttribute>("perm").data();
auto array_attr = op.attribute<ir::ArrayAttribute>("perm").AsVector();
std::vector<int> axis(array_attr.size());
for (size_t i = 0; i < array_attr.size(); ++i) {
axis[i] = array_attr[i].dyn_cast<ir::Int32Attribute>().data();
......@@ -333,7 +333,7 @@ class Conv2dBnFusePattern
phi::DDim new_conv2d_out_shape = ir::GetShapeFromValue(new_conv2d_op.out());
std::vector<int64_t> new_bias_new_shape(new_conv2d_out_shape.size(), 1);
std::string data_format =
new_conv2d_op.attribute<ir::StrAttribute>("data_format").data();
new_conv2d_op.attribute<ir::StrAttribute>("data_format").AsString();
IR_ENFORCE(data_format == "NCHW", "Only support NCHW now.");
new_bias_new_shape[1] = new_conv2d_out_shape[1];
paddle::dialect::ReshapeOp reshape_bias_op =
......@@ -503,7 +503,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) {
strides.push_back(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>()
.data());
}
......@@ -513,27 +514,30 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
paddings_t.push_back(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>()
.data());
}
std::string padding_algorithm =
attributes.at("padding_algorithm").dyn_cast<ir::StrAttribute>().data();
std::string padding_algorithm = attributes.at("padding_algorithm")
.dyn_cast<ir::StrAttribute>()
.AsString();
std::vector<int> dilations_t;
for (size_t i = 0;
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
dilations_t.push_back(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>()
.data());
}
int groups = attributes.at("groups").dyn_cast<ir::Int32Attribute>().data();
std::string data_format =
attributes.at("data_format").dyn_cast<ir::StrAttribute>().data();
attributes.at("data_format").dyn_cast<ir::StrAttribute>().AsString();
std::string activation =
attributes.at("activation").dyn_cast<ir::StrAttribute>().data();
attributes.at("activation").dyn_cast<ir::StrAttribute>().AsString();
bool exhaustive_search =
attributes.at("exhaustive_search").dyn_cast<ir::BoolAttribute>().data();
std::vector<int> channels;
......@@ -541,7 +545,8 @@ void Conv2dFusionOpTest::Build(ir::Builder &builder,
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) {
channels.push_back(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.dyn_cast<ir::Int32Attribute>()
.data());
}
......@@ -776,7 +781,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("strides").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("strides")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: strides is not right."));
......@@ -789,7 +795,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("paddings_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("paddings_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: paddings_t is not right."));
......@@ -807,7 +814,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("dilations_t").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("dilations_t")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: dilations_t is not right."));
......@@ -837,7 +845,8 @@ void Conv2dFusionOpTest::Verify() {
i < attributes.at("channels").dyn_cast<ir::ArrayAttribute>().size();
i++) {
PADDLE_ENFORCE(attributes.at("channels")
.dyn_cast<ir::ArrayAttribute>()[i]
.dyn_cast<ir::ArrayAttribute>()
.at(i)
.isa<ir::Int32Attribute>(),
phi::errors::PreconditionNotMet(
"Type of attribute: channels is not right."));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册