未验证 提交 14425c06 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Refine OP auto code gen (#54186)

* refine auto gen

* refine code

* refine code

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug

* fix bug
上级 4bcb5cc4
......@@ -8,16 +8,17 @@ set(op_forward_yaml_file1
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/ops.parsed.yaml
)
set(op_forward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_ops.parsed.yaml
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml
)
set(op_backward_yaml_file1
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/backward_ops.parsed.yaml
)
set(op_backward_yaml_file2
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/static_backward.parsed.yaml
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_backward_ops.parsed.yaml
)
set(op_yaml_file3 ${PADDLE_SOURCE_DIR}/paddle/fluid/dialect/pd_op.yaml)
set(op_yaml_files
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2}
${op_forward_yaml_file1},${op_forward_yaml_file2},${op_backward_yaml_file1},${op_backward_yaml_file2},${op_yaml_file3}
)
set(op_namespace paddle,dialect)
set(dialect_name pd)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
namespace paddle {
namespace dialect {
#define OPNAME(op_name) "pd." #op_name
#define REIGSTER_EMPTY_OP(op_name, className) \
class className : public ir::Op<className> { \
public: \
static const char *name() { return OPNAME(op_name); } \
static constexpr const char **attributes_name = nullptr; \
static constexpr uint32_t attributes_num = 0; \
static void verify(const std::vector<ir::OpResult> &inputs, \
const std::vector<ir::Type> &outputs, \
const ir::AttributeMap &attributes) { \
LOG(WARNING) << "This is a fake verify"; \
} \
};
// TODO(zhangbo): As operators are supplemented and defined, they are gradually
// removed.
REIGSTER_EMPTY_OP(conv2d, Conv2DOp); // To be customized: conv2d
REIGSTER_EMPTY_OP(feed, FeedOp); // To be customized: feed
REIGSTER_EMPTY_OP(batch_norm, BatchNormOp); // To be customized: batch_norm
REIGSTER_EMPTY_OP(batch_norm_, BatchNormOp_); // To be customized: batch_norm_
REIGSTER_EMPTY_OP(elementwise_add,
ElementwiseAddOp); // To be customized: add (elementwise_add)
REIGSTER_EMPTY_OP(pool2d, Pool2DOp); // To be customized: pool2d
REIGSTER_EMPTY_OP(
flatten_contiguous_range,
FlattenContiguousRangeOp); // flatten (flatten_contiguous_range)
REIGSTER_EMPTY_OP(matmul_v2,
MatmulV2Op); // To be customized: matmul (matmul_v2)
REIGSTER_EMPTY_OP(reshape2, Reshape2Op); // To be customized: reshape
REIGSTER_EMPTY_OP(softmax_with_cross_entropy,
SoftmaxWithCrossEntropyOp); // cross_entropy_with_softmax
// (softmax_with_cross_entropy)
REIGSTER_EMPTY_OP(reduce_mean,
ReduceMeanOp); // To be customized: mean (reduce_mean)
REIGSTER_EMPTY_OP(top_k_v2, TopKV2Op); // topk (top_k_v2)
REIGSTER_EMPTY_OP(fill_constant,
FillConstantOp); // To be customized: full (fill_constant)
REIGSTER_EMPTY_OP(reduce_mean_grad,
ReduceMeanGradOp); // To be customized: reduce_mean_grad
REIGSTER_EMPTY_OP(
softmax_with_cross_entropy_grad,
SoftmaxWithCrossEntropyGradOp); // cross_entropy_with_softmax_grad
// (softmax_with_cross_entropy_grad)
REIGSTER_EMPTY_OP(
elementwise_add_grad,
ElementwiseAddGradOp); // To be customized: add_grad (elementwise_add_grad)
REIGSTER_EMPTY_OP(
matmul_v2_grad,
MatmulV2GradOp); // To be customized: matmul_grad (matmul_v2_grad)
REIGSTER_EMPTY_OP(
flatten_contiguous_range_grad,
FlattenContiguousRangeGradOp); // flatten_grad
// (flatten_contiguous_range_grad)
REIGSTER_EMPTY_OP(pool2d_grad, Pool2DGradOp); // To be customized: pool2d_grad
REIGSTER_EMPTY_OP(batch_norm_grad,
BatchNormGradOp); // To be customized: batch_norm_grad
REIGSTER_EMPTY_OP(conv2d_grad, Conv2DGradOp); // To be customized: conv2d_grad
REIGSTER_EMPTY_OP(sum, SumOp); // To be customized: sum(reduce_sum)
REIGSTER_EMPTY_OP(fetch_v2, FetchV2Op); // To be customized: fetch_v2
REIGSTER_EMPTY_OP(add, AddOp);
REIGSTER_EMPTY_OP(add_grad, AddGradOp);
REIGSTER_EMPTY_OP(matmul, MatMulOp);
REIGSTER_EMPTY_OP(matmul_grad, MatMulGradOp);
REIGSTER_EMPTY_OP(reshape, ReshapeOp);
REIGSTER_EMPTY_OP(reshape_grad, ReshapeGradOp);
REIGSTER_EMPTY_OP(mean, MeanOp);
REIGSTER_EMPTY_OP(cross_entropy_with_softmax, CrossEntropyOp);
REIGSTER_EMPTY_OP(cross_entropy_with_softmax_grad, CrossEntropyGradOp);
REIGSTER_EMPTY_OP(topk, TopKOp);
REIGSTER_EMPTY_OP(topk_grad, TopKGradOp);
REIGSTER_EMPTY_OP(full, FullOp);
REIGSTER_EMPTY_OP(add_n, AddNOp);
} // namespace dialect
} // namespace paddle
......@@ -29,7 +29,11 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
{op_declare}
#else
#include <vector>
#include "paddle/ir/core/op_base.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/dialect/pd_interface.h"
{input}
#endif
......@@ -45,6 +49,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
static const char *name() {{ return "{dialect_op_name}"; }}
{attribute_declare}
static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo();
static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs}
}};
......@@ -79,6 +84,46 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
"""
# get op input info
OP_INFO_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
return std::make_tuple(inputs, attributes, outputs);
}}
"""
OP_INPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpInputInfo> {op_name}::inputs_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_INPUT_INFO_TEMPLATE = (
"""OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})"""
)
# get op output info
OP_OUTPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpOutputInfo> {op_name}::outputs_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
"""OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
)
# get op attribute info
OP_ATTRIBUTE_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpAttributeInfo> {op_name}::attributes_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
"""OpAttributeInfo("{name}", "{typename}", "{data_type}")"""
)
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
VLOG(4) << "Verifying inputs, outputs and attributes for: {op_name}.";
......@@ -158,10 +203,14 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
}}
"""
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true,
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
......@@ -170,32 +219,65 @@ ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.at("{attribute
"""
def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
names = op_item.split('(')
if len(names) == 1:
phi_fluid_name = names[0].strip()
return phi_fluid_name, phi_fluid_name
else:
phi_name = names[0].strip()
fluid_name = names[1].split(')')[0].strip()
return phi_name, fluid_name
# =====================================
# Parse Op Compat From Yaml
# =====================================
# Parse Op information from Yaml item
class OpCompatParser:
def __init__(self, ops_compat_yaml_file):
self.ops_compat_yaml_file = ops_compat_yaml_file
with open(self.ops_compat_yaml_file, "r") as f:
self.ops_compat = yaml.safe_load(f)
def get_compat(self, op_name):
for compat in self.ops_compat:
phi_name, fluid_name = to_phi_and_fluid_op_name(compat['op'])
if op_name == phi_name:
return compat
return None
# =====================================
# Parse Op Information From Yaml
# =====================================
class OpInfoParser:
def __init__(self, op_yaml_item):
def __init__(self, op_yaml_item, op_compat_item):
self.op_yaml_item = op_yaml_item
self.op_compat_item = op_compat_item
self.op_phi_name = self.parse_op_phi_name()
# parse inputs
self.input_name_list = self.parse_input_name_list()
self.input_type_list = self.parse_input_type_list()
self.input_optional_list = self.parse_input_optional_list()
self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list()
self.cross_check(
self.input_name_list, self.input_type_list, self.input_optional_list
)
# parse outputs
self.output_name_list = self.parse_output_name_list()
self.output_type_list = self.parse_output_type_list()
self.output_optional_list = self.parse_output_optional_list()
self.output_intermediate_list = self.parse_output_intermediate_list()
self.cross_check(
self.output_name_list,
self.output_type_list,
self.output_optional_list,
)
# parse attributes
self.attribute_name_list = self.parse_attribute_name_list()
self.attribute_type_list = self.parse_attribute_type_list()
self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.cross_check(self.attribute_name_list, self.attribute_type_list)
def cross_check(self, name_list, type_list, optional_list=None):
......@@ -229,9 +311,21 @@ class OpInfoParser:
def parse_input_optional_list(self):
optional_list = []
for input_info in self.op_yaml_item['inputs']:
optional_list.append(input_info['optional'])
if input_info['optional']:
optional_list.append("true")
else:
optional_list.append("false")
return optional_list
def parse_input_no_need_buffer_list(self):
no_need_buffer_list = []
for input_info in self.op_yaml_item['inputs']:
if input_info['no_need_buffer']:
no_need_buffer_list.append("true")
else:
no_need_buffer_list.append("false")
return no_need_buffer_list
def parse_output_name_list(self):
name_list = []
for output_info in self.op_yaml_item['outputs']:
......@@ -255,11 +349,26 @@ class OpInfoParser:
optional_list = []
for output_info in self.op_yaml_item['outputs']:
if 'optional' in output_info:
optional_list.append(output_info['optional'])
if output_info['optional']:
optional_list.append("true")
else:
optional_list.append(False)
optional_list.append("false")
else:
optional_list.append("false")
return optional_list
def parse_output_intermediate_list(self):
intermediate_list = []
for output_info in self.op_yaml_item['outputs']:
if 'intermediate' in output_info:
if output_info['intermediate']:
intermediate_list.append("true")
else:
intermediate_list.append("false")
else:
intermediate_list.append("false")
return intermediate_list
def parse_attribute_name_list(self):
name_list = []
for attribute_info in self.op_yaml_item['attrs']:
......@@ -301,8 +410,31 @@ class OpInfoParser:
type_list.append(attr_types_map[attribute_info['typename']])
return type_list
def parse_attribute_data_type_list(self):
data_type_list = []
for attribute_info in self.op_yaml_item['attrs']:
if 'data_type' in attribute_info:
data_type_list.append(attribute_info['data_type'])
else:
data_type_list.append("")
return data_type_list
def parse_op_phi_name(self):
return self.op_yaml_item['name']
if self.parse_op_inplace_info() is None:
return [self.op_yaml_item['name']]
else:
if self.op_yaml_item['name'][-1] == "_":
return [self.op_yaml_item['name']]
else:
return [
self.op_yaml_item['name'],
self.op_yaml_item['name'] + "_",
]
def parse_op_inplace_info(self):
if 'inplace' in self.op_yaml_item:
return self.op_yaml_item['inplace']
return None
def to_pascal_case(s):
......@@ -314,10 +446,11 @@ def to_pascal_case(s):
# =====================================
# Generate op definition files
# Generate Op Definition Files
# =====================================
def OpGenerator(
op_yaml_files,
op_compat_yaml_file,
namespaces,
dialect_name,
op_def_h_file,
......@@ -330,6 +463,8 @@ def OpGenerator(
os.remove(op_def_cc_file)
# (2) Prepare: Get all op item in all op_yaml_files
op_compat_parser = OpCompatParser(op_compat_yaml_file)
op_yaml_items = []
for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:
......@@ -337,7 +472,9 @@ def OpGenerator(
op_yaml_items = op_yaml_items + ops
op_info_items = []
for op in op_yaml_items:
op_info_items.append(OpInfoParser(op))
op_info_items.append(
OpInfoParser(op, op_compat_parser.get_compat(op['name']))
)
# (3) CodeGen: Traverse op_info_items and generate
ops_name_list = [] # all op class name store in this list
......@@ -345,36 +482,43 @@ def OpGenerator(
ops_defined_list = [] # all op class defined store in this list
for op_info in op_info_items:
# get op info
op_name = op_info.op_phi_name
op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name
op_input_name_list = op_info.input_name_list
op_input_type_list = op_info.input_type_list
op_input_optional_list = op_info.input_optional_list
op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
op_output_name_list = op_info.output_name_list
op_output_type_list = op_info.output_type_list
op_output_optional_list = op_info.output_optional_list
op_output_intermediate_list = op_info.output_intermediate_list
op_attribute_name_list = op_info.attribute_name_list
op_attribute_type_list = op_info.attribute_type_list
op_interfaces = []
op_attribute_data_type_list = op_info.attribute_data_type_list
op_interfaces = ["GetOpInfoInterface"]
op_traits = []
# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
op_class_name = to_pascal_case(op_name) + "Op"
op_dialect_name = dialect_name + "." + op_name
# gen interface/trait str
op_interfaces_str = ""
if len(op_interfaces) > 0:
op_interfaces_str = "," + ",".join(op_interfaces)
op_traits_str = ""
if len(op_interfaces) > 0:
if len(op_traits) > 0:
op_traits_str = "," + ",".join(op_traits)
op_get_inputs_outputs_str = ""
for idx in range(len(op_input_name_list)):
op_get_inputs_outputs_str += OP_GET_INPUT_TEMPLATE.format(
input_name=op_input_name_list[idx], input_index=idx
input_name=op_input_name_list[idx],
input_index=idx,
)
for idx in range(len(op_output_name_list)):
op_get_inputs_outputs_str += OP_GET_OUTPUT_TEMPLATE.format(
output_name=op_output_name_list[idx], output_index=idx
output_name=op_output_name_list[idx],
output_index=idx,
)
# gen op_declare_str/op_defined_str
......@@ -410,6 +554,57 @@ def OpGenerator(
attribute_names=attribute_names_str,
)
# generate get op info funciton: inputs
inputs_info_str = ""
if len(op_input_name_list) > 0:
input_info_list = []
for idx in range(len(op_input_name_list)):
input_info_list.append(
CONSTRUCT_INPUT_INFO_TEMPLATE.format(
name=op_input_name_list[idx],
typename=op_input_type_list[idx],
optional=op_input_optional_list[idx],
no_need_buffer=op_input_no_need_buffer_list[idx],
)
)
inputs_info_str = ", ".join(input_info_list)
# generate get op info funciton: outputs
outputs_info_str = ""
if len(op_output_name_list) > 0:
output_info_list = []
for idx in range(len(op_output_name_list)):
output_info_list.append(
CONSTRUCT_OUTPUT_INFO_TEMPLATE.format(
name=op_output_name_list[idx],
typename=op_output_type_list[idx],
optional=op_output_optional_list[idx],
intermediate=op_output_intermediate_list[idx],
)
)
outputs_info_str = ", ".join(output_info_list)
# generate get op info funciton: attributes
attribute_info_str = ""
if len(op_attribute_name_list) > 0:
attribute_info_list = []
for idx in range(len(op_attribute_name_list)):
attribute_info_list.append(
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
name=op_attribute_name_list[idx],
typename=op_attribute_type_list[idx],
data_type=op_attribute_data_type_list[idx],
)
)
attribute_info_str = ", ".join(attribute_info_list)
op_info_func_str = OP_INFO_TEMPLATE.format(
op_name=op_class_name,
inputs=inputs_info_str,
attributes=attribute_info_str,
outputs=outputs_info_str,
)
# generate op verify function: inputs_type_check_str
if len(op_input_type_list) == 0:
inputs_type_check_str = (
......@@ -425,11 +620,13 @@ def OpGenerator(
is_vector = True
input_type = input_type[15:-1]
check_str = ""
if is_optional:
if is_optional == "true":
if is_vector:
check_str = INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
check_str = (
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
)
)
else:
check_str = INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE.format(
index=idx, standard=input_type
......@@ -460,7 +657,7 @@ def OpGenerator(
is_vector = True
output_type = output_type[15:-1]
check_str = ""
if is_optional:
if is_optional == "true":
if is_vector:
check_str = (
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE.format(
......@@ -494,8 +691,11 @@ def OpGenerator(
attribute_type = op_attribute_type_list[idx]
if attribute_type.startswith("ir::ArrayAttribute<"):
attribute_type = attribute_type[19:-1]
attributes_check_str += ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name, standard=attribute_type
attributes_check_str += (
ATTRIBUTE_VECTOR_CHECK_TEMPLATE.format(
attribute_name=attribute_name,
standard=attribute_type,
)
)
else:
attributes_check_str += ATTRIBUTE_CHECK_TEMPLATE.format(
......@@ -515,6 +715,7 @@ def OpGenerator(
ops_name_list.append(op_class_name)
ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str)
ops_defined_list.append(op_verify_str)
# (4) Generate head file str
......@@ -588,6 +789,7 @@ if __name__ == "__main__":
# auto code generate
OpGenerator(
op_yaml_files,
op_compat_yaml_file,
namespaces,
dialect_name,
op_def_h_file,
......
......@@ -16,7 +16,6 @@
#include "paddle/fluid/dialect/pd_attribute.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/dialect/CMakeLists.txt.
#include "paddle/fluid/dialect/legacy_pd_op.h"
#include "paddle/fluid/dialect/pd_op.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/pd_type_storage.h"
......@@ -111,42 +110,6 @@ void PaddleDialect::initialize() {
>();
RegisterInterfaces<ParameterConvertInterface>();
RegisterOps<Conv2DOp,
FeedOp,
BatchNormOp,
BatchNormOp_,
ElementwiseAddOp,
Pool2DOp,
FlattenContiguousRangeOp,
MatmulV2Op,
Reshape2Op,
SoftmaxWithCrossEntropyOp,
ReduceMeanOp,
TopKV2Op,
FillConstantOp,
ReduceMeanGradOp,
SoftmaxWithCrossEntropyGradOp,
ElementwiseAddGradOp,
MatmulV2GradOp,
FlattenContiguousRangeGradOp,
Pool2DGradOp,
BatchNormGradOp,
Conv2DGradOp,
SumOp,
FetchV2Op,
AddOp,
MatMulOp,
ReshapeOp,
CrossEntropyOp,
TopKOp,
FullOp,
MeanOp,
AddNOp,
AddGradOp,
MatMulGradOp,
ReshapeGradOp,
CrossEntropyGradOp,
TopKGradOp>();
}
void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/op_base.h"
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>>;
namespace paddle {
namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
public:
struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *))
: get_op_info_(get_op_info) {}
OpInfoTuple (*get_op_info_)(ir::Operation *);
};
template <class ConcreteOp>
struct Model : public Concept {
static OpInfoTuple GetOpInfo(ir::Operation *op) {
ConcreteOp concret_op = op->dyn_cast<ConcreteOp>();
if (concret_op == nullptr) throw("concret_op is nullptr");
return concret_op.GetOpInfo();
}
Model() : Concept(GetOpInfo) {}
};
GetOpInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); }
private:
Concept *impl_;
};
} // namespace dialect
} // namespace paddle
- name: feed
inputs:
- typename: Tensor[]
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int, name: col}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null
backward: null
- name: fetch
inputs:
- typename: Tensor
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int, name: col}
outputs:
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null
backward: null
......@@ -132,5 +132,45 @@ inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
}
}
struct OpInputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool no_need_buffer = false;
OpInputInfo(std::string name,
std::string type_name,
bool optional,
bool no_need_buffer)
: name(name),
type_name(type_name),
optional(optional),
no_need_buffer(no_need_buffer) {}
};
struct OpOutputInfo {
std::string name;
std::string type_name;
bool optional = false;
bool intermediate = false;
OpOutputInfo(std::string name,
std::string type_name,
bool optional,
bool intermediate)
: name(name),
type_name(type_name),
optional(optional),
intermediate(intermediate) {}
};
struct OpAttributeInfo {
std::string name;
std::string type_name;
std::string data_type;
OpAttributeInfo(std::string name,
std::string type_name,
std::string data_type)
: name(name), type_name(type_name), data_type(data_type) {}
};
} // namespace dialect
} // namespace paddle
......@@ -296,7 +296,7 @@
- op : einsum
args : (Tensor[] x, str equation)
output : Tensor, Tensor[]{x.size()}, Tensor[]{x.size()}
output : Tensor(out), Tensor[](inner_cache){x.size()}, Tensor[](xshape){x.size()}
infer_meta :
func : EinsumRawInferMeta
param : [x, equation]
......
......@@ -50,10 +50,7 @@ class InferShapeInterface : public ir::OpInterfaceBase<InferShapeInterface> {
concret_op.InferShape();
}
Model() : Concept(InferShape) {
static_assert(sizeof(Model) == sizeof(Concept),
"sizeof(Model) != sizeof(Concept)");
}
Model() : Concept(InferShape) {}
};
InferShapeInterface(ir::Operation *op, Concept *impl)
......
......@@ -15,6 +15,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/dialect/pd_dialect.h"
#include "paddle/fluid/dialect/pd_interface.h"
#include "paddle/fluid/dialect/pd_type.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/ir/core/builtin_attribute.h"
......@@ -177,7 +178,21 @@ TEST(program_test, program) {
EXPECT_EQ(*(dst_tensor->data<float>() + i), data_a[i] + data_b[i]);
}
// (7) Def SetParameterOp(c, "c")
// (7) Def AbsOp(b)
ir::OpInfo abs_info = ctx->GetRegisteredOpInfo("pd.abs");
std::vector<ir::OpResult> operands = {op1->GetResultByIndex(0)};
std::unordered_map<std::string, ir::Attribute> abs_op_attribute;
std::vector<ir::Type> output_types = {dense_tensor_dtype};
ir::OperationArgument abs_argument(abs_info);
abs_argument.addOperands(operands.begin(), operands.end());
abs_argument.addAttributes(abs_op_attribute.begin(), abs_op_attribute.end());
abs_argument.addTypes(output_types.begin(), output_types.end());
ir::Operation *abs_op = ir::Operation::create(std::move(abs_argument));
paddle::dialect::GetOpInfoInterface interface =
abs_op->dyn_cast<paddle::dialect::GetOpInfoInterface>();
EXPECT_EQ(std::get<0>(interface.GetOpInfo())[0].name == "x", true);
// (8) Def SetParameterOp(c, "c")
std::string op4_name =
builtin_dialect->name() + "." + std::string(ir::SetParameterOp::name());
ir::OpInfo op4_info = ctx->GetRegisteredOpInfo(op4_name);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册