未验证 提交 c3077ec1 编写于 作者: W winter-wang 提交者: GitHub

[IR] add op_operand api for ir::Operation. (#54875)

上级 96652265
......@@ -78,7 +78,7 @@ op_n_attribute_declare_str = (
"static const char *attributes_name[{attribute_num}];"
)
OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->operand({input_index}); }}
OP_GET_INPUT_TEMPLATE = """ ir::OpOperand {input_name}() {{ return operation()->op_operand({input_index}); }}
"""
OP_GET_OUTPUT_TEMPLATE = """ ir::OpResult {output_name}() {{ return operation()->result({output_index}); }}
"""
......@@ -1046,7 +1046,7 @@ def GenBuildOutputs(
name=op_output_name_list[idx]
)
build_output_str += " argument.AddTypes(argument_outputs.begin(), argument_outputs.end());\n"
build_output_str += " argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());\n"
return build_output_str
......
......@@ -54,12 +54,12 @@ INPUT_VECTORTYPE_CHECK_TEMPLATE = """
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
if (auto val = (*this)->op_operand({index})) {{
PADDLE_ENFORCE(val.type().isa<{standard}>(),
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
if (auto val = (*this)->op_operand({index})) {{
if (auto vec_type = val.type().dyn_cast<ir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
PADDLE_ENFORCE(vec_type[i].isa<{standard}>(),
......
......@@ -86,7 +86,6 @@ phi::KernelKey GetKernelKey(
dialect::DenseTensorType type =
op->operand(in_index)
.source()
.type()
.dyn_cast<paddle::dialect::DenseTensorType>();
kernel_data_type = TransToPhiDataType(type.dtype());
......@@ -108,7 +107,7 @@ phi::KernelKey GetKernelKey(
if (op->name() == "pd.uniform") {
// try to process uniform, use shape to determin backend
// TODO(phlrain): shuold support other initilize op
auto define_op = op->operand(0).source().GetDefiningOp();
auto define_op = op->operand(0).GetDefiningOp();
if (define_op->name() == "pd.full_int_array") {
auto shape = define_op->attributes()
.at("value")
......@@ -140,8 +139,7 @@ phi::KernelKey GetKernelKey(
if ((input_info.size() > i) && input_info[i].is_mutable_attribute) {
continue;
}
auto input_tmp = op->operand(i).source();
auto input_tmp = op->operand(i);
auto new_input_tmp = map_value_pair.at(input_tmp);
auto input_type = new_input_tmp.type();
......@@ -262,7 +260,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) {
if ((*it)->num_operands() > 0) {
for (size_t i = 0; i < (*it)->num_operands(); ++i) {
auto cur_in = (*it)->operand(i).source();
auto cur_in = (*it)->operand(i);
auto new_in = map_value_pair.at(cur_in);
auto new_in_type = new_in.type();
......
......@@ -103,7 +103,7 @@ void BuildScope(ir::Block* block,
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i).source();
auto ptr = (*it)->operand(i);
PADDLE_ENFORCE_EQ(name_map->count(ptr),
true,
......@@ -119,7 +119,7 @@ void BuildScope(ir::Block* block,
if (input_num > 0) {
for (size_t i = 0; i < input_num; ++i) {
auto ptr = (*it)->operand(i).source();
auto ptr = (*it)->operand(i);
std::string name;
if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr);
......@@ -191,7 +191,7 @@ void BuildInferMetaContext(
auto& t = vec_param_list[input_index];
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->operand(input_index_map[t]).source();
ir::Value ptr = op->operand(input_index_map[t]);
auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
......@@ -316,7 +316,7 @@ void BuildPhiKernelContext(
for (auto& t : vec_param_list) {
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->operand(input_index_map[t]).source();
ir::Value ptr = op->operand(input_index_map[t]);
auto in_var_name = name_map.at(ptr);
if (input_map != nullptr) {
// only deal with single input for now, [todo] need support multi input
......
......@@ -230,7 +230,7 @@ void IrPrinter::PrintOpOperands(Operation* op) {
std::vector<Value> op_operands;
op_operands.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
op_operands.push_back(op->operand(idx).source());
op_operands.push_back(op->operand(idx));
}
PrintInterleave(
op_operands.begin(),
......@@ -245,11 +245,11 @@ void IrPrinter::PrintOperandsType(Operation* op) {
std::vector<Type> op_operand_types;
op_operand_types.reserve(num_op_operands);
for (size_t idx = 0; idx < num_op_operands; idx++) {
auto op_operand = op->operand(idx);
auto op_operand = op->op_operand(idx);
if (op_operand) {
op_operand_types.push_back(op->operand(idx).source().type());
op_operand_types.push_back(op_operand.type());
} else {
op_operand_types.push_back(Type(nullptr));
op_operand_types.push_back(Type());
}
}
os << " (";
......
......@@ -130,7 +130,7 @@ void Operation::Destroy() {
// 4. Deconstruct OpOperand.
for (size_t idx = 0; idx < num_operands_; idx++) {
operand(idx).impl()->~OpOperandImpl();
op_operand(idx).impl()->~OpOperandImpl();
}
// 5. Free memory.
uint32_t max_inline_result_num =
......@@ -184,13 +184,18 @@ ir::OpResult Operation::result(uint32_t index) const {
}
}
ir::OpOperand Operation::operand(uint32_t index) const {
OpOperand Operation::op_operand(uint32_t index) const {
if (index >= num_operands_) {
IR_THROW("index exceeds OP input range.");
}
const char *ptr = reinterpret_cast<const char *>(this) + sizeof(Operation) +
(index) * sizeof(detail::OpOperandImpl);
return ir::OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
return OpOperand(reinterpret_cast<const detail::OpOperandImpl *>(ptr));
}
Value Operation::operand(uint32_t index) const {
OpOperand val = op_operand(index);
return val ? val.source() : Value();
}
std::string Operation::name() const {
......
......@@ -53,7 +53,9 @@ class IR_API alignas(8) Operation final {
OpResult result(uint32_t index) const;
OpOperand operand(uint32_t index) const;
OpOperand op_operand(uint32_t index) const;
Value operand(uint32_t index) const;
/// Returns the region held by this operation at position 'index'.
Region &region(unsigned index);
......
......@@ -61,7 +61,7 @@ struct OperationArgument {
void AddOutput(Type type) { output_types.emplace_back(type); }
template <class InputIt>
void AddTypes(InputIt first, InputIt last);
void AddOutputs(InputIt first, InputIt last);
/// Add an attribute with the specified name.
void AddAttribute(const std::string& name, Attribute attr) {
......@@ -86,7 +86,7 @@ void OperationArgument::AddOperands(InputIt first, InputIt last) {
}
}
template <class InputIt>
void OperationArgument::AddTypes(InputIt first, InputIt last) {
void OperationArgument::AddOutputs(InputIt first, InputIt last) {
while (first != last) {
output_types.emplace_back(*first++);
}
......
......@@ -47,7 +47,7 @@ Operation *OpOperand::owner() const { return impl()->owner(); }
void OpOperand::RemoveFromUdChain() { return impl()->RemoveFromUdChain(); }
detail::OpOperandImpl *OpOperand::impl() const {
IR_ENFORCE(impl_, "Can't use impl() interface while operand is null.");
IR_ENFORCE(impl_, "Can't use impl() interface while op_operand is null.");
return impl_;
}
// Value
......
......@@ -28,8 +28,8 @@ class OpResultImpl;
} // namespace detail
///
/// \brief OpOperand class represents the operand of operation. This class only
/// provides interfaces, for specific implementation, see Impl class.
/// \brief OpOperand class represents the op_operand of operation. This class
/// only provides interfaces, for specific implementation, see Impl class.
///
class IR_API OpOperand {
public:
......
......@@ -35,7 +35,7 @@ class OpOperandImpl {
void set_source(Value value);
/// Remove this operand from the current use list.
/// Remove this op_operand from the current use list.
void RemoveFromUdChain();
~OpOperandImpl();
......@@ -62,7 +62,7 @@ class OpOperandImpl {
/// \brief ValueImpl is the base class of all derived Value classes such as
/// OpResultImpl. This class defines all the information and usage interface in
/// the IR Value. Each Value include three attributes:
/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first operand
/// (1) type: ir::Type; (2) UD-chain of value: OpOperandImpl*, first op_operand
/// address with offset of this value; (3) index: the position where the output
/// list of the parent operator.
///
......
......@@ -131,7 +131,7 @@ class GreedyPatternRewriteDriver : public ir::PatternRewriter {
void NotifyOperationRemoved(ir::Operation* op) override {
for (uint32_t i = 0; i < op->num_operands(); ++i) {
AddOperandToWorklist(op->operand(i).source());
AddOperandToWorklist(op->operand(i));
}
for (uint32_t i = 0; i < op->num_regions(); ++i) {
auto& region = op->region(i);
......
......@@ -109,13 +109,9 @@ class Operation1 : public ir::Op<Operation1> {
std::unordered_map<std::string, ir::Attribute> attributes =
CreateAttributeMap({"op1_attr1", "op1_attr2"},
{"op1_attr1", "op1_attr2"});
argument.AddOperands<std::vector<ir::OpResult>::iterator>(inputs.begin(),
inputs.end());
argument.AddTypes<std::vector<ir::Type>::iterator>(output_types.begin(),
output_types.end());
argument.AddAttributes<
std::unordered_map<std::string, ir::Attribute>::iterator>(
attributes.begin(), attributes.end());
argument.AddOperands(inputs.begin(), inputs.end());
argument.AddOutputs(output_types.begin(), output_types.end());
argument.AddAttributes(attributes.begin(), attributes.end());
}
};
const char *Operation1::attributes_name[attributes_num] = {"op1_attr1",
......
......@@ -174,9 +174,9 @@ TEST(program_test, program) {
// (8) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
EXPECT_EQ(op4->op_operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
std::unique_ptr<ir::Parameter> parameter_c =
......
......@@ -91,10 +91,10 @@ TEST(value_test, value_test) {
// Test 2: op1_first_output -> op4_first_input
ir::OpResult op1_first_output = op1->result(0);
ir::OpOperand op4_first_input = op4->operand(0);
ir::OpOperand op4_first_input = op4->op_operand(0);
EXPECT_EQ(op1_first_output.first_use(), op4_first_input);
ir::OpOperand op3_first_input = op3->operand(0);
ir::OpOperand op3_first_input = op3->op_operand(0);
EXPECT_EQ(op4_first_input.next_use(), op3_first_input);
EXPECT_EQ(op3_first_input.next_use(), nullptr);
......@@ -110,11 +110,11 @@ TEST(value_test, value_test) {
// a = OP1(); b = OP2(); c = OP3(a, b); d, e, f, g, h, i, j = OP4(a, c);
//
c.ReplaceUsesWithIf(b, [](ir::OpOperand) { return true; });
EXPECT_EQ(op4->operand(1).source(), b);
EXPECT_EQ(op4->operand(1), b);
EXPECT_TRUE(c.use_empty());
b.ReplaceAllUsesWith(a);
EXPECT_EQ(op4->operand(1).source(), a);
EXPECT_EQ(op4->operand(1), a);
EXPECT_TRUE(b.use_empty());
// destroy
......
......@@ -56,7 +56,7 @@ void BuildScope(ir::Block* block,
int input = (*it)->num_operands();
if (input > 0) {
for (int i = 0; i < input; ++i) {
auto ptr = (*it)->operand(i).source();
auto ptr = (*it)->operand(i);
std::string name;
if (name_map->find(ptr) != name_map->end()) {
name = name_map->at(ptr);
......@@ -130,7 +130,7 @@ void build_context(ir::Operation* op,
for (auto& t : vec_param_list) {
if (input_index_map.count(t)) {
// get information from input
ir::Value ptr = op->operand(input_index_map[t]).source();
ir::Value ptr = op->operand(input_index_map[t]);
auto in_var_name = name_map.at(ptr);
if (mutable_attr_type_map.count(t)) {
......
......@@ -247,10 +247,9 @@ TEST(pass_manager, PassManager) {
// (7) Def SetParameterOp(c, "c")
auto op4 = builder.Build<ir::SetParameterOp>(op3->result(0), "c");
EXPECT_EQ(op4->operand(0).source().type().dialect().id(),
paddle_dialect->id());
EXPECT_EQ(op4->operand(0).type().dialect().id(), paddle_dialect->id());
Interface *c_interface =
op4->operand(0).type().dialect().GetRegisteredInterface<Interface>();
op4->op_operand(0).type().dialect().GetRegisteredInterface<Interface>();
// ir::Parameter *parameter_c =
// c_interface->VariableToParameter(variable_c.get());
......
......@@ -181,7 +181,7 @@ class TransposePatternRewrite
bool MatchAndRewrite(paddle::dialect::TransposeOp op,
ir::PatternRewriter &rewriter) const override {
auto prev_op = op->operand(0).source().GetDefiningOp();
auto prev_op = op->operand(0).GetDefiningOp();
std::vector<int> axis_last = GetAxis(op);
auto prev_trans_op = prev_op->dyn_cast<paddle::dialect::TransposeOp>();
if (prev_trans_op) {
......@@ -191,7 +191,7 @@ class TransposePatternRewrite
auto new_perm = GetPerm(axis_first, axis_last);
rewriter.SetInsertionPoint(op);
auto new_op = rewriter.Build<paddle::dialect::TransposeOp>(
prev_op->operand(0).source().GetDefiningOp()->result(0), new_perm);
prev_op->operand(0).GetDefiningOp()->result(0), new_perm);
rewriter.ReplaceOp(op, {new_op.out()});
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册