未验证 提交 1f82bc37 编写于 作者: K kangguangli 提交者: GitHub

[IR] Refine some IR code (#54303)

* add vector type support for program translator

* polish

* support basic attribute type

* resolve conflicts

* add verify for combine/slice and unittests

* polish

* support more type in attribute translator

* modify by reviews

* fix merge mistakes

* refine code

* refine code

* add interface

* fix: op name normalization

* fix typo

* refactor input translator

* fix merge conflicts

* fix op normalizer bug

* refactor attribute translator

* fix bug

* refactor output translator

* fix typo

* fix

* fix approval error

* fix coverage

* fix op_compat parser

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* fix merge conflicts

* revert some changes

---------
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 279ac753
......@@ -18,7 +18,9 @@ namespace paddle {
namespace dialect {
phi::IntArray IntArrayAttribute::data() const { return storage()->GetAsKey(); }
phi::Scalar ScalarAttribute::data() const { return storage()->GetAsKey(); }
paddle::experimental::Scalar ScalarAttribute::data() const {
return storage()->GetAsKey();
}
phi::DataType DataTypeAttribute::data() const { return storage()->GetAsKey(); }
......
......@@ -43,7 +43,7 @@ class ScalarAttribute : public ir::Attribute {
return storage() < right.storage();
}
phi::Scalar data() const;
paddle::experimental::Scalar data() const;
};
class DataTypeAttribute : public ir::Attribute {
......
......@@ -55,7 +55,7 @@ struct IntArrayAttributeStorage : public ir::AttributeStorage {
};
struct ScalarAttributeStorage : public ir::AttributeStorage {
using ParamKey = phi::Scalar;
using ParamKey = paddle::experimental::Scalar;
explicit ScalarAttributeStorage(const ParamKey &key) { data_ = key; }
......@@ -73,7 +73,7 @@ struct ScalarAttributeStorage : public ir::AttributeStorage {
ParamKey GetAsKey() const { return ParamKey(data_); }
private:
phi::Scalar data_;
paddle::experimental::Scalar data_;
};
struct DataTypeAttributeStorage : public ir::AttributeStorage {
......
- name: feed
inputs:
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
inputs: []
attrs:
- {typename: int, name: col}
- {typename: str, name: name}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
......@@ -21,9 +16,8 @@
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int, name: col}
outputs:
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
- {typename: str, name: name}
outputs: []
no_need_buffer: null
data_transform: null
inplace: null
......
......@@ -18,38 +18,39 @@
#include "paddle/ir/core/builtin_type.h"
namespace ir {
BuiltinDialect::BuiltinDialect(ir::IrContext *context)
: ir::Dialect(name(), context, ir::TypeId::get<BuiltinDialect>()) {
BuiltinDialect::BuiltinDialect(IrContext *context)
: Dialect(name(), context, TypeId::get<BuiltinDialect>()) {
initialize();
}
void BuiltinDialect::initialize() {
// Register all built-in types defined in builtin_type.h.
RegisterTypes<ir::BFloat16Type,
ir::Float16Type,
ir::Float32Type,
ir::Float64Type,
ir::Int8Type,
ir::Int16Type,
ir::Int32Type,
ir::Int64Type,
ir::BoolType,
ir::VectorType>();
RegisterTypes<BFloat16Type,
Float16Type,
Float32Type,
Float64Type,
Int8Type,
Int16Type,
Int32Type,
Int64Type,
BoolType,
VectorType>();
RegisterAttributes<ir::StrAttribute,
ir::BoolAttribute,
ir::FloatAttribute,
ir::DoubleAttribute,
ir::PointerAttribute,
ir::Int32_tAttribute,
ir::Int64_tAttribute,
ir::ArrayAttribute>();
RegisterAttributes<StrAttribute,
BoolAttribute,
FloatAttribute,
DoubleAttribute,
PointerAttribute,
Int32_tAttribute,
Int64_tAttribute,
ArrayAttribute>();
RegisterOps<ir::ModuleOp,
ir::GetParameterOp,
ir::SetParameterOp,
ir::CombineOp,
ir::SliceOp>();
RegisterOps<ModuleOp,
GetParameterOp,
SetParameterOp,
CombineOp,
SliceOp,
ConstantOp>();
}
} // namespace ir
......@@ -214,4 +214,21 @@ void SliceOp::verify(const std::vector<ir::OpResult> &inputs,
outputs[0]));
}
void ConstantOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// inputs.size() == 0
PADDLE_ENFORCE_EQ(
inputs.size(),
0,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
}
} // namespace ir
......@@ -107,4 +107,18 @@ class SliceOp : public ir::Op<SliceOp> {
const ir::AttributeMap &attributes);
};
class ConstantOp : public ir::Op<ConstantOp> {
public:
using Op::Op;
static const char *name() { return "builtin.constant"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
};
} // namespace ir
......@@ -51,6 +51,11 @@ class Printer {
explicit Printer(std::ostream& os) : os(os) {}
void PrintType(ir::Type type) {
if (!type) {
os << "<!TypeNull>";
return;
}
if (type.isa<ir::Float16Type>()) {
os << "f16";
} else if (type.isa<ir::Float32Type>()) {
......@@ -83,10 +88,6 @@ class Printer {
};
void Type::print(std::ostream& os) const {
if (!*this) {
os << "<!TypeNull>";
return;
}
Printer p(os);
p.PrintType(*this);
}
......@@ -105,6 +106,10 @@ class ProgramPrinter : public Printer {
}
void PrintValue(ir::Value v) {
if (!v) {
os << "<<NULL VALUE>>";
return;
}
const void* key = static_cast<const void*>(v.impl());
auto ret = aliases.find(key);
if (ret != aliases.end()) {
......@@ -176,7 +181,12 @@ class ProgramPrinter : public Printer {
std::vector<ir::Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
auto op_operand = op->GetOperandByIndex(idx);
if (op_operand) {
op_operand_types.push_back(op->GetOperandByIndex(idx).source().type());
} else {
op_operand_types.push_back(ir::Type(nullptr));
}
}
PrintInterleave(
op_operand_types.begin(),
......@@ -190,7 +200,12 @@ class ProgramPrinter : public Printer {
std::vector<ir::Type> op_result_types;
op_result_types.reserve(num_op_result);
for (size_t idx = 0; idx < num_op_result; idx++) {
op_result_types.push_back(op->GetResultByIndex(idx).type());
auto op_result = op->GetResultByIndex(idx);
if (op_result) {
op_result_types.push_back(op_result.type());
} else {
op_result_types.push_back(ir::Type(nullptr));
}
}
PrintInterleave(
op_result_types.begin(),
......
......@@ -31,6 +31,7 @@ OpOperand &OpOperand::operator=(const detail::OpOperandImpl *impl) {
impl_ = const_cast<detail::OpOperandImpl *>(impl);
return *this;
}
OpOperand::operator bool() const { return impl_ && impl_->source(); }
OpOperand OpOperand::next_use() const { return impl_->next_use(); }
......@@ -38,8 +39,6 @@ Value OpOperand::source() const { return impl_->source(); }
Operation *OpOperand::owner() const { return impl_->owner(); }
// detail::OpOperandImpl *OpOperand::impl() const { return impl_; }
// Value
Value::Value(const detail::ValueImpl *impl)
: impl_(const_cast<detail::ValueImpl *>(impl)) {}
......@@ -108,6 +107,9 @@ void OpOperandImpl::release_source() { source_ = nullptr; }
OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
: source_(source), owner_(owner) {
if (!source) {
return;
}
prev_use_addr_ = source.impl()->first_use_addr();
next_use_ = source.impl()->first_use();
if (next_use_) {
......@@ -117,6 +119,7 @@ OpOperandImpl::OpOperandImpl(ir::Value source, ir::Operation *owner)
}
void OpOperandImpl::remove_from_ud_chain() {
if (!source_) return;
if (!prev_use_addr_) return;
if (prev_use_addr_ == source_.impl()->first_use_addr()) {
/// NOTE: In ValueImpl, first_use_offseted_by_index_ use lower three bits
......
......@@ -49,7 +49,7 @@ class OpOperand {
bool operator!() const { return impl_ == nullptr; }
operator bool() const { return impl_; }
operator bool() const;
OpOperand next_use() const;
......
......@@ -1617,7 +1617,7 @@
out : Out
- op : mean (reduce_mean)
backward : reduce_mean_grad
backward : mean_grad (reduce_mean_grad)
inputs :
x : X
outputs :
......@@ -1860,6 +1860,8 @@
- op : pool2d
backward : pool2d_grad
attrs:
kernel_size: ksize
extra :
attrs : [bool use_mkldnn = false, bool use_quantizer = false,
str mkldnn_data_type = "float32", bool is_test = false]
......@@ -2000,6 +2002,7 @@
x : X
outputs:
out : Out
xshape: XShape
int_array:
shape :
data_type : int
......
......@@ -254,13 +254,10 @@ TEST(program_test, slice_combine_test) {
ir::Operation::create({}, op1_attribute, {fp32_dtype}, op1_info);
program.block()->push_back(op1);
// (5) Def b = GetParameterOp("b")
std::string op2_name = std::string(ir::GetParameterOp::name());
// (5) Def b = Constant("b")
std::string op2_name = std::string(ir::ConstantOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, op2_attribute, {fp32_dtype}, op2_info);
ir::Operation *op2 = ir::Operation::create({}, {}, {fp32_dtype}, op2_info);
program.block()->push_back(op2);
// (6) Def combine_op = CombineOp("a", "b")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册