未验证 提交 4bd5b695 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support static build function for op builder (#54197)

* add build

* add build

* refine code

* refine code

* refine code

* refine code

* refine interface

* fix bug

* fix bug

* fix bug

* refine yaml
上级 4f25604e
...@@ -31,6 +31,8 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST ...@@ -31,6 +31,8 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include <vector> #include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/fluid/dialect/utils.h" #include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/dialect/pd_interface.h" #include "paddle/fluid/dialect/pd_interface.h"
...@@ -54,6 +56,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{ ...@@ -54,6 +56,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
{attribute_declare} {attribute_declare}
static constexpr uint32_t attributes_num = {attribute_num}; static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo(); static OpInfoTuple GetOpInfo();
static void build({build_args});
static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes); static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs} {get_inputs_and_outputs}
{exclusive_interface} {exclusive_interface}
...@@ -81,6 +84,14 @@ CC_FILE_TEMPLATE = """#include "{h_file}" ...@@ -81,6 +84,14 @@ CC_FILE_TEMPLATE = """#include "{h_file}"
#include "paddle/ir/core/builtin_type.h" #include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/ir_context.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h" #include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/nullary.h"
...@@ -97,45 +108,35 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """ ...@@ -97,45 +108,35 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }}; const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
""" """
# get op input info # get op info
OP_INFO_TEMPLATE = """ OP_INFO_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{ OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }}; std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }}; std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }}; std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
return std::make_tuple(inputs, attributes, outputs); paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}});
}} return std::make_tuple(inputs, attributes, outputs, run_time_info);
"""
OP_INPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpInputInfo> {op_name}::inputs_info() {{
return {{ {impl} }};
}} }}
""" """
CONSTRUCT_INPUT_INFO_TEMPLATE = ( CONSTRUCT_INPUT_INFO_TEMPLATE = (
"""OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})""" """OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})"""
) )
# get op output info
OP_OUTPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpOutputInfo> {op_name}::outputs_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = ( CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
"""OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})""" """OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
) )
# get op attribute info
OP_ATTRIBUTE_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpAttributeInfo> {op_name}::attributes_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = ( CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
"""OpAttributeInfo("{name}", "{typename}", "{data_type}")""" """OpAttributeInfo("{name}", "{typename}", "{data_type}")"""
) )
# build
OP_BUILD_TEMPLATE = """
void {op_name}::build({build_args}) {{
{build_inputs}
{build_attributes}
{build_outputs}
}}
"""
# verify # verify
OP_VERIFY_TEMPLATE = """ OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
...@@ -154,6 +155,14 @@ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vecto ...@@ -154,6 +155,14 @@ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vecto
}} }}
""" """
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true, INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input.")); phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
""" """
...@@ -216,14 +225,10 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{ ...@@ -216,14 +225,10 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
}} }}
""" """
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
""" """
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true, ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right.")); phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{ for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true, PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
...@@ -286,6 +291,7 @@ class OpInfoParser: ...@@ -286,6 +291,7 @@ class OpInfoParser:
# parse outputs # parse outputs
self.output_name_list = self.parse_output_name_list() self.output_name_list = self.parse_output_name_list()
self.output_type_list = self.parse_output_type_list() self.output_type_list = self.parse_output_type_list()
self.output_size_list = self.parse_output_size_list()
self.output_optional_list = self.parse_output_optional_list() self.output_optional_list = self.parse_output_optional_list()
self.output_intermediate_list = self.parse_output_intermediate_list() self.output_intermediate_list = self.parse_output_intermediate_list()
self.cross_check( self.cross_check(
...@@ -294,11 +300,67 @@ class OpInfoParser: ...@@ -294,11 +300,67 @@ class OpInfoParser:
self.output_optional_list, self.output_optional_list,
) )
# parse attributes # parse attributes
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['paddle::dialect::ScalarAttribute', 'int'],
'Scalar(int64_t)': ['paddle::dialect::ScalarAttribute', 'int64_t'],
'Scalar(float)': ['paddle::dialect::ScalarAttribute', 'float'],
'Scalar(dobule)': ['paddle::dialect::ScalarAttribute', 'dobule'],
'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'std::vector<Scalar>',
],
'int': ['ir::Int32_tAttribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'],
'int64_t': ['ir::Int64_tAttribute', 'int64_t'],
'long': ['ir::LongAttribute', 'long'],
'size_t': ['ir::Size_tAttribute', 'size_t'],
'float': ['ir::FloatAttribute', 'float'],
'float[]': [
'ir::ArrayAttribute<ir::FloatAttribute>',
'std::vector<float>',
],
'double': ['ir::DoubleAttribute', 'double'],
'bool': ['ir::BoolAttribute', 'bool'],
'bool[]': [
'ir::ArrayAttribute<ir::BoolAttribute>',
'std::vecot<bool>',
],
'str': ['ir::StrAttribute', 'std::string'],
'str[]': [
'ir::ArrayAttribute<ir::StrAttribute>',
'std::vector<std::string>',
],
'Place': ['paddle::dialect::PlaceAttribute', 'Place'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>',
'std::vector<int64_t>',
],
'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>',
'std::vector<int>',
],
}
self.attribute_name_list = self.parse_attribute_name_list() self.attribute_name_list = self.parse_attribute_name_list()
self.attribute_type_list = self.parse_attribute_type_list() self.attribute_type_list = self.parse_attribute_type_list()
self.attribute_build_arg_type_list = (
self.parse_attribute_build_arg_type_list()
)
self.attribute_data_type_list = self.parse_attribute_data_type_list() self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.attribute_default_value_list = (
self.parse_attribute_default_value_list()
)
self.cross_check(self.attribute_name_list, self.attribute_type_list) self.cross_check(self.attribute_name_list, self.attribute_type_list)
# parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map()
if 'infer_meta' in self.op_yaml_item: if 'infer_meta' in self.op_yaml_item:
self.infer_shape_func = self.op_yaml_item['infer_meta']["func"] self.infer_shape_func = self.op_yaml_item['infer_meta']["func"]
else: else:
...@@ -313,6 +375,23 @@ class OpInfoParser: ...@@ -313,6 +375,23 @@ class OpInfoParser:
optional_list optional_list
), "type list size != optional list size." ), "type list size != optional list size."
def parse_op_phi_name(self):
if self.parse_op_inplace_info() is None:
return [self.op_yaml_item['name']]
else:
if self.op_yaml_item['name'][-1] == "_":
return [self.op_yaml_item['name']]
else:
return [
self.op_yaml_item['name'],
self.op_yaml_item['name'] + "_",
]
def parse_op_inplace_info(self):
if 'inplace' in self.op_yaml_item:
return self.op_yaml_item['inplace']
return None
def parse_input_name_list(self): def parse_input_name_list(self):
name_list = [] name_list = []
for input_info in self.op_yaml_item['inputs']: for input_info in self.op_yaml_item['inputs']:
...@@ -369,6 +448,15 @@ class OpInfoParser: ...@@ -369,6 +448,15 @@ class OpInfoParser:
type_list.append(output_type_map[output_info['typename']]) type_list.append(output_type_map[output_info['typename']])
return type_list return type_list
def parse_output_size_list(self):
size_list = []
for output_info in self.op_yaml_item['outputs']:
if 'size' in output_info:
size_list.append(output_info['size'])
else:
size_list.append(None)
return size_list
def parse_output_optional_list(self): def parse_output_optional_list(self):
optional_list = [] optional_list = []
for output_info in self.op_yaml_item['outputs']: for output_info in self.op_yaml_item['outputs']:
...@@ -399,39 +487,31 @@ class OpInfoParser: ...@@ -399,39 +487,31 @@ class OpInfoParser:
name_list.append(attribute_info['name']) name_list.append(attribute_info['name'])
return name_list return name_list
def parse_attribute_build_arg_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
assert (
attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error."
# Scalar & IntArray has data_type
temp_type = self.attr_types_map[attribute_info['typename']][1]
if 'Scalar' in temp_type:
if 'data_type' in attribute_info:
temp_type = attribute_info['data_type']
if 'IntArray' in temp_type:
if 'data_type' in attribute_info:
temp_type = attribute_info['data_type']
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list
def parse_attribute_type_list(self): def parse_attribute_type_list(self):
attr_types_map = {
'IntArray': 'paddle::dialect::IntArrayAttribute',
'Scalar': 'paddle::dialect::ScalarAttribute',
'Scalar(int)': 'paddle::dialect::ScalarAttribute',
'Scalar(int64_t)': 'paddle::dialect::ScalarAttribute',
'Scalar(float)': 'paddle::dialect::ScalarAttribute',
'Scalar(dobule)': 'paddle::dialect::ScalarAttribute',
'Scalar[]': 'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'int': 'ir::Int32_tAttribute',
'int32_t': 'ir::Int32_tAttribute',
'int64_t': 'ir::Int64_tAttribute',
'long': 'ir::LongAttribute',
'size_t': 'ir::Size_tAttribute',
'float': 'ir::FloatAttribute',
'float[]': 'ir::ArrayAttribute<ir::FloatAttribute>',
'double': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
'bool[]': 'ir::ArrayAttribute<ir::BoolAttribute>',
'str': 'ir::StrAttribute',
'str[]': 'ir::ArrayAttribute<ir::StrAttribute>',
'Place': 'paddle::dialect::PlaceAttribute',
'DataLayout': 'paddle::dialect::DataLayoutAttribute',
'DataType': 'paddle::dialect::DataTypeAttribute',
'int64_t[]': 'ir::ArrayAttribute<ir::Int64_tAttribute>',
'int[]': 'ir::ArrayAttribute<ir::Int32_tAttribute>',
}
type_list = [] type_list = []
for attribute_info in self.op_yaml_item['attrs']: for attribute_info in self.op_yaml_item['attrs']:
assert ( assert (
attribute_info['typename'] in attr_types_map attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error." ), f"{self.op_phi_name} : Attr type error."
type_list.append(attr_types_map[attribute_info['typename']]) type_list.append(self.attr_types_map[attribute_info['typename']][0])
return type_list return type_list
def parse_attribute_data_type_list(self): def parse_attribute_data_type_list(self):
...@@ -443,22 +523,48 @@ class OpInfoParser: ...@@ -443,22 +523,48 @@ class OpInfoParser:
data_type_list.append("") data_type_list.append("")
return data_type_list return data_type_list
def parse_op_phi_name(self): def parse_attribute_default_value_list(self):
if self.parse_op_inplace_info() is None: default_value_list = []
return [self.op_yaml_item['name']] for attribute_info in self.op_yaml_item['attrs']:
else: if 'default_value' in attribute_info:
if self.op_yaml_item['name'][-1] == "_": default_value = attribute_info['default_value']
return [self.op_yaml_item['name']] default_value_list.append(
self.get_phi_dtype_name(default_value)
)
else: else:
return [ default_value_list.append(None)
self.op_yaml_item['name'], return default_value_list
self.op_yaml_item['name'] + "_",
]
def parse_op_inplace_info(self): def parse_infer_meta_map(self):
if 'inplace' in self.op_yaml_item: if 'infer_meta' in self.op_yaml_item:
return self.op_yaml_item['inplace'] return self.op_yaml_item['infer_meta']
return None else:
return None
def parse_kernel_map(self):
if 'kernel' in self.op_yaml_item:
return self.op_yaml_item['kernel']
else:
return None
def get_phi_dtype_name(self, name):
name = name.replace('Scalar', 'phi::Scalar')
name = name.replace('IntArray', 'phi::IntArray')
name = name.replace('DataLayout', 'phi::DataLayout')
name = name.replace('DataType', 'phi::DataType')
if name.startswith(
(
"Place",
"CPUPlace",
"GPUPlace",
"GPUPinnedPlace",
"XPUPlace",
"IPUPlace",
"CustomPlace",
)
):
return "phi::" + name
return name
def to_pascal_case(s): def to_pascal_case(s):
...@@ -472,6 +578,280 @@ def to_pascal_case(s): ...@@ -472,6 +578,280 @@ def to_pascal_case(s):
# ===================================== # =====================================
# Generate Op Definition Files # Generate Op Definition Files
# ===================================== # =====================================
def GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
for_func_define=True,
):
'''
Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
'''
build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
if len(op_input_name_list) > 0:
for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_"
for attr_idx in range(len(op_attribute_name_list)):
build_args_str += (
", "
+ op_attribute_build_arg_type_list[attr_idx]
+ " "
+ op_attribute_name_list[attr_idx]
)
if for_func_define:
if op_attribute_default_value_list[attr_idx] is not None:
default_value = op_attribute_default_value_list[attr_idx]
if op_attribute_build_arg_type_list[attr_idx] != "std::string":
if default_value[0] == "'" or default_value[0] == '"':
default_value = default_value[1:]
if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1]
build_args_str += "=" + default_value
return build_args_str
def GenBuildInputs(op_input_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}};
argument.addOperands(argument_inputs.begin(), argument_inputs.end());
"""
build_input_str = ""
if len(op_input_name_list) > 0:
inputs_args_str = "_, ".join(op_input_name_list) + "_"
build_input_str = BUILD_INPUT_TEMPLATE.format(
inputs_args=inputs_args_str
)
return build_input_str
def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list):
INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr}));
"""
SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr}));
"""
STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr});
"""
ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector<ir::Attribute> vec_{attr_name};
for (size_t i = 0; i < static_cast<size_t>({attr_size}); i++) {{
{create_attribute}
vec_{attr_name}.push_back(attr_{attr_name});
}}
ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name});
"""
attr_str = ""
for idx in range(len(op_attribute_name_list)):
if "ir::ArrayAttribute<" in op_attribute_type_list[idx]:
inner_attribute_type = op_attribute_type_list[idx][19:-1]
if inner_attribute_type == "paddle::dialect::IntArrayAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
elif inner_attribute_type == "paddle::dialect::ScalarAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
else:
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
elif (
op_attribute_type_list[idx] == "paddle::dialect::IntArrayAttribute"
):
attr_str += INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute":
attr_str += SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
else:
attr_str += STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
attr_str += """ argument.addAttribute("{attr_name}", attr_{attr_name});\n""".format(
attr_name=op_attribute_name_list[idx]
)
return attr_str
def GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
):
build_output_str = ""
CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
dense_{name}.set_meta(
phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()),
{name}.dims(),
{name}.data_layout(),
{name}.lod(),
{name}.offset())
);
phi::MetaTensor meta_{name}(&dense_{name});
"""
CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}({name}.size(), phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{
vec_dense_{name}[i].set_meta(
phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dtype()),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().data_layout(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset())
);
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}}
std::vector<const phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
"""
CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}(({output_size}), phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < static_cast<size_t>({output_size}); i++) {{
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}}
std::vector<phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""
# Prepar input type
for idx in range(len(op_input_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_input_type_list[idx]:
build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast<ir::VectorType>(); (void){name};\n".format(
name=op_input_name_list[idx]
)
# is a Tensor
else:
build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format(
name=op_input_name_list[idx]
)
# Prepare inputs for infer meta
infer_meta_args = []
for idx in range(len(op_infer_meta_map['param'])):
# is input
if op_infer_meta_map['param'][idx] in op_input_name_list:
if (
"meta_" + op_infer_meta_map['param'][idx]
) not in infer_meta_args:
# is a vector<Tensor>
if (
'ir::VectorType'
in op_input_type_list[
op_input_name_list.index(
op_infer_meta_map['param'][idx]
)
]
):
build_output_str += (
CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format(
name=op_infer_meta_map['param'][idx]
)
)
# is a Tensor
else:
build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format(
name=op_infer_meta_map['param'][idx]
)
infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx])
# is attribute
else:
infer_meta_args.append(op_infer_meta_map['param'][idx])
# Prepare outputs for infer meta
for idx in range(len(op_output_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_output_type_list[idx]:
build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format(
name=op_output_name_list[idx],
output_size=op_output_size_list[idx],
)
infer_meta_args.append(f"meta_{op_output_name_list[idx]}")
# is a Tensor
else:
build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format(
name=op_output_name_list[idx]
)
infer_meta_args.append(f"&meta_{op_output_name_list[idx]}")
# Execute infer meta function
CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args});
"""
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
# use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;"
CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """
ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset());
argument_outputs.push_back({name}_dense_tensor_type);
"""
CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """
std::vector<ir::Type> {name}_types;
for (size_t i=0; i < static_cast<size_t>({output_size}); i++) {{
{name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset()));
}}
ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types);
argument_outputs.push_back({name}_vector_type);
"""
for idx in range(len(op_output_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_output_type_list[idx]:
build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format(
name=op_output_name_list[idx],
output_size=op_output_size_list[idx],
)
# is a Tensor
else:
build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format(
name=op_output_name_list[idx]
)
build_output_str += " argument.addTypes(argument_outputs.begin(), argument_outputs.end());\n"
return build_output_str
def OpGenerator( def OpGenerator(
op_yaml_files, op_yaml_files,
op_compat_yaml_file, op_compat_yaml_file,
...@@ -512,11 +892,16 @@ def OpGenerator( ...@@ -512,11 +892,16 @@ def OpGenerator(
op_input_no_need_buffer_list = op_info.input_no_need_buffer_list op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
op_output_name_list = op_info.output_name_list op_output_name_list = op_info.output_name_list
op_output_type_list = op_info.output_type_list op_output_type_list = op_info.output_type_list
op_output_size_list = op_info.output_size_list
op_output_optional_list = op_info.output_optional_list op_output_optional_list = op_info.output_optional_list
op_output_intermediate_list = op_info.output_intermediate_list op_output_intermediate_list = op_info.output_intermediate_list
op_attribute_name_list = op_info.attribute_name_list op_attribute_name_list = op_info.attribute_name_list
op_attribute_type_list = op_info.attribute_type_list op_attribute_type_list = op_info.attribute_type_list
op_attribute_data_type_list = op_info.attribute_data_type_list op_attribute_data_type_list = op_info.attribute_data_type_list
op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
op_attribute_default_value_list = op_info.attribute_default_value_list
op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map
op_interfaces = ["GetOpInfoInterface"] op_interfaces = ["GetOpInfoInterface"]
op_traits = [] op_traits = []
...@@ -552,6 +937,53 @@ def OpGenerator( ...@@ -552,6 +937,53 @@ def OpGenerator(
output_index=idx, output_index=idx,
) )
# gen build str
build_define_input_args_str = ""
build_declare_input_args_str = ""
build_func_declare_str = ""
if op_infer_meta_map is not None:
build_define_input_args_str = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
True,
)
build_declare_input_args_str = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
False,
)
build_inputs_str = GenBuildInputs(op_input_name_list)
build_attributes_str = GenBuildAttributes(
op_attribute_name_list, op_attribute_type_list
)
build_outputs_str = GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
)
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
else:
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs="",
build_attributes="",
build_outputs="",
)
# gen op_declare_str/op_defined_str # gen op_declare_str/op_defined_str
if len(op_attribute_name_list) == 0: if len(op_attribute_name_list) == 0:
op_declare_str = OP_DECLARE_TEMPLATE.format( op_declare_str = OP_DECLARE_TEMPLATE.format(
...@@ -561,6 +993,7 @@ def OpGenerator( ...@@ -561,6 +993,7 @@ def OpGenerator(
traits=op_traits_str, traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str, attribute_declare=op_0_attribute_declare_str,
attribute_num=0, attribute_num=0,
build_args=build_define_input_args_str,
get_inputs_and_outputs=op_get_inputs_outputs_str, get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str, exclusive_interface=exclusive_interface_str,
) )
...@@ -575,6 +1008,7 @@ def OpGenerator( ...@@ -575,6 +1008,7 @@ def OpGenerator(
attribute_num=len(op_attribute_name_list) attribute_num=len(op_attribute_name_list)
), ),
attribute_num=len(op_attribute_name_list), attribute_num=len(op_attribute_name_list),
build_args=build_define_input_args_str,
get_inputs_and_outputs=op_get_inputs_outputs_str, get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str, exclusive_interface=exclusive_interface_str,
) )
...@@ -631,11 +1065,27 @@ def OpGenerator( ...@@ -631,11 +1065,27 @@ def OpGenerator(
) )
attribute_info_str = ", ".join(attribute_info_list) attribute_info_str = ", ".join(attribute_info_list)
# generate runtiem info
infer_meta_func_str = ""
infer_meta_param_str = ""
if op_infer_meta_map is not None:
infer_meta_func_str = op_infer_meta_map['func']
infer_meta_param_str = '", "'.join(op_infer_meta_map['param'])
kernel_func_str = ""
kernel_param_str = ""
if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param'])
op_info_func_str = OP_INFO_TEMPLATE.format( op_info_func_str = OP_INFO_TEMPLATE.format(
op_name=op_class_name, op_name=op_class_name,
inputs=inputs_info_str, inputs=inputs_info_str,
attributes=attribute_info_str, attributes=attribute_info_str,
outputs=outputs_info_str, outputs=outputs_info_str,
infer_meta_func=infer_meta_func_str,
infer_meta_param=infer_meta_param_str,
kernel_func=kernel_func_str,
kernel_param=kernel_param_str,
) )
# generate op verify function: inputs_type_check_str # generate op verify function: inputs_type_check_str
...@@ -736,14 +1186,19 @@ def OpGenerator( ...@@ -736,14 +1186,19 @@ def OpGenerator(
) )
# generate op verify function # generate op verify function
op_verify_str = OP_VERIFY_TEMPLATE.format( if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_name=op_class_name, op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
inputs_size=len(op_input_type_list), op_name=op_class_name,
outputs_size=len(op_output_type_list), )
inputs_type_check=inputs_type_check_str, else:
outputs_type_check=outputs_type_check_str, op_verify_str = OP_VERIFY_TEMPLATE.format(
attributes_check=attributes_check_str, op_name=op_class_name,
) inputs_size=len(op_input_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
)
op_infer_shape_str = "" op_infer_shape_str = ""
if op_info.infer_shape_func: if op_info.infer_shape_func:
...@@ -756,6 +1211,7 @@ def OpGenerator( ...@@ -756,6 +1211,7 @@ def OpGenerator(
ops_declare_list.append(op_declare_str) ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str) ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str) ops_defined_list.append(op_info_func_str)
ops_defined_list.append(build_func_declare_str)
ops_defined_list.append(op_verify_str) ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_shape_str) ops_defined_list.append(op_infer_shape_str)
...@@ -786,7 +1242,7 @@ def OpGenerator( ...@@ -786,7 +1242,7 @@ def OpGenerator(
namespace=name, input=source_file_str namespace=name, input=source_file_str
) # Add namespaces ) # Add namespaces
source_file_str = CC_FILE_TEMPLATE.format( source_file_str = CC_FILE_TEMPLATE.format(
h_file=op_def_h_file, input=source_file_str h_file=op_def_h_file[:-4], input=source_file_str
) # Add head ) # Add head
# (5) Generate pd_op.h.tmp, pd_op.cc.tmp # (5) Generate pd_op.h.tmp, pd_op.cc.tmp
...@@ -817,6 +1273,7 @@ def ParseArguments(): ...@@ -817,6 +1273,7 @@ def ParseArguments():
# ===================================== # =====================================
if __name__ == "__main__": if __name__ == "__main__":
# parse arguments # parse arguments
print("auto gen op")
args = ParseArguments() args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",") op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file op_compat_yaml_file = args.op_compat_yaml_file
......
...@@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) { ...@@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
std::make_shared<paddle::framework::Variable>(); std::make_shared<paddle::framework::Variable>();
phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>(); phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Init DenseTensor // Init DenseTensor
auto dim = parameter->type().dyn_cast<DenseTensorType>().dim(); auto dim = parameter->type().dyn_cast<DenseTensorType>().dims();
phi::DenseTensorMeta meta( phi::DenseTensorMeta meta(
TransToPhiDataType( TransToPhiDataType(
parameter->type().dyn_cast<DenseTensorType>().dtype()), parameter->type().dyn_cast<DenseTensorType>().dtype()),
phi::DDim(dim.data(), dim.size()), dim,
TransToPhiDataLayout(
parameter->type().dyn_cast<DenseTensorType>().data_layout()), parameter->type().dyn_cast<DenseTensorType>().data_layout(),
parameter->type().dyn_cast<DenseTensorType>().lod(), parameter->type().dyn_cast<DenseTensorType>().lod(),
parameter->type().dyn_cast<DenseTensorType>().offset()); parameter->type().dyn_cast<DenseTensorType>().offset());
tensor->set_meta(meta); tensor->set_meta(meta);
...@@ -67,17 +67,13 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter( ...@@ -67,17 +67,13 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
// Get Meta // Get Meta
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx); ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
DenseTensorTypeStorage::Dim dims(tensor->dims().size());
std::copy(tensor->dims().Get(),
tensor->dims().Get() + tensor->dims().size(),
dims.data());
DenseTensorTypeStorage::DataLayout data_layout =
TransToIrDataLayout(tensor->layout());
DenseTensorTypeStorage::LoD lod = tensor->lod();
size_t offset = tensor->meta().offset;
void *data = tensor->data(); void *data = tensor->data();
ir::Type dense_tensor_type = ir::Type dense_tensor_type = DenseTensorType::get(ctx,
DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset); data_type,
tensor->dims(),
tensor->layout(),
tensor->lod(),
tensor->meta().offset);
return std::make_unique<ir::Parameter>( return std::make_unique<ir::Parameter>(
data, data,
tensor->numel() * phi::SizeOf(tensor->dtype()), tensor->numel() * phi::SizeOf(tensor->dtype()),
...@@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) { ...@@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>(); DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<"; os << "tensor<";
auto &dims = tensor_type.dim(); for (auto d : phi::vectorize(tensor_type.dims())) {
for (auto d : dims) {
os << d; os << d;
os << "x"; os << "x";
} }
......
...@@ -19,25 +19,22 @@ ...@@ -19,25 +19,22 @@
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>, using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>, std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>>; std::vector<paddle::dialect::OpOutputInfo>,
paddle::dialect::OpRunTimeInfo>;
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
public: public:
struct Concept { struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *)) explicit Concept(OpInfoTuple (*get_op_info)())
: get_op_info_(get_op_info) {} : get_op_info_(get_op_info) {}
OpInfoTuple (*get_op_info_)(ir::Operation *); OpInfoTuple (*get_op_info_)();
}; };
template <class ConcreteOp> template <class ConcreteOp>
struct Model : public Concept { struct Model : public Concept {
static OpInfoTuple GetOpInfo(ir::Operation *op) { static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); }
ConcreteOp concret_op = op->dyn_cast<ConcreteOp>();
if (concret_op == nullptr) throw("concret_op is nullptr");
return concret_op.GetOpInfo();
}
Model() : Concept(GetOpInfo) {} Model() : Concept(GetOpInfo) {}
}; };
...@@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> { ...@@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
GetOpInfoInterface(ir::Operation *op, Concept *impl) GetOpInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {} : ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); } OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); }
private: private:
Concept *impl_; Concept *impl_;
......
...@@ -11,17 +11,6 @@ ...@@ -11,17 +11,6 @@
- {typename: Tensor, name: out, optional: false, intermediate: false} - {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null inplace: null
backward: null backward: null
- name: fetch - name: fetch
...@@ -37,16 +26,5 @@ ...@@ -37,16 +26,5 @@
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false} - {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null no_need_buffer: null
data_transform: null data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null inplace: null
backward: null backward: null
...@@ -18,20 +18,13 @@ namespace paddle { ...@@ -18,20 +18,13 @@ namespace paddle {
namespace dialect { namespace dialect {
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; } const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const paddle::dialect::DenseTensorTypeStorage::Dim& DenseTensorType::dim() const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; }
const {
return storage()->dims_;
}
const paddle::dialect::DenseTensorTypeStorage::DataLayout& const phi::DataLayout& DenseTensorType::data_layout() const {
DenseTensorType::data_layout() const {
return storage()->layout_; return storage()->layout_;
} }
const paddle::dialect::DenseTensorTypeStorage::LoD& DenseTensorType::lod() const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; }
const {
return storage()->lod_;
}
const size_t& DenseTensorType::offset() const { return storage()->offset_; } const size_t& DenseTensorType::offset() const { return storage()->offset_; }
......
...@@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type { ...@@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type {
const ir::Type &dtype() const; const ir::Type &dtype() const;
const paddle::dialect::DenseTensorTypeStorage::Dim &dim() const; const phi::DDim &dims() const;
const paddle::dialect::DenseTensorTypeStorage::DataLayout &data_layout() const phi::DataLayout &data_layout() const;
const;
const paddle::dialect::DenseTensorTypeStorage::LoD &lod() const; const phi::LoD &lod() const;
const size_t &offset() const; const size_t &offset() const;
}; };
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/ir/core/type.h" #include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h" #include "paddle/ir/core/utils.h"
#include "paddle/phi/core/tensor_meta.h"
namespace std { namespace std {
/// ///
...@@ -46,46 +47,20 @@ namespace dialect { ...@@ -46,46 +47,20 @@ namespace dialect {
/// (3)define HashValue method, (4)overload operator==. /// (3)define HashValue method, (4)overload operator==.
/// ///
struct DenseTensorTypeStorage : public ir::TypeStorage { struct DenseTensorTypeStorage : public ir::TypeStorage {
/// using DataLayout = phi::DataLayout;
/// \brief It is consistent with the DataLayout defined by Phi operator using Dim = phi::DDim;
/// library. See the file for details: paddle/phi/common/layout.h.
///
enum class DataLayout : unsigned int {
UNDEFINED = 0,
NHWC,
NCHW,
NCDHW,
NDHWC,
ONEDNN,
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
// for compatible with fluid DataLayout, here need prefix `k`
kNHWC = NHWC,
kNCHW = NCHW,
kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally
kNDHWC = NDHWC,
kNCDHW = NCDHW,
};
using Dim = std::vector<int64_t>;
using LoD = std::vector<std::vector<size_t>>; using LoD = std::vector<std::vector<size_t>>;
/// ///
/// \brief Declare ParamKey according to parameter type. /// \brief Declare ParamKey according to parameter type.
/// ///
using ParamKey = std::tuple<ir::Type, Dim, DataLayout, LoD, size_t>; using ParamKey =
std::tuple<ir::Type, phi::DDim, phi::DataLayout, phi::LoD, size_t>;
DenseTensorTypeStorage(
ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset) DenseTensorTypeStorage(ir::Type dtype,
phi::DDim dims,
phi::DataLayout layout,
phi::LoD lod,
size_t offset)
: dtype_(dtype), : dtype_(dtype),
dims_(dims), dims_(dims),
layout_(layout), layout_(layout),
...@@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { ...@@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key))); ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims // hash dims
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key))); ir::hash_combine(hash_value, std::hash<phi::DDim>()(std::get<1>(key)));
// hash layout // hash layout
hash_value = ir::hash_combine( hash_value = ir::hash_combine(
hash_value, hash_value,
std::hash<std::underlying_type<DataLayout>::type>()( std::hash<std::underlying_type<phi::DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>( static_cast<std::underlying_type<phi::DataLayout>::type>(
std::get<2>(key)))); std::get<2>(key))));
// hash lod // hash lod
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key))); ir::hash_combine(hash_value, std::hash<phi::LoD>()(std::get<3>(key)));
// hash offset // hash offset
hash_value = hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key))); ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
...@@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage { ...@@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
/// layout, lod, offset. /// layout, lod, offset.
/// ///
ir::Type dtype_; ir::Type dtype_;
Dim dims_; phi::DDim dims_;
DataLayout layout_; phi::DataLayout layout_;
LoD lod_; phi::LoD lod_;
size_t offset_; size_t offset_;
}; };
......
...@@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype, ...@@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
} }
} }
inline phi::DataLayout TransToPhiDataLayout( // inline phi::DataLayout TransToPhiDataLayout(
DenseTensorTypeStorage::DataLayout data_layout) { // DenseTensorTypeStorage::DataLayout data_layout) {
switch (data_layout) { // switch (data_layout) {
case DenseTensorTypeStorage::DataLayout::NHWC: // case DenseTensorTypeStorage::DataLayout::NHWC:
return phi::DataLayout::NHWC; // return phi::DataLayout::NHWC;
case DenseTensorTypeStorage::DataLayout::NCHW: // case DenseTensorTypeStorage::DataLayout::NCHW:
return phi::DataLayout::NCHW; // return phi::DataLayout::NCHW;
case DenseTensorTypeStorage::DataLayout::NCDHW: // case DenseTensorTypeStorage::DataLayout::NCDHW:
return phi::DataLayout::NCDHW; // return phi::DataLayout::NCDHW;
case DenseTensorTypeStorage::DataLayout::NDHWC: // case DenseTensorTypeStorage::DataLayout::NDHWC:
return phi::DataLayout::NDHWC; // return phi::DataLayout::NDHWC;
case DenseTensorTypeStorage::DataLayout::ONEDNN: // case DenseTensorTypeStorage::DataLayout::ONEDNN:
return phi::DataLayout::ONEDNN; // return phi::DataLayout::ONEDNN;
case DenseTensorTypeStorage::DataLayout::SPARSE_COO: // case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
return phi::DataLayout::SPARSE_COO; // return phi::DataLayout::SPARSE_COO;
case DenseTensorTypeStorage::DataLayout::SPARSE_CSR: // case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
return phi::DataLayout::SPARSE_CSR; // return phi::DataLayout::SPARSE_CSR;
case DenseTensorTypeStorage::DataLayout::PSTRING_UNION: // case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
return phi::DataLayout::PSTRING_UNION; // return phi::DataLayout::PSTRING_UNION;
case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS: // case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
return phi::DataLayout::NUM_DATA_LAYOUTS; // return phi::DataLayout::NUM_DATA_LAYOUTS;
case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT: // case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
return phi::DataLayout::ALL_LAYOUT; // return phi::DataLayout::ALL_LAYOUT;
default: // default:
PADDLE_THROW(phi::errors::Unimplemented( // PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir data layout `%s` when casting it into " // "Unsupported ir data layout `%s` when casting it into "
"phi data type.", // "phi data type.",
static_cast<int>(data_layout))); // static_cast<int>(data_layout)));
} // }
} // }
inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout( // inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
phi::DataLayout data_layout) { // phi::DataLayout data_layout) {
switch (data_layout) { // switch (data_layout) {
case phi::DataLayout::NHWC: // case phi::DataLayout::NHWC:
return DenseTensorTypeStorage::DataLayout::NHWC; // return DenseTensorTypeStorage::DataLayout::NHWC;
case phi::DataLayout::NCHW: // case phi::DataLayout::NCHW:
return DenseTensorTypeStorage::DataLayout::NCHW; // return DenseTensorTypeStorage::DataLayout::NCHW;
case phi::DataLayout::NCDHW: // case phi::DataLayout::NCDHW:
return DenseTensorTypeStorage::DataLayout::NCDHW; // return DenseTensorTypeStorage::DataLayout::NCDHW;
case phi::DataLayout::NDHWC: // case phi::DataLayout::NDHWC:
return DenseTensorTypeStorage::DataLayout::NDHWC; // return DenseTensorTypeStorage::DataLayout::NDHWC;
case phi::DataLayout::ONEDNN: // case phi::DataLayout::ONEDNN:
return DenseTensorTypeStorage::DataLayout::ONEDNN; // return DenseTensorTypeStorage::DataLayout::ONEDNN;
case phi::DataLayout::SPARSE_COO: // case phi::DataLayout::SPARSE_COO:
return DenseTensorTypeStorage::DataLayout::SPARSE_COO; // return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
case phi::DataLayout::SPARSE_CSR: // case phi::DataLayout::SPARSE_CSR:
return DenseTensorTypeStorage::DataLayout::SPARSE_CSR; // return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
case phi::DataLayout::PSTRING_UNION: // case phi::DataLayout::PSTRING_UNION:
return DenseTensorTypeStorage::DataLayout::PSTRING_UNION; // return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
case phi::DataLayout::NUM_DATA_LAYOUTS: // case phi::DataLayout::NUM_DATA_LAYOUTS:
return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS; // return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
case phi::DataLayout::ALL_LAYOUT: // case phi::DataLayout::ALL_LAYOUT:
return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT; // return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
default: // default:
PADDLE_THROW(phi::errors::Unimplemented( // PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data layout `%s` when casting it into " // "Unsupported phi data layout `%s` when casting it into "
"ir data type.", // "ir data type.",
static_cast<int>(data_layout))); // static_cast<int>(data_layout)));
} // }
} // }
// inline phi::DenseTensorMeta TransToDenseTensorMeta(
// paddle::dialect::DenseTensorType type) {
// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()),
// type.dim(),
// type.data_layout(),
// type.lod(),
// type.offset());
// }
struct OpInputInfo { struct OpInputInfo {
std::string name; std::string name;
...@@ -172,5 +181,20 @@ struct OpAttributeInfo { ...@@ -172,5 +181,20 @@ struct OpAttributeInfo {
: name(name), type_name(type_name), data_type(data_type) {} : name(name), type_name(type_name), data_type(data_type) {}
}; };
struct OpRunTimeInfo {
std::string infer_meta_func;
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param) {}
};
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() { ...@@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() {
ir::Type dtype = ir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc); this->operator[](var_desc.GetDataType())(ctx, var_desc);
DenseTensorTypeStorage::Dim dim = var_desc.GetShape(); DenseTensorTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout = DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED; DenseTensorTypeStorage::DataLayout::UNDEFINED;
DenseTensorTypeStorage::LoD lod = {}; DenseTensorTypeStorage::LoD lod = {};
......
...@@ -25,7 +25,7 @@ namespace ir { ...@@ -25,7 +25,7 @@ namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \ #define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \
struct concrete_storage : public ir::AttributeStorage { \ struct concrete_storage : public ir::AttributeStorage { \
using ParamKey = bool; \ using ParamKey = base_type; \
\ \
explicit concrete_storage(const ParamKey &key) { data_ = key; } \ explicit concrete_storage(const ParamKey &key) { data_ = key; } \
\ \
......
...@@ -221,7 +221,7 @@ ...@@ -221,7 +221,7 @@
- backward_op : broadcast_tensors_grad - backward_op : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] input) -> Tensor[](out) forward : broadcast_tensors (Tensor[] input) -> Tensor[](out)
args : (Tensor[] input, Tensor[] out_grad) args : (Tensor[] input, Tensor[] out_grad)
output : Tensor[](input_grad) output : Tensor[](input_grad){input.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
param : [input] param : [input]
......
...@@ -235,7 +235,7 @@ ...@@ -235,7 +235,7 @@
- backward_op : einsum_grad - backward_op : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape) forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation) args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()} output : Tensor[](x_grad){x_shape.size()}
infer_meta : infer_meta :
func : UnchangedMultiInferMeta func : UnchangedMultiInferMeta
param : [x_shape] param : [x_shape]
......
...@@ -107,10 +107,9 @@ TEST(program_test, program) { ...@@ -107,10 +107,9 @@ TEST(program_test, program) {
a_interface->ParameterToVariable(program.GetParameter("a")); a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>(); const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4); EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), phi::DDim(dims.data(), dims.size())); EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(), EXPECT_EQ(a_tensor.layout(), data_layout);
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(a_tensor.lod(), lod); EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset); EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) { for (int64_t i = 0; i < a_tensor.numel(); i++) {
...@@ -137,10 +136,9 @@ TEST(program_test, program) { ...@@ -137,10 +136,9 @@ TEST(program_test, program) {
b_interface->ParameterToVariable(program.GetParameter("b")); b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>(); const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4); EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), phi::DDim(dims.data(), dims.size())); EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype)); EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(), EXPECT_EQ(b_tensor.layout(), data_layout);
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(b_tensor.lod(), lod); EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset); EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) { for (int64_t i = 0; i < b_tensor.numel(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册