未验证 提交 ef29468e 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【new ir】add ir pybind api (#55745)

* add ir core

* add test

* modify name

* merge

* add test for __eq__

* shield  test for __eq__

* --amend

* Update new_ir_compiler.cc
上级 683287ba
...@@ -79,7 +79,7 @@ std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op, ...@@ -79,7 +79,7 @@ std::vector<ir::LoweredFunc> NewIRCompiler::GetOpFunc(const ::ir::Operation& op,
VLOG(4) << "GetOpFunc for op: " << op_name; VLOG(4) << "GetOpFunc for op: " << op_name;
// step 1: Deal with Oprands // step 1: Deal with Oprands
for (int i = 0; i < op.num_operands(); ++i) { for (int i = 0; i < op.num_operands(); ++i) {
auto in_value = op.operand(i); auto in_value = op.operand_source(i);
// TODO(Aurelius84): For now, use addr as name but it's not wise. // TODO(Aurelius84): For now, use addr as name but it's not wise.
std::string input_id = CompatibleInfo::kInputPrefix + std::string input_id = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(in_value)); std::to_string(std::hash<::ir::Value>()(in_value));
...@@ -215,7 +215,7 @@ std::vector<std::string> NewIRCompiler::OpGetInputNames( ...@@ -215,7 +215,7 @@ std::vector<std::string> NewIRCompiler::OpGetInputNames(
std::vector<std::string> names; std::vector<std::string> names;
std::unordered_set<std::string> repeat; std::unordered_set<std::string> repeat;
for (int i = 0; i < op.num_operands(); ++i) { for (int i = 0; i < op.num_operands(); ++i) {
auto value = op.operand(i); auto value = op.operand_source(i);
std::string name = CompatibleInfo::kInputPrefix + std::string name = CompatibleInfo::kInputPrefix +
std::to_string(std::hash<::ir::Value>()(value)); std::to_string(std::hash<::ir::Value>()(value));
if (repeat.count(name)) { if (repeat.count(name)) {
...@@ -264,7 +264,7 @@ std::shared_ptr<Scope> BuildScope(const Target& target, ...@@ -264,7 +264,7 @@ std::shared_ptr<Scope> BuildScope(const Target& target,
for (auto it = program.block()->begin(); it != program.block()->end(); ++it) { for (auto it = program.block()->begin(); it != program.block()->end(); ++it) {
for (auto i = 0; i < (*it)->num_operands(); ++i) { for (auto i = 0; i < (*it)->num_operands(); ++i) {
auto in_value = (*it)->operand(i); auto in_value = (*it)->operand_source(i);
create_var(CompatibleInfo::kInputPrefix, in_value); create_var(CompatibleInfo::kInputPrefix, in_value);
} }
......
...@@ -266,7 +266,7 @@ PhiKernelInstruction::PhiKernelInstruction( ...@@ -266,7 +266,7 @@ PhiKernelInstruction::PhiKernelInstruction(
auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds(); auto& no_need_buffer_ids = yaml_info_parser.NoNeedBufferIds();
std::unordered_set<::ir::Value> no_need_buffer_values; std::unordered_set<::ir::Value> no_need_buffer_values;
for (size_t id = 0; id < no_need_buffer_ids.size(); id++) { for (size_t id = 0; id < no_need_buffer_ids.size(); id++) {
no_need_buffer_values.insert(op->operand(no_need_buffer_ids[id])); no_need_buffer_values.insert(op->operand_source(no_need_buffer_ids[id]));
} }
SetNoNeedBuffer(no_need_buffer_values); SetNoNeedBuffer(no_need_buffer_values);
VLOG(6) << "finish process no need buffer"; VLOG(6) << "finish process no need buffer";
...@@ -302,7 +302,7 @@ void PhiKernelInstruction::InitInputsOutputsIds( ...@@ -302,7 +302,7 @@ void PhiKernelInstruction::InitInputsOutputsIds(
variable_2_var_name) { variable_2_var_name) {
std::unordered_map<ir::Value, std::vector<int>> inputs; std::unordered_map<ir::Value, std::vector<int>> inputs;
for (size_t i = 0; i < op->num_operands(); i++) { for (size_t i = 0; i < op->num_operands(); i++) {
ir::Value value = op->operand(i); ir::Value value = op->operand_source(i);
if (value) { if (value) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
value_2_var_name.find(value), value_2_var_name.find(value),
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# generator op member function # generator op member function
OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand({input_index}); }} OP_GET_INPUT_TEMPLATE = """ ir::Value {input_name}() {{ return operand_source({input_index}); }}
""" """
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }} OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return result({output_index}); }}
""" """
......
...@@ -40,26 +40,26 @@ void {op_name}::Verify() {{}} ...@@ -40,26 +40,26 @@ void {op_name}::Verify() {{}}
""" """
INPUT_TYPE_CHECK_TEMPLATE = """ INPUT_TYPE_CHECK_TEMPLATE = """
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));""" phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """ INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand({index}).type().dyn_cast<ir::VectorType>()) {{ if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{ for (size_t i = 0; i < vec_type.size(); ++i) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}} }}
}} }}
else {{ else {{
PADDLE_ENFORCE((*this)->operand({index}).type().isa<{standard}>(), PADDLE_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}""" }}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{ if (auto val = (*this)->operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(), PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}""" }}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->op_operand({index})) {{ if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{ if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{ for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
......
...@@ -140,7 +140,7 @@ void CheckInputVars( ...@@ -140,7 +140,7 @@ void CheckInputVars(
size_t input_num = op->num_operands(); size_t input_num = op->num_operands();
if (input_num > 0) { if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i); auto value = op->operand_source(i);
if (value) { if (value) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
value_2_var_name.find(value), value_2_var_name.find(value),
...@@ -298,7 +298,7 @@ void HandleForSpecialOp( ...@@ -298,7 +298,7 @@ void HandleForSpecialOp(
tensor_array->clear(); tensor_array->clear();
size_t input_num = op->num_operands(); size_t input_num = op->num_operands();
for (size_t i = 0; i < input_num; ++i) { for (size_t i = 0; i < input_num; ++i) {
auto value = op->operand(i); auto value = op->operand_source(i);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
value_2_var_name->count(value), value_2_var_name->count(value),
true, true,
...@@ -315,7 +315,7 @@ void HandleForSpecialOp( ...@@ -315,7 +315,7 @@ void HandleForSpecialOp(
.dyn_cast<ir::StrAttribute>() .dyn_cast<ir::StrAttribute>()
.AsString(); .AsString();
auto value = op->operand(0); auto value = op->operand_source(0);
// change opreand name to param_name // change opreand name to param_name
auto orig_name = value_2_var_name->at(value); auto orig_name = value_2_var_name->at(value);
...@@ -336,7 +336,7 @@ void HandleForSpecialOp( ...@@ -336,7 +336,7 @@ void HandleForSpecialOp(
auto var_name = auto var_name =
op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString(); op->attributes().at("name").dyn_cast<ir::StrAttribute>().AsString();
auto value = op->operand(0); auto value = op->operand_source(0);
// change opreand name to param_name // change opreand name to param_name
auto orig_name = value_2_var_name->at(value); auto orig_name = value_2_var_name->at(value);
...@@ -372,7 +372,7 @@ void HandleForSpecialOp( ...@@ -372,7 +372,7 @@ void HandleForSpecialOp(
if (op_name == "builtin.slice") { if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice"; VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0); auto out_value = op->result(0);
auto in_value = op->operand(0); auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value), PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true, true,
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
...@@ -426,7 +426,7 @@ void HandleForInplaceOp( ...@@ -426,7 +426,7 @@ void HandleForInplaceOp(
if (yaml_parser.HasInplace(value_name)) { if (yaml_parser.HasInplace(value_name)) {
std::string inplace_name = yaml_parser.InplaceName(value_name); std::string inplace_name = yaml_parser.InplaceName(value_name);
ir::Value inplace_value = ir::Value inplace_value =
op->operand(yaml_parser.InputName2Id().at(inplace_name)); op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name = value_2_var_name->at(inplace_value); std::string var_name = value_2_var_name->at(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")"; << " (var: " << var_name << ")";
...@@ -547,7 +547,7 @@ void BuildRuntimeContext( ...@@ -547,7 +547,7 @@ void BuildRuntimeContext(
true, true,
phi::errors::NotFound("param [%s] MUST in name2id map", name)); phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name); auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index); ir::Value ptr = op->operand_source(index);
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name; VLOG(6) << "ctx->EmplaceBackInput: " << name << "\t" << in_var_name;
...@@ -603,7 +603,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase( ...@@ -603,7 +603,7 @@ std::shared_ptr<paddle::framework::OperatorBase> BuildOperatorBase(
true, true,
phi::errors::NotFound("param [%s] MUST in name2id map", name)); phi::errors::NotFound("param [%s] MUST in name2id map", name));
auto index = op_yaml_info.InputName2Id().at(name); auto index = op_yaml_info.InputName2Id().at(name);
ir::Value ptr = op->operand(index); ir::Value ptr = op->operand_source(index);
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
......
...@@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op, ...@@ -92,7 +92,7 @@ void BuildPhiContext(ir::Operation* op,
true, true,
phi::errors::NotFound("param [%s] MUST in name2id map", t)); phi::errors::NotFound("param [%s] MUST in name2id map", t));
auto index = op_yaml_info.InputName2Id().at(t); auto index = op_yaml_info.InputName2Id().at(t);
ir::Value ptr = op->operand(index); ir::Value ptr = op->operand_source(index);
if (!ptr) { if (!ptr) {
phi::DenseTensor* ptr = nullptr; phi::DenseTensor* ptr = nullptr;
OutType in_ptr(ptr); OutType in_ptr(ptr);
...@@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op, ...@@ -128,7 +128,7 @@ void BuildPhiContext(ir::Operation* op,
for (auto& t : vec_kernel_fn_attr_params) { for (auto& t : vec_kernel_fn_attr_params) {
if (name2id.count(t)) { if (name2id.count(t)) {
// tensor attribute, get information from input // tensor attribute, get information from input
ir::Value ptr = op->operand(name2id.at(t)); ir::Value ptr = op->operand_source(name2id.at(t));
auto in_var_name = name_map.at(ptr); auto in_var_name = name_map.at(ptr);
......
...@@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -114,12 +114,13 @@ class ConstantFoldingPattern : public ir::RewritePattern {
std::vector<ir::OpResult> op_inputs; std::vector<ir::OpResult> op_inputs;
for (uint32_t i = 0; i < op->num_operands(); i++) { for (uint32_t i = 0; i < op->num_operands(); i++) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op->operand(i).type().isa<paddle::dialect::DenseTensorType>(), op->operand_source(i).type().isa<paddle::dialect::DenseTensorType>(),
true, true,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Op's input must be a dense tensor type.")); "Op's input must be a dense tensor type."));
auto [param_name, param] = ir::GetParameterFromValue(op->operand(i)); auto [param_name, param] =
ir::GetParameterFromValue(op->operand_source(i));
program->SetParameter(param_name, program->SetParameter(param_name,
std::make_unique<ir::Parameter>(*param)); std::make_unique<ir::Parameter>(*param));
...@@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern { ...@@ -128,8 +129,8 @@ class ConstantFoldingPattern : public ir::RewritePattern {
param_var, param_var,
phi::errors::InvalidArgument("Parameter var not in scope.")); phi::errors::InvalidArgument("Parameter var not in scope."));
auto get_parameter_op = auto get_parameter_op = builder.Build<ir::GetParameterOp>(
builder.Build<ir::GetParameterOp>(param_name, op->operand(i).type()); param_name, op->operand_source(i).type());
op_inputs.push_back(get_parameter_op->result(0)); op_inputs.push_back(get_parameter_op->result(0));
} }
......
...@@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey( ...@@ -97,7 +97,7 @@ phi::KernelKey GetKernelKey(
} else if (input_map.count(slot_name)) { } else if (input_map.count(slot_name)) {
// parse from input // parse from input
int in_index = input_map.at(slot_name); int in_index = input_map.at(slot_name);
auto type = map_value_pair.at(op->operand(in_index)).type(); auto type = map_value_pair.at(op->operand_source(in_index)).type();
if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) { if (type.isa<paddle::dialect::AllocatedDenseTensorType>()) {
kernel_data_type = TransToPhiDataType( kernel_data_type = TransToPhiDataType(
...@@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey( ...@@ -151,7 +151,7 @@ phi::KernelKey GetKernelKey(
if (op->name() == "pd.uniform") { if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend // try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op // TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).GetDefiningOp(); auto define_op = op->operand_source(0).GetDefiningOp();
if (define_op->name() == "pd.full_int_array") { if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes() auto shape = define_op->attributes()
.at("value") .at("value")
...@@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey( ...@@ -183,7 +183,7 @@ phi::KernelKey GetKernelKey(
if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) { if (op_info_parser != nullptr && op_info_parser->IsTensorAttribute(i)) {
continue; continue;
} }
auto input_tmp = op->operand(i); auto input_tmp = op->operand_source(i);
// NOTE: if not input_tmp, it's an optional input // NOTE: if not input_tmp, it's an optional input
if (!input_tmp) { if (!input_tmp) {
continue; continue;
...@@ -341,7 +341,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, ...@@ -341,7 +341,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
if ((*it)->num_operands() > 0) { if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) { for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i); auto cur_in = (*it)->operand_source(i);
if (!cur_in) { if (!cur_in) {
vec_inputs.push_back(ir::OpResult()); vec_inputs.push_back(ir::OpResult());
continue; continue;
......
...@@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) { ...@@ -64,7 +64,7 @@ Operation* GetDefiningOpForInput(Operation* op, uint32_t index) {
index < op->num_operands(), index < op->num_operands(),
true, true,
phi::errors::InvalidArgument("Intput operand's index must be valid.")); phi::errors::InvalidArgument("Intput operand's index must be valid."));
return op->operand(index).GetDefiningOp(); return op->operand_source(index).GetDefiningOp();
} }
Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) { Operation* GetFirstUseOperationForOutput(Operation* op, uint32_t index) {
......
...@@ -77,6 +77,8 @@ void BindProgram(py::module *m) { ...@@ -77,6 +77,8 @@ void BindProgram(py::module *m) {
void BindBlock(py::module *m) { void BindBlock(py::module *m) {
py::class_<Block> block(*m, "Block"); py::class_<Block> block(*m, "Block");
block.def("front", &Block::front, return_value_policy::reference) block.def("front", &Block::front, return_value_policy::reference)
.def("get_parent_program",
[](Block &self) { return self.GetParentOp()->GetParentProgram(); })
.def("get_ops", .def("get_ops",
[](Block &self) -> py::list { [](Block &self) -> py::list {
py::list op_list; py::list op_list;
...@@ -94,19 +96,22 @@ void BindBlock(py::module *m) { ...@@ -94,19 +96,22 @@ void BindBlock(py::module *m) {
void BindOperation(py::module *m) { void BindOperation(py::module *m) {
py::class_<Operation> op(*m, "Operation"); py::class_<Operation> op(*m, "Operation");
op.def("name", &Operation::name) op.def("name", &Operation::name)
.def("get_parent", .def("get_parent_block",
py::overload_cast<>(&Operation::GetParent), py::overload_cast<>(&Operation::GetParent),
return_value_policy::reference) return_value_policy::reference)
.def("get_parent", .def("get_parent_block",
py::overload_cast<>(&Operation::GetParent, py::const_), py::overload_cast<>(&Operation::GetParent, py::const_),
return_value_policy::reference) return_value_policy::reference)
.def("num_operands", &Operation::num_operands)
.def("num_results", &Operation::num_results) .def("num_results", &Operation::num_results)
.def("operand", &Operation::operand)
.def("result", &Operation::result) .def("result", &Operation::result)
.def("operand_source", &Operation::operand_source)
.def("operands", .def("operands",
[](Operation &self) -> py::list { [](Operation &self) -> py::list {
py::list op_list; py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) { for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.op_operand(i)); op_list.append(self.operand(i));
} }
return op_list; return op_list;
}) })
...@@ -118,6 +123,14 @@ void BindOperation(py::module *m) { ...@@ -118,6 +123,14 @@ void BindOperation(py::module *m) {
} }
return op_list; return op_list;
}) })
.def("operands_source",
[](Operation &self) -> py::list {
py::list op_list;
for (uint32_t i = 0; i < self.num_operands(); i++) {
op_list.append(self.operand_source(i));
}
return op_list;
})
.def("get_input_names", .def("get_input_names",
[](Operation &self) -> py::list { [](Operation &self) -> py::list {
py::list op_list; py::list op_list;
...@@ -159,8 +172,11 @@ void BindOperation(py::module *m) { ...@@ -159,8 +172,11 @@ void BindOperation(py::module *m) {
void BindValue(py::module *m) { void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value"); py::class_<Value> value(*m, "Value");
value.def( value
"get_defining_op", &Value::GetDefiningOp, return_value_policy::reference); .def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("__eq__", &Value::operator==);
} }
void BindOpOperand(py::module *m) { void BindOpOperand(py::module *m) {
......
...@@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) { ...@@ -239,7 +239,7 @@ void IrPrinter::PrintOpOperands(const Operation* op) {
std::vector<Value> op_operands; std::vector<Value> op_operands;
op_operands.reserve(num_op_operands); op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->operand(idx)); op_operands.push_back(op->operand_source(idx));
} }
PrintInterleave( PrintInterleave(
op_operands.begin(), op_operands.begin(),
...@@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) { ...@@ -254,7 +254,7 @@ void IrPrinter::PrintOperandsType(const Operation* op) {
std::vector<Type> op_operand_types; std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands); op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) { for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->op_operand(idx); auto op_operand = op->operand(idx);
if (op_operand) { if (op_operand) {
op_operand_types.push_back(op_operand.type()); op_operand_types.push_back(op_operand.type());
} else { } else {
......
...@@ -88,7 +88,9 @@ class IR_API OpBase { ...@@ -88,7 +88,9 @@ class IR_API OpBase {
const AttributeMap &attributes() const { return operation()->attributes(); } const AttributeMap &attributes() const { return operation()->attributes(); }
Value operand(uint32_t index) const { return operation()->operand(index); } Value operand_source(uint32_t index) const {
return operation()->operand_source(index);
}
OpResult result(uint32_t index) const { return operation()->result(index); } OpResult result(uint32_t index) const { return operation()->result(index); }
......
...@@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) { ...@@ -40,7 +40,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
// Allocate the required memory based on the size and number of inputs, outputs, // Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult, // and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, Operand. // OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<ir::OpResult> &inputs, Operation *Operation::Create(const std::vector<ir::OpResult> &inputs,
const AttributeMap &attributes, const AttributeMap &attributes,
const std::vector<ir::Type> &output_types, const std::vector<ir::Type> &output_types,
...@@ -132,7 +132,7 @@ void Operation::Destroy() { ...@@ -132,7 +132,7 @@ void Operation::Destroy() {
// 4. Deconstruct OpOperand. // 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) { for (size_t idx = 0; idx < num_operands_; idx++) {
op_operand(idx).impl()->~OpOperandImpl(); operand(idx).impl()->~OpOperandImpl();
} }
// 5. Free memory. // 5. Free memory.
uint32_t max_inline_result_num = uint32_t max_inline_result_num =
...@@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const { ...@@ -186,7 +186,7 @@ ir::OpResult Operation::result(uint32_t index) const {
} }
} }
OpOperand Operation::op_operand(uint32_t index) const { OpOperand Operation::operand(uint32_t index) const {
if (index >= num_operands_) { if (index >= num_operands_) {
IR_THROW("index exceeds OP input range."); IR_THROW("index exceeds OP input range.");
} }
...@@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const { ...@@ -195,8 +195,8 @@ OpOperand Operation::op_operand(uint32_t index) const {
return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr)); return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
} }
Value Operation::operand(uint32_t index) const { Value Operation::operand_source(uint32_t index) const {
OpOperand val = op_operand(index); OpOperand val = operand(index);
return val ? val.source() : Value(); return val ? val.source() : Value();
} }
......
...@@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final { ...@@ -55,9 +55,9 @@ class IR_API alignas(8) Operation final {
OpResult result(uint32_t index) const; OpResult result(uint32_t index) const;
OpOperand op_operand(uint32_t index) const; OpOperand operand(uint32_t index) const;
Value operand(uint32_t index) const; Value operand_source(uint32_t index) const;
/// Returns the region held by this operation at position 'index'. /// Returns the region held by this operation at position 'index'.
Region &region(unsigned index); Region &region(unsigned index);
......
...@@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { ...@@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
void NotifyOperationRemoved(ir::Operation* op) override { void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) { for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i)); AddOperandToWorklist(op->operand_source(i));
} }
for (uint32_t i = 0; i < op->num_regions(); ++i) { for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i); auto& region = op->region(i);
......
...@@ -174,9 +174,9 @@ TEST(program_test, program) { ...@@ -174,9 +174,9 @@ TEST(program_test, program) {
// (8) Def SetParameterOp(c, "c") // (8) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c"); auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id()); EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface = Interface *c_interface =
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>(); op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c = // ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get()); // c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c = std::unique_ptr<ir::Parameter> parameter_c =
......
...@@ -91,10 +91,10 @@ TEST(value_test, value_test) { ...@@ -91,10 +91,10 @@ TEST(value_test, value_test) {
// Test 2: op1_first_output -> op4_first_input // Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->result(0); ir::OpResult op1_first_output = op1->result(0);
ir::OpOperand op4_first_input = op4->op_operand(0); ir::OpOperand op4_first_input = op4->operand(0);
EXPECT_EQ(op1_first_output.first_use(), op4_first_input); EXPECT_EQ(op1_first_output.first_use(), op4_first_input);
ir::OpOperand op3_first_input = op3->op_operand(0); ir::OpOperand op3_first_input = op3->operand(0);
EXPECT_EQ(op4_first_input.next_use(), op3_first_input); EXPECT_EQ(op4_first_input.next_use(), op3_first_input);
EXPECT_EQ(op3_first_input.next_use(), nullptr); EXPECT_EQ(op3_first_input.next_use(), nullptr);
...@@ -110,11 +110,11 @@ TEST(value_test, value_test) { ...@@ -110,11 +110,11 @@ TEST(value_test, value_test) {
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c); // a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
// //
c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; }); c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1), b); EXPECT_EQ(op4->operand_source(1), b);
EXPECT_TRUE(c.use_empty()); EXPECT_TRUE(c.use_empty());
b.ReplaceAllUsesWith(a); b.ReplaceAllUsesWith(a);
EXPECT_EQ(op4->operand(1), a); EXPECT_EQ(op4->operand_source(1), a);
EXPECT_TRUE(b.use_empty()); EXPECT_TRUE(b.use_empty());
// destroy // destroy
......
...@@ -386,10 +386,10 @@ class Conv2dFusionOpTest : public ir::Op<Conv2dFusionOpTest, ...@@ -386,10 +386,10 @@ class Conv2dFusionOpTest : public ir::Op<Conv2dFusionOpTest,
ir::OpResult residual_, ir::OpResult residual_,
ir::AttributeMap attributes); ir::AttributeMap attributes);
void Verify(); void Verify();
ir::Value input() { return operand(0); } ir::Value input() { return operand_source(0); }
ir::Value filter() { return operand(1); } ir::Value filter() { return operand_source(1); }
ir::Value bias() { return operand(2); } ir::Value bias() { return operand_source(2); }
ir::Value residual() { return operand(3); } ir::Value residual() { return operand_source(3); }
ir::OpResult output() { return result(0); } ir::OpResult output() { return result(0); }
ir::OpResult outputs() { return result(1); } ir::OpResult outputs() { return result(1); }
ir::Attribute attribute(const std::string &name) { ir::Attribute attribute(const std::string &name) {
...@@ -752,19 +752,25 @@ void Conv2dFusionOpTest::Verify() { ...@@ -752,19 +752,25 @@ void Conv2dFusionOpTest::Verify() {
4u, 4u,
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 4.", input_size)); "The size %d of inputs must be equal to 4.", input_size));
PADDLE_ENFORCE( PADDLE_ENFORCE((*this)
(*this)->operand(0).type().isa<paddle::dialect::DenseTensorType>(), ->operand_source(0)
phi::errors::PreconditionNotMet( .type()
"Type validation failed for the 0th input.")); .isa<paddle::dialect::DenseTensorType>(),
PADDLE_ENFORCE( phi::errors::PreconditionNotMet(
(*this)->operand(1).type().isa<paddle::dialect::DenseTensorType>(), "Type validation failed for the 0th input."));
phi::errors::PreconditionNotMet( PADDLE_ENFORCE((*this)
"Type validation failed for the 1th input.")); ->operand_source(1)
PADDLE_ENFORCE( .type()
(*this)->operand(2).type().isa<paddle::dialect::DenseTensorType>(), .isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input.")); "Type validation failed for the 1th input."));
if (auto val = (*this)->op_operand(3)) { PADDLE_ENFORCE((*this)
->operand_source(2)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 2th input."));
if (auto val = (*this)->operand(3)) {
PADDLE_ENFORCE(val.type().isa<paddle::dialect::DenseTensorType>(), PADDLE_ENFORCE(val.type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Type validation failed for the 3th input.")); "Type validation failed for the 3th input."));
......
...@@ -30,8 +30,8 @@ def get_ir_program(): ...@@ -30,8 +30,8 @@ def get_ir_program():
x_s = paddle.static.data('x', [4, 4], x.dtype) x_s = paddle.static.data('x', [4, 4], x.dtype)
x_s.stop_gradient = False x_s.stop_gradient = False
y_s = paddle.matmul(x_s, x_s) y_s = paddle.matmul(x_s, x_s)
y_s = paddle.add(x_s, y_s) z_s = paddle.add(y_s, y_s)
y_s = paddle.tanh(y_s) k_s = paddle.tanh(z_s)
newir_program = ir.translate_to_new_ir(main_program.desc) newir_program = ir.translate_to_new_ir(main_program.desc)
return newir_program return newir_program
...@@ -41,6 +41,11 @@ class TestPybind(unittest.TestCase): ...@@ -41,6 +41,11 @@ class TestPybind(unittest.TestCase):
newir_program = get_ir_program() newir_program = get_ir_program()
print(newir_program) print(newir_program)
block = newir_program.block()
program = block.get_parent_program()
self.assertEqual(newir_program, program)
def test_block(self): def test_block(self):
newir_program = get_ir_program() newir_program = get_ir_program()
block = newir_program.block() block = newir_program.block()
...@@ -57,7 +62,7 @@ class TestPybind(unittest.TestCase): ...@@ -57,7 +62,7 @@ class TestPybind(unittest.TestCase):
matmul_op = newir_program.block().get_ops()[1] matmul_op = newir_program.block().get_ops()[1]
add_op = newir_program.block().get_ops()[2] add_op = newir_program.block().get_ops()[2]
tanh_op = newir_program.block().get_ops()[3] tanh_op = newir_program.block().get_ops()[3]
parent_block = tanh_op.get_parent() parent_block = tanh_op.get_parent_block()
parent_ops_num = len(parent_block.get_ops()) parent_ops_num = len(parent_block.get_ops())
self.assertEqual(parent_ops_num, 4) self.assertEqual(parent_ops_num, 4)
self.assertEqual(tanh_op.num_results(), 1) self.assertEqual(tanh_op.num_results(), 1)
...@@ -79,6 +84,13 @@ class TestPybind(unittest.TestCase): ...@@ -79,6 +84,13 @@ class TestPybind(unittest.TestCase):
matmul_op.result(0).set_stop_gradient(True) matmul_op.result(0).set_stop_gradient(True)
self.assertEqual(matmul_op.result(0).get_stop_gradient(), True) self.assertEqual(matmul_op.result(0).get_stop_gradient(), True)
result_set = set()
for opresult in matmul_op.results():
result_set.add(opresult)
# self.assertTrue(add_op.operands()[0].source() in result_set)
# self.assertEqual(add_op.operands_source()[0] , matmul_op.results()[0],)
self.assertEqual( self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.add" tanh_op.operands()[0].source().get_defining_op().name(), "pd.add"
) )
...@@ -87,6 +99,11 @@ class TestPybind(unittest.TestCase): ...@@ -87,6 +99,11 @@ class TestPybind(unittest.TestCase):
self.assertEqual( self.assertEqual(
tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul" tanh_op.operands()[0].source().get_defining_op().name(), "pd.matmul"
) )
self.assertEqual(
tanh_op.operands()[0].source().get_defining_op(),
tanh_op.operands_source()[0].get_defining_op(),
)
self.assertEqual(add_op.result(0).use_empty(), True) self.assertEqual(add_op.result(0).use_empty(), True)
def test_type(self): def test_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册