未验证 提交 ef29468e 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【new ir】add ir pybind api (#55745)

* add ir core

* add test

* modify name

* merge

* add test for __eq__

* shield  test for __eq__

* --amend

* Update new_ir_compiler.cc
上级 683287ba
......@@ -79,7 +79,7 @@ std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op,
VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i);
auto in_value = op.operand_source(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value));
......@@ -215,7 +215,7 @@ std::vector<std::string> NewIRCompiler::OpGetInputNames(
std::vector<std::string> names;
std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i);
auto value = op.operand_source(i);
std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) {
......@@ -264,7 +264,7 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i);
auto in_value = (*it)->operand_source(i);
create_var(CompatibleInfo::kInputPrefix, in_value);
}
......
......@@ -266,7 +266,7 @@ PhiKernelInstruction::PhiKernelInstruction(
auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds();
std::unordered_set<::ir::Value> no_need_buffer_values;
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id]));
no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id]));
}
SetNoNeedBuffer(no_need_buffer_values);
VLOG(6) << "finish process no need buffer";
......@@ -302,7 +302,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
variable_2_var_name) {
std::unordered_map<ir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); i++) {
ir::Value value = op->operand(i);
ir::Value value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
......
......@@ -14,7 +14,7 @@
# generator op member function
OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }}
OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
"""
......
......@@ -40,26 +40,26 @@ void {op_name}::Verify() {{}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{
if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}
}}
else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(),
PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
......
......@@ -140,7 +140,7 @@ void CheckInputVars(
size_t input_num = op->num_operands();
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
if (value) {
PADDLE_ENFORCE_NE(
value_2_var_name.find(value),
......@@ -298,7 +298,7 @@ void HandleForSpecialOp(
tensor_array->clear();
size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i);
auto value = op->operand_source(i);
PADDLE_ENFORCE_EQ(
value_2_var_name->count(value),
true,
......@@ -315,7 +315,7 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>()
.AsString();
auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);
......@@ -336,7 +336,7 @@ void HandleForSpecialOp(
auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->operand(0);
auto value = op->operand_source(0);
// change opreand name to param_name
auto orig_name = value_2_var_name->at(value);
......@@ -372,7 +372,7 @@ void HandleForSpecialOp(
if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand(0);
auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true,
phi::errors::PreconditionNotMet(
......@@ -426,7 +426,7 @@ void HandleForInplaceOp(
if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name));
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
......@@ -547,7 +547,7 @@ void BuildRuntimeContext(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);
auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;
......@@ -603,7 +603,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
true,
phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);
auto in_var_name = name_map.at(ptr);
......
......@@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op,
true,
phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.InputName2Id().at(t);
ir::Value ptr = op->operand(index);
ir::Value ptr = op->operand_source(index);
if (!ptr) {
phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr);
......@@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op,
for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) {
// tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t));
ir::Value ptr = op->operand_source(name2id.at(t));
auto in_var_name = name_map.at(ptr);
......
......@@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(),
op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
true,
phi::errors::InvalidArgument(
"Op's input must be a dense tensor type."));
auto [param_name, param] = ir::GetParameterFromValue(op->operand(i));
auto [param_name, param] =
ir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param));
......@@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern {
param_var,
phi::errors::InvalidArgument("Parameter var not in scope."));
auto get_parameter_op =
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type());
auto get_parameter_op = builder.Build<ir::GetParameterOp>(
param_name, op->operand_source(i).type());
op_inputs.push_back(get_parameter_op->result(0));
}
......
......@@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey(
} else if (input_map.count(slot_name)) {
// parse from input
int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand(in_index)).type();
auto type = map_value_pair.at(op->operand_source(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType(
......@@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey(
if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).GetDefiningOp();
auto define_op = op->operand_source(0).GetDefiningOp();
if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes()
.at("value")
......@@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey(
if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue;
}
auto input_tmp = op->operand(i);
auto input_tmp = op->operand_source(i);
// NOTE: if not input_tmp, it's an optional input
if (!input_tmp) {
continue;
......@@ -341,7 +341,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i);
auto cur_in = (*it)->operand_source(i);
if (!cur_in) {
vec_inputs.push_back(ir::OpResult());
continue;
......
......@@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(),
true,
phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp();
return op->operand_source(index).GetDefiningOp();
}
Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
......
......@@ -77,6 +77,8 @@ void BindProgram(py::module *m) {
void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); })
.def("get_ops",
[](Block &self) -> py::list {
py::list op_list;
......@@ -94,19 +96,22 @@ void BindBlock(py::module *m) {
void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name)
.def("get_parent",
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent),
return_value_policy::reference)
.def("get_parent",
.def("get_parent_block",
py::overload_cast<>(&Operation::GetParent, py::const_),
return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results)
.def("operand", &Operation::operand)
.def("result", &Operation::result)
.def("operand_source", &Operation::operand_source)
.def("operands",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.op_operand(i));
op_list.append(self.operand(i));
}
return op_list;
})
......@@ -118,6 +123,14 @@ void BindOperation(py::module *m) {
}
return op_list;
})
.def("operands_source",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.operand_source(i));
}
return op_list;
})
.def("get_input_names",
[](Operation &self) -> py::list {
py::list op_list;
......@@ -159,8 +172,11 @@ void BindOperation(py::module *m) {
void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value");
value.def(
"get_defining_op", &Value::GetDefiningOp, return_value_policy::reference);
value
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
}
void BindOpOperand(py::module *m) {
......
......@@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->operand(idx));
op_operands.push_back(op->operand_source(idx));
}
PrintInterleave(
op_operands.begin(),
......@@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->op_operand(idx);
auto op_operand = op->operand(idx);
if (op_operand) {
op_operand_types.push_back(op_operand.type());
} else {
......
......@@ -88,7 +88,9 @@ class IR_API OpBase {
const AttributeMap &attributes() const { return operation()->attributes(); }
Value operand(uint32_t index) const { return operation()->operand(index); }
Value operand_source(uint32_t index) const {
return operation()->operand_source(index);
}
OpResult result(uint32_t index) const { return operation()->result(index); }
......
......@@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand.
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes,
const std::vector<ir::Type> &output_types,
......@@ -132,7 +132,7 @@ void Operation::Destroy() {
// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
op_operand(idx).impl()->~OpOperandImpl();
operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
......@@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const {
}
}
OpOperand Operation::op_operand(uint32_t index) const {
OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) {
IR_THROW("index exceeds OP input range.");
}
......@@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const {
return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
}
Value Operation::operand(uint32_t index) const {
OpOperand val = op_operand(index);
Value Operation::operand_source(uint32_t index) const {
OpOperand val = operand(index);
return val ? val.source() : Value();
}
......
......@@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final {
OpResult result(uint32_t index) const;
OpOperand op_operand(uint32_t index) const;
OpOperand operand(uint32_t index) const;
Value operand(uint32_t index) const;
Value operand_source(uint32_t index) const;
/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
......
......@@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i));
AddOperandToWorklist(op->operand_source(i));
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
......
......@@ -174,9 +174,9 @@ TEST(program_test, program) {
// (8) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id());
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
......
......@@ -91,10 +91,10 @@ TEST(value_test, value_test) {
// Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->result(0);
ir::OpOperand op4_first_input = op4->op_operand(0);
ir::OpOperand op4_first_input = op4->operand(0);
EXPECT_EQ(op1_first_output.first_use(), op4_first_input);
ir::OpOperand op3_first_input = op3->op_operand(0);
ir::OpOperand op3_first_input = op3->operand(0);
EXPECT_EQ(op4_first_input.next_use(), op3_first_input);
EXPECT_EQ(op3_first_input.next_use(), nullptr);
......@@ -110,11 +110,11 @@ TEST(value_test, value_test) {
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
//
c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1), b);
EXPECT_EQ(op4->operand_source(1), b);
EXPECT_TRUE(c.use_empty());
b.ReplaceAllUsesWith(a);
EXPECT_EQ(op4->operand(1), a);
EXPECT_EQ(op4->operand_source(1), a);
EXPECT_TRUE(b.use_empty());
// destroy
......
......@@ -386,10 +386,10 @@ class Conv2dFusionOpTest : public ir::Op<Conv2dFusionOpTest,
ir::OpResult residual_,
ir::AttributeMap attributes);
void Verify();
ir::Value input() { return operand(0); }
ir::Value filter() { return operand(1); }
ir::Value bias() { return operand(2); }
ir::Value residual() { return operand(3); }
ir::Value input() { return operand_source(0); }
ir::Value filter() { return operand_source(1); }
ir::Value bias() { return operand_source(2); }
ir::Value residual() { return operand_source(3); }
ir::OpResult output() { return result(0); }
ir::OpResult outputs() { return result(1); }
ir::Attribute attribute(const std::string &name) {
......@@ -752,19 +752,25 @@ void Conv2dFusionOpTest::Verify() {
4u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 4.", input_size));
PADDLE_ENFORCE(
(*this)->operand(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE(
(*this)->operand(1).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
PADDLE_ENFORCE(
(*this)->operand(2).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input."));
if (auto val = (*this)->op_operand(3)) {
PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
PADDLE_ENFORCE((*this)
->operand_source(2)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input."));
if (auto val = (*this)->operand(3)) {
PADDLE_ENFORCE(val.type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 3th input."));
......
......@@ -30,8 +30,8 @@ def get_ir_program():
x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s)
y_s = paddle.tanh(y_s)
z_s = paddle.add(y_s, y_s)
k_s = paddle.tanh(z_s)
newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program
......@@ -41,6 +41,11 @@ class TestPybind(unittest.TestCase):
newir_program = get_ir_program()
print(newir_program)
block = newir_program.block()
program = block.get_parent_program()
self.assertEqual(newir_program, program)
def test_block(self):
newir_program = get_ir_program()
block = newir_program.block()
......@@ -57,7 +62,7 @@ class TestPybind(unittest.TestCase):
matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3]
parent_block = tanh_op.get_parent()
parent_block = tanh_op.get_parent_block()
parent_ops_num = len(parent_block.get_ops())
self.assertEqual(parent_ops_num, 4)
self.assertEqual(tanh_op.num_results(), 1)
......@@ -79,6 +84,13 @@ class TestPybind(unittest.TestCase):
matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)
result_set = set()
for opresult in matmul_op.results():
result_set.add(opresult)
# self.assertTrue(add_op.operands()[0].source() in result_set)
# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],)
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
)
......@@ -87,6 +99,11 @@ class TestPybind(unittest.TestCase):
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul"
)
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
)
self.assertEqual(add_op.result(0).use_empty(), True)
def test_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册