未验证 提交 1f82bc37 编写于 作者: K kangguangli 提交者: GitHub

[IR] Refine some IR code (#54303)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* revert some changes

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 279ac753
...@@ -18,7 +18,9 @@ namespace paddle { ...@@ -18,7 +18,9 @@ namespace paddle {
namespace dialect { namespace dialect {
phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); } phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); }
phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); } paddle::experimental::Scalar ScalarAttribute::data() const {
return storage()->GetAsKey();
}
phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); } phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); }
......
...@@ -43,7 +43,7 @@ class ScalarAttribute : public ir::Attribute { ...@@ -43,7 +43,7 @@ class ScalarAttribute : public ir::Attribute {
return storage() < right.storage(); return storage() < right.storage();
} }
phi::Scalar data() const; paddle::experimental::Scalar data() const;
}; };
class DataTypeAttribute : public ir::Attribute { class DataTypeAttribute : public ir::Attribute {
......
...@@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage { ...@@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage {
}; };
struct ScalarAttributeStorage : public ir::AttributeStorage { struct ScalarAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::Scalar; using ParamKey = paddle::experimental::Scalar;
explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; } explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; }
...@@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage { ...@@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage {
ParamKey GetAsKey() const { return ParamKey(data_); } ParamKey GetAsKey() const { return ParamKey(data_); }
private: private:
phi::Scalar data_; paddle::experimental::Scalar data_;
}; };
struct DataTypeAttributeStorage : public ir::AttributeStorage { struct DataTypeAttributeStorage : public ir::AttributeStorage {
......
- name: feed - name: feed
inputs: inputs: []
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs: attrs:
- {typename: int, name: col} - {typename: str, name: name}
outputs: outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false} - {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
...@@ -21,9 +16,8 @@ ...@@ -21,9 +16,8 @@
no_need_buffer: false no_need_buffer: false
data_transform: {} data_transform: {}
attrs: attrs:
- {typename: int, name: col} - {typename: str, name: name}
outputs: outputs: []
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
inplace: null inplace: null
......
...@@ -18,38 +18,39 @@ ...@@ -18,38 +18,39 @@
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
namespace ir { namespace ir {
BuiltinDialect::BuiltinDialect(ir::IrContext *context) BuiltinDialect::BuiltinDialect(IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<BuiltinDialect>()) { : Dialect(name(), context, TypeId::get<BuiltinDialect>()) {
initialize(); initialize();
} }
void BuiltinDialect::initialize() { void BuiltinDialect::initialize() {
// Register all built-in types defined in builtin_type.h. // Register all built-in types defined in builtin_type.h.
RegisterTypes<ir::BFloat16Type, RegisterTypes<BFloat16Type,
ir::Float16Type, Float16Type,
ir::Float32Type, Float32Type,
ir::Float64Type, Float64Type,
ir::Int8Type, Int8Type,
ir::Int16Type, Int16Type,
ir::Int32Type, Int32Type,
ir::Int64Type, Int64Type,
ir::BoolType, BoolType,
ir::VectorType>(); VectorType>();
RegisterAttributes<ir::StrAttribute, RegisterAttributes<StrAttribute,
ir::BoolAttribute, BoolAttribute,
ir::FloatAttribute, FloatAttribute,
ir::DoubleAttribute, DoubleAttribute,
ir::PointerAttribute, PointerAttribute,
ir::Int32_tAttribute, Int32_tAttribute,
ir::Int64_tAttribute, Int64_tAttribute,
ir::ArrayAttribute>(); ArrayAttribute>();
RegisterOps<ir::ModuleOp, RegisterOps<ModuleOp,
ir::GetParameterOp, GetParameterOp,
ir::SetParameterOp, SetParameterOp,
ir::CombineOp, CombineOp,
ir::SliceOp>(); SliceOp,
ConstantOp>();
} }
} // namespace ir } // namespace ir
...@@ -214,4 +214,21 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -214,4 +214,21 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs,
outputs[0])); outputs[0]));
} }
void ConstantOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// inputs.size() == 0
PADDLE_ENFORCE_EQ(
inputs.size(),
0,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
}
} // namespace ir } // namespace ir
...@@ -107,4 +107,18 @@ class SliceOp : public ir::Op<SliceOp> { ...@@ -107,4 +107,18 @@ class SliceOp : public ir::Op<SliceOp> {
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
class ConstantOp : public ir::Op<ConstantOp> {
public:
using Op::Op;
static const char *name() { return "builtin.constant"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
};
} // namespace ir } // namespace ir
...@@ -51,6 +51,11 @@ class Printer { ...@@ -51,6 +51,11 @@ class Printer {
explicit Printer(std::ostream& os) : os(os) {} explicit Printer(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) { void PrintType(ir::Type type) {
if (!type) {
os << "<!TypeNull>";
return;
}
if (type.isa<ir::Float16Type>()) { if (type.isa<ir::Float16Type>()) {
os << "f16"; os << "f16";
} else if (type.isa<ir::Float32Type>()) { } else if (type.isa<ir::Float32Type>()) {
...@@ -83,10 +88,6 @@ class Printer { ...@@ -83,10 +88,6 @@ class Printer {
}; };
void Type::print(std::ostream& os) const { void Type::print(std::ostream& os) const {
if (!*this) {
os << "<!TypeNull>";
return;
}
Printer p(os); Printer p(os);
p.PrintType(*this); p.PrintType(*this);
} }
...@@ -105,6 +106,10 @@ class ProgramPrinter : public Printer { ...@@ -105,6 +106,10 @@ class ProgramPrinter : public Printer {
} }
void PrintValue(ir::Value v) { void PrintValue(ir::Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl()); const void* key = static_cast<const void*>(v.impl());
auto ret = aliases.find(key); auto ret = aliases.find(key);
if (ret != aliases.end()) { if (ret != aliases.end()) {
...@@ -176,7 +181,12 @@ class ProgramPrinter : public Printer { ...@@ -176,7 +181,12 @@ class ProgramPrinter : public Printer {
std::vector<ir::Type> op_operand_types; std::vector<ir::Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type()); op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else {
op_operand_types.push_back(ir::Type(nullptr));
}
} }
PrintInterleave( PrintInterleave(
op_operand_types.begin(), op_operand_types.begin(),
...@@ -190,7 +200,12 @@ class ProgramPrinter : public Printer { ...@@ -190,7 +200,12 @@ class ProgramPrinter : public Printer {
std::vector<ir::Type> op_result_types; std::vector<ir::Type> op_result_types;
op_result_types.reserve(num_op_result); op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) { for (size_t idx = 0; idx < num_op_result; idx++) {
op_result_types.push_back(op->GetResultByIndex(idx).type()); auto op_result = op->GetResultByIndex(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
op_result_types.push_back(ir::Type(nullptr));
}
} }
PrintInterleave( PrintInterleave(
op_result_types.begin(), op_result_types.begin(),
......
...@@ -31,6 +31,7 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) { ...@@ -31,6 +31,7 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
impl_ = const_cast<detail::OpOperandImpl *>(impl); impl_ = const_cast<detail::OpOperandImpl *>(impl);
return *this; return *this;
} }
OpOperand::operator bool() const { return impl_ && impl_->source(); }
OpOperand OpOperand::next_use() const { return impl_->next_use(); } OpOperand OpOperand::next_use() const { return impl_->next_use(); }
...@@ -38,8 +39,6 @@ Value OpOperand::source() const { return impl_->source(); } ...@@ -38,8 +39,6 @@ Value OpOperand::source() const { return impl_->source(); }
Operation *OpOperand::owner() const { return impl_->owner(); } Operation *OpOperand::owner() const { return impl_->owner(); }
// detail::OpOperandImpl *OpOperand::impl() const { return impl_; }
// Value // Value
Value::Value(const detail::ValueImpl *impl) Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {} : impl_(const_cast<detail::ValueImpl *>(impl)) {}
...@@ -108,6 +107,9 @@ void OpOperandImpl::release_source() { source_ = nullptr; } ...@@ -108,6 +107,9 @@ void OpOperandImpl::release_source() { source_ = nullptr; }
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
: source_(source), owner_(owner) { : source_(source), owner_(owner) {
if (!source) {
return;
}
prev_use_addr_ = source.impl()->first_use_addr(); prev_use_addr_ = source.impl()->first_use_addr();
next_use_ = source.impl()->first_use(); next_use_ = source.impl()->first_use();
if (next_use_) { if (next_use_) {
...@@ -117,6 +119,7 @@ OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner) ...@@ -117,6 +119,7 @@ OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
} }
void OpOperandImpl::remove_from_ud_chain() { void OpOperandImpl::remove_from_ud_chain() {
if (!source_) return;
if (!prev_use_addr_) return; if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) { if (prev_use_addr_ == source_.impl()->first_use_addr()) {
/// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits /// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits
......
...@@ -49,7 +49,7 @@ class OpOperand { ...@@ -49,7 +49,7 @@ class OpOperand {
bool operator!() const { return impl_ == nullptr; } bool operator!() const { return impl_ == nullptr; }
operator bool() const { return impl_; } operator bool() const;
OpOperand next_use() const; OpOperand next_use() const;
......
...@@ -1617,7 +1617,7 @@ ...@@ -1617,7 +1617,7 @@
out : Out out : Out
- op : mean (reduce_mean) - op : mean (reduce_mean)
backward : reduce_mean_grad backward : mean_grad (reduce_mean_grad)
inputs : inputs :
x : X x : X
outputs : outputs :
...@@ -1860,6 +1860,8 @@ ...@@ -1860,6 +1860,8 @@
- op : pool2d - op : pool2d
backward : pool2d_grad backward : pool2d_grad
attrs:
kernel_size: ksize
extra : extra :
attrs : [bool use_mkldnn = false, bool use_quantizer = false, attrs : [bool use_mkldnn = false, bool use_quantizer = false,
str mkldnn_data_type = "float32", bool is_test = false] str mkldnn_data_type = "float32", bool is_test = false]
...@@ -2000,6 +2002,7 @@ ...@@ -2000,6 +2002,7 @@
x : X x : X
outputs: outputs:
out : Out out : Out
xshape: XShape
int_array: int_array:
shape : shape :
data_type : int data_type : int
......
...@@ -254,13 +254,10 @@ TEST(program_test, slice_combine_test) { ...@@ -254,13 +254,10 @@ TEST(program_test, slice_combine_test) {
ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info); ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info);
program.block()->push_back(op1); program.block()->push_back(op1);
// (5) Def b = GetParameterOp("b") // (5) Def b = Constant("b")
std::string op2_name = std::string(ir::GetParameterOp::name()); std::string op2_name = std::string(ir::ConstantOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{ ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info);
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info);
program.block()->push_back(op2); program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b") // (6) Def combine_op = CombineOp("a", "b")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册