diff --git a/paddle/fluid/ir/dialect/op_gen.py b/paddle/fluid/ir/dialect/op_gen.py index 65eabda77470f1f0718e49d94c1a81800d013aa5..7aa49f583f45ea07df82acf6f3bdcbf803876b68 100644 --- a/paddle/fluid/ir/dialect/op_gen.py +++ b/paddle/fluid/ir/dialect/op_gen.py @@ -78,7 +78,7 @@ op_n_attribute_declare_str = ( "static const char *attributes_name[{attribute_num}];" ) -OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->operand({input_index}); }} +OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->op_operand({input_index}); }} """ OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->result({output_index}); }} """ @@ -1046,7 +1046,7 @@ def GenBuildOutputs( name=op_output_name_list[idx] ) - build_output_str += " argument.AddTypes(argument_outputs.begin(), argument_outputs.end());\n" + build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n" return build_output_str diff --git a/paddle/fluid/ir/dialect/op_verify_gen.py b/paddle/fluid/ir/dialect/op_verify_gen.py index 12714e4af4d3b1f58650ac7980f751c0ae5eba54..7b65e8dce9181e46f1050e28eb4e96423f32f453 100644 --- a/paddle/fluid/ir/dialect/op_verify_gen.py +++ b/paddle/fluid/ir/dialect/op_verify_gen.py @@ -54,12 +54,12 @@ INPUT_VECTORTYPE_CHECK_TEMPLATE = """ phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); }}""" INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """ - if (auto val = (*this)->operand({index})) {{ + if (auto val = (*this)->op_operand({index})) {{ PADDLE_ENFORCE(val.type().isa<{standard}>(), phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); }}""" INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """ - if (auto val = (*this)->operand({index})) {{ + if (auto val = (*this)->op_operand({index})) {{ if (auto vec_type = val.type().dyn_cast()) {{ for (size_t i = 0; i < vec_type.size(); i++) {{ PADDLE_ENFORCE(vec_type[i].isa<{standard}>(), diff --git a/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc b/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc index 308c7bbb9feb178c6ca258c5e39ad5fe67b6748e..edbd4b4bc8e5eb75d54b4e862009e0b6a7dfbec8 100644 --- a/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/ir/pass/pd_op_to_kernel_pass.cc @@ -86,7 +86,6 @@ phi::KernelKey GetKernelKey( dialect::DenseTensorType type = op->operand(in_index) - .source() .type() .dyn_cast(); kernel_data_type = TransToPhiDataType(type.dtype()); @@ -108,7 +107,7 @@ phi::KernelKey GetKernelKey( if (op->name() == "pd.uniform") { // try to process uniform, use shape to determin backend // TODO(phlrain): shuold support other initilize op - auto define_op = op->operand(0).source().GetDefiningOp(); + auto define_op = op->operand(0).GetDefiningOp(); if (define_op->name() == "pd.full_int_array") { auto shape = define_op->attributes() .at("value") @@ -140,8 +139,7 @@ phi::KernelKey GetKernelKey( if ((input_info.size() > i) && input_info[i].is_mutable_attribute) { continue; } - auto input_tmp = op->operand(i).source(); - + auto input_tmp = op->operand(i); auto new_input_tmp = map_value_pair.at(input_tmp); auto input_type = new_input_tmp.type(); @@ -262,7 +260,7 @@ std::unique_ptr PdOpLowerToKernelPass(ir::Program* prog) { if ((*it)->num_operands() > 0) { for (size_t i = 0; i < (*it)->num_operands(); ++i) { - auto cur_in = (*it)->operand(i).source(); + auto cur_in = (*it)->operand(i); auto new_in = map_value_pair.at(cur_in); auto new_in_type = new_in.type(); diff --git a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc index 62eaa0e06682b0a3c9009d9abc88bd451be2c284..76c3848343d8f69ac08669d8131c0b82027f0bcd 100644 --- a/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc +++ b/paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.cc @@ -103,7 +103,7 @@ void BuildScope(ir::Block* block, auto tensor_array = var->GetMutable(); for (size_t i = 0; i < input_num; ++i) { - auto ptr = (*it)->operand(i).source(); + auto ptr = (*it)->operand(i); PADDLE_ENFORCE_EQ(name_map->count(ptr), true, @@ -119,7 +119,7 @@ void BuildScope(ir::Block* block, if (input_num > 0) { for (size_t i = 0; i < input_num; ++i) { - auto ptr = (*it)->operand(i).source(); + auto ptr = (*it)->operand(i); std::string name; if (name_map->find(ptr) != name_map->end()) { name = name_map->at(ptr); @@ -191,7 +191,7 @@ void BuildInferMetaContext( auto& t = vec_param_list[input_index]; if (input_index_map.count(t)) { // get information from input - ir::Value ptr = op->operand(input_index_map[t]).source(); + ir::Value ptr = op->operand(input_index_map[t]); auto in_var_name = name_map.at(ptr); if (mutable_attr_type_map.count(t)) { @@ -316,7 +316,7 @@ void BuildPhiKernelContext( for (auto& t : vec_param_list) { if (input_index_map.count(t)) { // get information from input - ir::Value ptr = op->operand(input_index_map[t]).source(); + ir::Value ptr = op->operand(input_index_map[t]); auto in_var_name = name_map.at(ptr); if (input_map != nullptr) { // only deal with single input for now, [todo] need support multi input diff --git a/paddle/ir/core/ir_printer.cc b/paddle/ir/core/ir_printer.cc index c87bba1c8b35626d19731f56e1b65712350511d9..bb7a0c9e825d233eae4e764663725c548e314004 100644 --- a/paddle/ir/core/ir_printer.cc +++ b/paddle/ir/core/ir_printer.cc @@ -230,7 +230,7 @@ void IrPrinter::PrintOpOperands(Operation* op) { std::vector op_operands; op_operands.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - op_operands.push_back(op->operand(idx).source()); + op_operands.push_back(op->operand(idx)); } PrintInterleave( op_operands.begin(), @@ -245,11 +245,11 @@ void IrPrinter::PrintOperandsType(Operation* op) { std::vector op_operand_types; op_operand_types.reserve(num_op_operands); for (size_t idx = 0; idx < num_op_operands; idx++) { - auto op_operand = op->operand(idx); + auto op_operand = op->op_operand(idx); if (op_operand) { - op_operand_types.push_back(op->operand(idx).source().type()); + op_operand_types.push_back(op_operand.type()); } else { - op_operand_types.push_back(Type(nullptr)); + op_operand_types.push_back(Type()); } } os << " ("; diff --git a/paddle/ir/core/operation.cc b/paddle/ir/core/operation.cc index 01cbafb5d59a8976cef9be4cd384af1ef6756a2f..caabae8530b27a960eba8566d38342c966a772b5 100644 --- a/paddle/ir/core/operation.cc +++ b/paddle/ir/core/operation.cc @@ -130,7 +130,7 @@ void Operation::Destroy() { // 4. Deconstruct OpOperand. for (size_t idx = 0; idx < num_operands_; idx++) { - operand(idx).impl()->~OpOperandImpl(); + op_operand(idx).impl()->~OpOperandImpl(); } // 5. Free memory. uint32_t max_inline_result_num = @@ -184,13 +184,18 @@ ir::OpResult Operation::result(uint32_t index) const { } } -ir::OpOperand Operation::operand(uint32_t index) const { +OpOperand Operation::op_operand(uint32_t index) const { if (index >= num_operands_) { IR_THROW("index exceeds OP input range."); } const char *ptr = reinterpret_cast(this) + sizeof(Operation) + (index) * sizeof(detail::OpOperandImpl); - return ir::OpOperand(reinterpret_cast(ptr)); + return OpOperand(reinterpret_cast(ptr)); +} + +Value Operation::operand(uint32_t index) const { + OpOperand val = op_operand(index); + return val ? val.source() : Value(); } std::string Operation::name() const { diff --git a/paddle/ir/core/operation.h b/paddle/ir/core/operation.h index bf223f2fdf966bb12a9630a10a954e518764ec9b..a7296efa4c84d777acd439bc44f4b9f9f2fd5582 100644 --- a/paddle/ir/core/operation.h +++ b/paddle/ir/core/operation.h @@ -53,7 +53,9 @@ class IR_API alignas(8) Operation final { OpResult result(uint32_t index) const; - OpOperand operand(uint32_t index) const; + OpOperand op_operand(uint32_t index) const; + + Value operand(uint32_t index) const; /// Returns the region held by this operation at position 'index'. Region ®ion(unsigned index); diff --git a/paddle/ir/core/operation_utils.h b/paddle/ir/core/operation_utils.h index cbf19a4bb74c7645ac2bbedff4ba62ae281a552d..3e4610b0f1dd2d7ab9ccff72eef0a26dd1d3f154 100644 --- a/paddle/ir/core/operation_utils.h +++ b/paddle/ir/core/operation_utils.h @@ -61,7 +61,7 @@ struct OperationArgument { void AddOutput(Type type) { output_types.emplace_back(type); } template - void AddTypes(InputIt first, InputIt last); + void AddOutputs(InputIt first, InputIt last); /// Add an attribute with the specified name. void AddAttribute(const std::string& name, Attribute attr) { @@ -86,7 +86,7 @@ void OperationArgument::AddOperands(InputIt first, InputIt last) { } } template -void OperationArgument::AddTypes(InputIt first, InputIt last) { +void OperationArgument::AddOutputs(InputIt first, InputIt last) { while (first != last) { output_types.emplace_back(*first++); } diff --git a/paddle/ir/core/value.cc b/paddle/ir/core/value.cc index a5ca59d19759b5fefe677002a61898470f3e4f3a..666be5481c41826d1aa0413b7e48e6e50238dcfb 100644 --- a/paddle/ir/core/value.cc +++ b/paddle/ir/core/value.cc @@ -47,7 +47,7 @@ Operation *OpOperand::owner() const { return impl()->owner(); } void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); } detail::OpOperandImpl *OpOperand::impl() const { - IR_ENFORCE(impl_, "Can't use impl() interface while operand is null."); + IR_ENFORCE(impl_, "Can't use impl() interface while op_operand is null."); return impl_; } // Value diff --git a/paddle/ir/core/value.h b/paddle/ir/core/value.h index 429516acc4a6b395690f37c916ea7d93faf9fb1f..88f23cd1ee5177f3aaaaa0259b6c506cb981142d 100644 --- a/paddle/ir/core/value.h +++ b/paddle/ir/core/value.h @@ -28,8 +28,8 @@ class OpResultImpl; } // namespace detail /// -/// \brief OpOperand class represents the operand of operation. This class only -/// provides interfaces, for specific implementation, see Impl class. +/// \brief OpOperand class represents the op_operand of operation. This class +/// only provides interfaces, for specific implementation, see Impl class. /// class IR_API OpOperand { public: diff --git a/paddle/ir/core/value_impl.h b/paddle/ir/core/value_impl.h index 1e21e8f0d19c6bb24c49a09f2eeb53e6af168797..9c3c56cdefd387813c3888875816541b2f74f723 100644 --- a/paddle/ir/core/value_impl.h +++ b/paddle/ir/core/value_impl.h @@ -35,7 +35,7 @@ class OpOperandImpl { void set_source(Value value); - /// Remove this operand from the current use list. + /// Remove this op_operand from the current use list. void RemoveFromUdChain(); ~OpOperandImpl(); @@ -62,7 +62,7 @@ class OpOperandImpl { /// \brief ValueImpl is the base class of all derived Value classes such as /// OpResultImpl. This class defines all the information and usage interface in /// the IR Value. Each Value include three attributes: -/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first operand +/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first op_operand /// address with offset of this value; (3) index: the position where the output /// list of the parent operator. /// diff --git a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc index 21a673e6b3a15c24fe58d89f5cba7ebc56697c06..8ee6c8886f60d81775a8b4814b2c8d27432c40ea 100644 --- a/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/ir/pattern_rewrite/pattern_rewrite_driver.cc @@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter { void NotifyOperationRemoved(ir::Operation* op) override { for (uint32_t i = 0; i < op->num_operands(); ++i) { - AddOperandToWorklist(op->operand(i).source()); + AddOperandToWorklist(op->operand(i)); } for (uint32_t i = 0; i < op->num_regions(); ++i) { auto& region = op->region(i); diff --git a/test/cpp/ir/core/ir_op_test.cc b/test/cpp/ir/core/ir_op_test.cc index 6ab59c6014d62d87221146121a13b6cd511a9f63..0e246af03cbe10360b3cbe568ee19658f3282ae0 100644 --- a/test/cpp/ir/core/ir_op_test.cc +++ b/test/cpp/ir/core/ir_op_test.cc @@ -109,13 +109,9 @@ class Operation1 : public ir::Op { std::unordered_map attributes = CreateAttributeMap({"op1_attr1", "op1_attr2"}, {"op1_attr1", "op1_attr2"}); - argument.AddOperands::iterator>(inputs.begin(), - inputs.end()); - argument.AddTypes::iterator>(output_types.begin(), - output_types.end()); - argument.AddAttributes< - std::unordered_map::iterator>( - attributes.begin(), attributes.end()); + argument.AddOperands(inputs.begin(), inputs.end()); + argument.AddOutputs(output_types.begin(), output_types.end()); + argument.AddAttributes(attributes.begin(), attributes.end()); } }; const char *Operation1::attributes_name[attributes_num] = {"op1_attr1", diff --git a/test/cpp/ir/core/ir_program_test.cc b/test/cpp/ir/core/ir_program_test.cc index 6e2a8e5acb97571541674b178f08587950878733..a6345829d07df84e6e1043bb8bf85600e5abb126 100644 --- a/test/cpp/ir/core/ir_program_test.cc +++ b/test/cpp/ir/core/ir_program_test.cc @@ -174,9 +174,9 @@ TEST(program_test, program) { // (8) Def SetParameterOp(c, "c") auto op4 = builder.Build(op3->result(0), "c"); - EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id()); + EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id()); Interface *c_interface = - op4->operand(0).type().dialect().GetRegisteredInterface(); + op4->op_operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); std::unique_ptr parameter_c = diff --git a/test/cpp/ir/core/ir_value_test.cc b/test/cpp/ir/core/ir_value_test.cc index b77552122bfc19ae888573d23308b0b5e512e2f2..3ad5c501464621b91fa574bd70c33c4420d4ea41 100644 --- a/test/cpp/ir/core/ir_value_test.cc +++ b/test/cpp/ir/core/ir_value_test.cc @@ -91,10 +91,10 @@ TEST(value_test, value_test) { // Test 2: op1_first_output -> op4_first_input ir::OpResult op1_first_output = op1->result(0); - ir::OpOperand op4_first_input = op4->operand(0); + ir::OpOperand op4_first_input = op4->op_operand(0); EXPECT_EQ(op1_first_output.first_use(), op4_first_input); - ir::OpOperand op3_first_input = op3->operand(0); + ir::OpOperand op3_first_input = op3->op_operand(0); EXPECT_EQ(op4_first_input.next_use(), op3_first_input); EXPECT_EQ(op3_first_input.next_use(), nullptr); @@ -110,11 +110,11 @@ TEST(value_test, value_test) { // a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c); // c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; }); - EXPECT_EQ(op4->operand(1).source(), b); + EXPECT_EQ(op4->operand(1), b); EXPECT_TRUE(c.use_empty()); b.ReplaceAllUsesWith(a); - EXPECT_EQ(op4->operand(1).source(), a); + EXPECT_EQ(op4->operand(1), a); EXPECT_TRUE(b.use_empty()); // destroy diff --git a/test/cpp/ir/core/phi_kernel_adaptor.h b/test/cpp/ir/core/phi_kernel_adaptor.h index e8847977bc4cc60af5d6208eae7aebb9acafbb7c..b82572a15f3510135f60f2f51eb62670c89fbe3f 100644 --- a/test/cpp/ir/core/phi_kernel_adaptor.h +++ b/test/cpp/ir/core/phi_kernel_adaptor.h @@ -56,7 +56,7 @@ void BuildScope(ir::Block* block, int input = (*it)->num_operands(); if (input > 0) { for (int i = 0; i < input; ++i) { - auto ptr = (*it)->operand(i).source(); + auto ptr = (*it)->operand(i); std::string name; if (name_map->find(ptr) != name_map->end()) { name = name_map->at(ptr); @@ -130,7 +130,7 @@ void build_context(ir::Operation* op, for (auto& t : vec_param_list) { if (input_index_map.count(t)) { // get information from input - ir::Value ptr = op->operand(input_index_map[t]).source(); + ir::Value ptr = op->operand(input_index_map[t]); auto in_var_name = name_map.at(ptr); if (mutable_attr_type_map.count(t)) { diff --git a/test/cpp/ir/pass/pass_manager_test.cc b/test/cpp/ir/pass/pass_manager_test.cc index 87a5abd64452e58499856d66f455ae5ed4d9293a..b77df8a092097d6b8afb86afe0b56c92cff5a77a 100644 --- a/test/cpp/ir/pass/pass_manager_test.cc +++ b/test/cpp/ir/pass/pass_manager_test.cc @@ -247,10 +247,9 @@ TEST(pass_manager, PassManager) { // (7) Def SetParameterOp(c, "c") auto op4 = builder.Build(op3->result(0), "c"); - EXPECT_EQ(op4->operand(0).source().type().dialect().id(), - paddle_dialect->id()); + EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id()); Interface *c_interface = - op4->operand(0).type().dialect().GetRegisteredInterface(); + op4->op_operand(0).type().dialect().GetRegisteredInterface(); // ir::Parameter *parameter_c = // c_interface->VariableToParameter(variable_c.get()); diff --git a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc index bd1b83c066497198f99b4daf0cd0a4d784fbce63..dc3b716a4953f2d486da30f9c2fd51dcdbb1b93a 100644 --- a/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/ir/pattern_rewrite/pattern_rewrite_test.cc @@ -181,7 +181,7 @@ class TransposePatternRewrite bool MatchAndRewrite(paddle::dialect::TransposeOp op, ir::PatternRewriter &rewriter) const override { - auto prev_op = op->operand(0).source().GetDefiningOp(); + auto prev_op = op->operand(0).GetDefiningOp(); std::vector axis_last = GetAxis(op); auto prev_trans_op = prev_op->dyn_cast(); if (prev_trans_op) { @@ -191,7 +191,7 @@ class TransposePatternRewrite auto new_perm = GetPerm(axis_first, axis_last); rewriter.SetInsertionPoint(op); auto new_op = rewriter.Build( - prev_op->operand(0).source().GetDefiningOp()->result(0), new_perm); + prev_op->operand(0).GetDefiningOp()->result(0), new_perm); rewriter.ReplaceOp(op, {new_op.out()}); return true; }