未验证 提交 1b8a1a98 编写于 作者: H hong 提交者: GitHub

[IR] Refactor op yaml info parser (#54790)

* update

* update

* polish code

* polish code
上级 8e7c8117
......@@ -22,6 +22,7 @@
#include "paddle/fluid/ir/dialect/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/kernel_type.h"
#include "paddle/fluid/ir/dialect/kernel_type_storage.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/platform/init_phi.h"
......
......@@ -36,6 +36,7 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/fluid/ir/interface/infershape.h"
#include "paddle/fluid/framework/infershape_utils.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/ir/dialect/pd_type_storage.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
namespace paddle {
namespace dialect {
struct OpInputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo() = default;
OpInputInfo(const OpInputInfo& input_info) = default;
OpInputInfo(const std::string& name,
const std::string& type_name,
bool optional,
bool no_need_buffer,
bool is_mutable_attribute)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
};
struct OpOutputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool intermediate = false;
OpOutputInfo() = default;
OpOutputInfo(const OpOutputInfo& output_info) = default;
OpOutputInfo(const std::string& name,
const std::string& type_name,
bool optional,
bool intermediate)
: name(name),
type_name(type_name),
optional(optional),
intermediate(intermediate) {}
};
struct OpAttributeInfo {
std::string name;
std::string type_name;
std::string data_type;
OpAttributeInfo() = default;
OpAttributeInfo(const OpAttributeInfo& attr_info) = default;
OpAttributeInfo(const std::string& name,
const std::string& type_name,
const std::string& data_type)
: name(name), type_name(type_name), data_type(data_type) {}
};
struct OpRunTimeInfo {
std::string infer_meta_func;
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
std::vector<std::string> kernel_key_dtype;
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(const std::string& infer_meta_func,
const std::vector<std::string>& infer_meta_param,
const std::vector<std::string>& kernel_func,
const std::vector<std::string>& kernel_param,
const std::vector<std::string>& dtype,
const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param),
kernel_key_dtype(dtype),
inplace(inplace),
view(view) {}
};
} // namespace dialect
} // namespace paddle
......@@ -20,7 +20,6 @@
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace paddle {
namespace dialect {
......@@ -96,72 +95,5 @@ static inline ir::Attribute TransToIrAttribute(phi::Scalar scalar,
}
}
struct OpInputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
bool is_mutable_attribute = false;
OpInputInfo(std::string name,
std::string type_name,
bool optional,
bool no_need_buffer,
bool is_mutable_attribute)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer),
is_mutable_attribute(is_mutable_attribute) {}
};
struct OpOutputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool intermediate = false;
OpOutputInfo(std::string name,
std::string type_name,
bool optional,
bool intermediate)
: name(name),
type_name(type_name),
optional(optional),
intermediate(intermediate) {}
};
struct OpAttributeInfo {
std::string name;
std::string type_name;
std::string data_type;
OpAttributeInfo(std::string name,
std::string type_name,
std::string data_type)
: name(name), type_name(type_name), data_type(data_type) {}
};
struct OpRunTimeInfo {
std::string infer_meta_func;
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
std::vector<std::string> kernel_key_dtype;
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param,
std::vector<std::string> dtype,
std::vector<std::pair<std::string, std::string>> inplace,
std::vector<std::pair<std::string, std::string>> view)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param),
kernel_key_dtype(dtype),
inplace(inplace),
view(view) {}
};
} // namespace dialect
} // namespace paddle
......@@ -14,7 +14,7 @@
#pragma once
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/ir/core/op_base.h"
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
namespace paddle {
namespace dialect {
OpYamlInfoParser::OpYamlInfoParser(const OpInfoTuple& op_info_tuple)
: op_info_tuple_(op_info_tuple) {
parse();
}
bool OpYamlInfoParser::IsTensorArrtibute(size_t index) const {
PADDLE_ENFORCE_LT(
index,
InputInfo().size(),
phi::errors::OutOfRange("Input index [%d] large than op input size [d]",
index,
InputInfo().size()));
return InputInfo()[index].is_mutable_attribute;
}
size_t OpYamlInfoParser::InputTensorNumber() const {
return input_tensor_number_;
}
const std::string& OpYamlInfoParser::AttrTypeName(
const std::string& name) const {
auto it = map_attr_info_.find(name);
PADDLE_ENFORCE_NE(
it,
map_attr_info_.end(),
phi::errors::NotFound("Not found [%s] in attribute map", name));
return it->second.type_name;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaTensorParams()
const {
return vec_infer_meta_tensor_params_;
}
const std::vector<std::string>& OpYamlInfoParser::InferMetaAttrParams() const {
return vec_infer_meta_attr_params_;
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnTensorParams() const {
return vec_kernel_fn_tensor_params_;
}
const std::vector<std::string>& OpYamlInfoParser::KernelFnAttrParams() const {
return vec_kernel_fn_attr_params_;
}
void OpYamlInfoParser::parse() {
auto input_info = std::get<0>(op_info_tuple_);
int start_index = 0;
for (size_t i = 0; i < input_info.size(); ++i) {
map_name2id_[input_info[i].name] = start_index++;
if (!input_info[i].is_mutable_attribute) {
input_tensor_number_++;
}
map_input_info_[input_info[i].name] = input_info[i];
}
auto attribute_info = std::get<1>(op_info_tuple_);
for (size_t i = 0; i < attribute_info.size(); ++i) {
map_attr_info_[attribute_info[i].name] = attribute_info[i];
}
auto output_info = std::get<2>(op_info_tuple_);
for (size_t i = 0; i < output_info.size(); ++i) {
map_output_info_[output_info[i].name] = output_info[i];
}
auto runtime_info = std::get<3>(op_info_tuple_);
for (auto& name : runtime_info.infer_meta_param) {
if (map_name2id_.count(name)) {
vec_infer_meta_tensor_params_.push_back(name);
} else {
vec_infer_meta_attr_params_.push_back(name);
}
}
for (auto& name : runtime_info.kernel_param) {
if (map_name2id_.count(name)) {
vec_kernel_fn_tensor_params_.push_back(name);
} else {
vec_kernel_fn_attr_params_.push_back(name);
}
}
}
} // namespace dialect
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/ir/interface/op_yaml_info.h"
namespace paddle {
namespace dialect {
class OpYamlInfoParser {
public:
OpYamlInfoParser() = delete;
explicit OpYamlInfoParser(const OpInfoTuple& op_info_tuple);
bool IsTensorArrtibute(size_t index) const;
size_t InputTensorNumber() const;
const std::string& AttrTypeName(const std::string& name) const;
const std::vector<std::string>& InferMetaTensorParams() const;
const std::vector<std::string>& InferMetaAttrParams() const;
const std::vector<std::string>& KernelFnTensorParams() const;
const std::vector<std::string>& KernelFnAttrParams() const;
private:
void parse();
inline const std::vector<OpInputInfo>& InputInfo() const {
return std::get<0>(op_info_tuple_);
}
const OpInfoTuple& op_info_tuple_;
std::map<std::string, int> map_name2id_;
std::map<std::string, OpInputInfo> map_input_info_;
std::map<std::string, OpAttributeInfo> map_attr_info_;
std::map<std::string, OpOutputInfo> map_output_info_;
std::vector<std::string> vec_infer_meta_tensor_params_;
std::vector<std::string> vec_infer_meta_attr_params_;
std::vector<std::string> vec_kernel_fn_tensor_params_;
std::vector<std::string> vec_kernel_fn_attr_params_;
int input_tensor_number_{0};
};
} // namespace dialect
} // namespace paddle
......@@ -20,6 +20,7 @@
#include "paddle/fluid/ir/dialect/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/kernel_op.h"
#include "paddle/fluid/ir/dialect/kernel_type.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
......
......@@ -14,6 +14,7 @@
#include "paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
......
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
......
......@@ -74,3 +74,12 @@ cc_test_old(
ir)
cc_test_old(ir_op_info_test SRCS op_info_test.cc DEPS gtest ir)
cc_test_old(
ir_op_yaml_info_parser_test
SRCS
op_yaml_info_parser_test.cc
DEPS
gtest
pd_dialect
pd_interface
ir)
......@@ -14,6 +14,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/dialect/utils.h"
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/ir/dialect/kernel_dialect.h"
#include "paddle/fluid/ir/dialect/kernel_op.h"
#include "paddle/fluid/ir/dialect/kernel_type.h"
#include "paddle/fluid/ir/dialect/op_yaml_info_util.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/utils.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/ir/dialect/pd_attribute.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_type.h"
#include "paddle/fluid/ir/interface/op_yaml_info_parser.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/fluid/ir/interface/op_yaml_info.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_dialect.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/utils.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
TEST(ir_op_info_test, op_op_info_test) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
ir::Builder builder(ctx, program.block());
auto uniform1 =
builder.Build<paddle::dialect::UniformOp>(std::vector<int64_t>{2, 2},
phi::DataType::FLOAT32,
0.0,
1.0,
2,
phi::CPUPlace());
uniform1->num_operands();
paddle::dialect::OpYamlInfoInterface op_info_interface =
uniform1->dyn_cast<paddle::dialect::OpYamlInfoInterface>();
auto op_info_res = op_info_interface.GetOpInfo();
paddle::dialect::OpYamlInfoParser op_yaml_info_parser(op_info_res);
auto infer_meta_tensor_param = op_yaml_info_parser.InferMetaTensorParams();
auto infer_meta_attr_param = op_yaml_info_parser.InferMetaAttrParams();
auto kernel_fn_tensor_param = op_yaml_info_parser.KernelFnTensorParams();
auto kernel_fn_attr_param = op_yaml_info_parser.KernelFnAttrParams();
EXPECT_EQ(infer_meta_tensor_param.size(), 1u);
EXPECT_EQ(infer_meta_attr_param.size(), 1u);
EXPECT_EQ(kernel_fn_tensor_param.size(), 3u);
EXPECT_EQ(kernel_fn_attr_param.size(), 2u);
EXPECT_EQ((op_yaml_info_parser.AttrTypeName("seed") == "ir::Int32Attribute"),
true);
EXPECT_EQ(op_yaml_info_parser.IsTensorArrtibute(0), true);
EXPECT_EQ(op_yaml_info_parser.InputTensorNumber(), 0u);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册