未验证 提交 4bd5b695 编写于 作者: Z zhangbo9674 提交者: GitHub

[IR] Support static build function for op builder (#54197)

* add build

* add build

* refine code

* refine code

* refine code

* refine code

* refine interface

* fix bug

* fix bug

* fix bug

* refine yaml
上级 4f25604e
......@@ -31,6 +31,8 @@ H_FILE_TEMPLATE = """#ifdef GET_OP_LIST
#include <vector>
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/operation_utils.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/fluid/dialect/utils.h"
#include "paddle/fluid/dialect/pd_interface.h"
......@@ -54,6 +56,7 @@ class {op_name} : public ir::Op<{op_name}{interfaces}{traits}> {{
{attribute_declare}
static constexpr uint32_t attributes_num = {attribute_num};
static OpInfoTuple GetOpInfo();
static void build({build_args});
static void verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes);
{get_inputs_and_outputs}
{exclusive_interface}
......@@ -81,6 +84,14 @@ CC_FILE_TEMPLATE = """#include "{h_file}"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/nullary.h"
......@@ -97,45 +108,35 @@ OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
"""
# get op input info
# get op info
OP_INFO_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
return std::make_tuple(inputs, attributes, outputs);
}}
"""
OP_INPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpInputInfo> {op_name}::inputs_info() {{
return {{ {impl} }};
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, {{"{kernel_func}"}}, {{"{kernel_param}"}});
return std::make_tuple(inputs, attributes, outputs, run_time_info);
}}
"""
CONSTRUCT_INPUT_INFO_TEMPLATE = (
"""OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer})"""
)
# get op output info
OP_OUTPUT_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpOutputInfo> {op_name}::outputs_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = (
"""OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
)
# get op attribute info
OP_ATTRIBUTE_INFO_TEMPLATE = """
std::vector<paddle::dialect::OpAttributeInfo> {op_name}::attributes_info() {{
return {{ {impl} }};
}}
"""
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = (
"""OpAttributeInfo("{name}", "{typename}", "{data_type}")"""
)
# build
OP_BUILD_TEMPLATE = """
void {op_name}::build({build_args}) {{
{build_inputs}
{build_attributes}
{build_outputs}
}}
"""
# verify
OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
......@@ -154,6 +155,14 @@ void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vecto
}}
"""
GRAD_OP_VERIFY_TEMPLATE = """
void {op_name}::verify(const std::vector<ir::OpResult> &inputs, const std::vector<ir::Type> &outputs, const ir::AttributeMap &attributes) {{
(void)inputs;
(void)outputs;
(void)attributes;
}}
"""
INPUT_TYPE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(inputs[{index}].type().isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type validation failed for the {index}th input."));
"""
......@@ -216,14 +225,10 @@ OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """if (outputs[{index}]) {{
}}
"""
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true,
ATTRIBUTE_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<{standard}>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0, true,
phi::errors::PreconditionNotMet("The AttributeMap miss mandatory attributes of: {attribute_name}."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """PADDLE_ENFORCE_EQ(attributes.count("{attribute_name}")>0 && attributes.at("{attribute_name}").isa<ir::ArrayAttribute>(), true,
phi::errors::PreconditionNotMet("Type of attribute: {attribute_name} is not right."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>().size(); i++) {{
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<ir::ArrayAttribute>()[i].isa<{standard}>(), true,
......@@ -286,6 +291,7 @@ class OpInfoParser:
# parse outputs
self.output_name_list = self.parse_output_name_list()
self.output_type_list = self.parse_output_type_list()
self.output_size_list = self.parse_output_size_list()
self.output_optional_list = self.parse_output_optional_list()
self.output_intermediate_list = self.parse_output_intermediate_list()
self.cross_check(
......@@ -294,11 +300,67 @@ class OpInfoParser:
self.output_optional_list,
)
# parse attributes
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['paddle::dialect::ScalarAttribute', 'int'],
'Scalar(int64_t)': ['paddle::dialect::ScalarAttribute', 'int64_t'],
'Scalar(float)': ['paddle::dialect::ScalarAttribute', 'float'],
'Scalar(dobule)': ['paddle::dialect::ScalarAttribute', 'dobule'],
'Scalar[]': [
'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'std::vector<Scalar>',
],
'int': ['ir::Int32_tAttribute', 'int'],
'int32_t': ['ir::Int32_tAttribute', 'int32_t'],
'int64_t': ['ir::Int64_tAttribute', 'int64_t'],
'long': ['ir::LongAttribute', 'long'],
'size_t': ['ir::Size_tAttribute', 'size_t'],
'float': ['ir::FloatAttribute', 'float'],
'float[]': [
'ir::ArrayAttribute<ir::FloatAttribute>',
'std::vector<float>',
],
'double': ['ir::DoubleAttribute', 'double'],
'bool': ['ir::BoolAttribute', 'bool'],
'bool[]': [
'ir::ArrayAttribute<ir::BoolAttribute>',
'std::vecot<bool>',
],
'str': ['ir::StrAttribute', 'std::string'],
'str[]': [
'ir::ArrayAttribute<ir::StrAttribute>',
'std::vector<std::string>',
],
'Place': ['paddle::dialect::PlaceAttribute', 'Place'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'ir::ArrayAttribute<ir::Int64_tAttribute>',
'std::vector<int64_t>',
],
'int[]': [
'ir::ArrayAttribute<ir::Int32_tAttribute>',
'std::vector<int>',
],
}
self.attribute_name_list = self.parse_attribute_name_list()
self.attribute_type_list = self.parse_attribute_type_list()
self.attribute_build_arg_type_list = (
self.parse_attribute_build_arg_type_list()
)
self.attribute_data_type_list = self.parse_attribute_data_type_list()
self.attribute_default_value_list = (
self.parse_attribute_default_value_list()
)
self.cross_check(self.attribute_name_list, self.attribute_type_list)
# parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map()
if 'infer_meta' in self.op_yaml_item:
self.infer_shape_func = self.op_yaml_item['infer_meta']["func"]
else:
......@@ -313,6 +375,23 @@ class OpInfoParser:
optional_list
), "type list size != optional list size."
def parse_op_phi_name(self):
if self.parse_op_inplace_info() is None:
return [self.op_yaml_item['name']]
else:
if self.op_yaml_item['name'][-1] == "_":
return [self.op_yaml_item['name']]
else:
return [
self.op_yaml_item['name'],
self.op_yaml_item['name'] + "_",
]
def parse_op_inplace_info(self):
if 'inplace' in self.op_yaml_item:
return self.op_yaml_item['inplace']
return None
def parse_input_name_list(self):
name_list = []
for input_info in self.op_yaml_item['inputs']:
......@@ -369,6 +448,15 @@ class OpInfoParser:
type_list.append(output_type_map[output_info['typename']])
return type_list
def parse_output_size_list(self):
size_list = []
for output_info in self.op_yaml_item['outputs']:
if 'size' in output_info:
size_list.append(output_info['size'])
else:
size_list.append(None)
return size_list
def parse_output_optional_list(self):
optional_list = []
for output_info in self.op_yaml_item['outputs']:
......@@ -399,39 +487,31 @@ class OpInfoParser:
name_list.append(attribute_info['name'])
return name_list
def parse_attribute_build_arg_type_list(self):
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
assert (
attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error."
# Scalar & IntArray has data_type
temp_type = self.attr_types_map[attribute_info['typename']][1]
if 'Scalar' in temp_type:
if 'data_type' in attribute_info:
temp_type = attribute_info['data_type']
if 'IntArray' in temp_type:
if 'data_type' in attribute_info:
temp_type = attribute_info['data_type']
type_list.append(self.get_phi_dtype_name(temp_type))
return type_list
def parse_attribute_type_list(self):
attr_types_map = {
'IntArray': 'paddle::dialect::IntArrayAttribute',
'Scalar': 'paddle::dialect::ScalarAttribute',
'Scalar(int)': 'paddle::dialect::ScalarAttribute',
'Scalar(int64_t)': 'paddle::dialect::ScalarAttribute',
'Scalar(float)': 'paddle::dialect::ScalarAttribute',
'Scalar(dobule)': 'paddle::dialect::ScalarAttribute',
'Scalar[]': 'ir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'int': 'ir::Int32_tAttribute',
'int32_t': 'ir::Int32_tAttribute',
'int64_t': 'ir::Int64_tAttribute',
'long': 'ir::LongAttribute',
'size_t': 'ir::Size_tAttribute',
'float': 'ir::FloatAttribute',
'float[]': 'ir::ArrayAttribute<ir::FloatAttribute>',
'double': 'ir::DoubleAttribute',
'bool': 'ir::BoolAttribute',
'bool[]': 'ir::ArrayAttribute<ir::BoolAttribute>',
'str': 'ir::StrAttribute',
'str[]': 'ir::ArrayAttribute<ir::StrAttribute>',
'Place': 'paddle::dialect::PlaceAttribute',
'DataLayout': 'paddle::dialect::DataLayoutAttribute',
'DataType': 'paddle::dialect::DataTypeAttribute',
'int64_t[]': 'ir::ArrayAttribute<ir::Int64_tAttribute>',
'int[]': 'ir::ArrayAttribute<ir::Int32_tAttribute>',
}
type_list = []
for attribute_info in self.op_yaml_item['attrs']:
assert (
attribute_info['typename'] in attr_types_map
attribute_info['typename'] in self.attr_types_map
), f"{self.op_phi_name} : Attr type error."
type_list.append(attr_types_map[attribute_info['typename']])
type_list.append(self.attr_types_map[attribute_info['typename']][0])
return type_list
def parse_attribute_data_type_list(self):
......@@ -443,22 +523,48 @@ class OpInfoParser:
data_type_list.append("")
return data_type_list
def parse_op_phi_name(self):
if self.parse_op_inplace_info() is None:
return [self.op_yaml_item['name']]
else:
if self.op_yaml_item['name'][-1] == "_":
return [self.op_yaml_item['name']]
def parse_attribute_default_value_list(self):
default_value_list = []
for attribute_info in self.op_yaml_item['attrs']:
if 'default_value' in attribute_info:
default_value = attribute_info['default_value']
default_value_list.append(
self.get_phi_dtype_name(default_value)
)
else:
return [
self.op_yaml_item['name'],
self.op_yaml_item['name'] + "_",
]
default_value_list.append(None)
return default_value_list
def parse_op_inplace_info(self):
if 'inplace' in self.op_yaml_item:
return self.op_yaml_item['inplace']
return None
def parse_infer_meta_map(self):
if 'infer_meta' in self.op_yaml_item:
return self.op_yaml_item['infer_meta']
else:
return None
def parse_kernel_map(self):
if 'kernel' in self.op_yaml_item:
return self.op_yaml_item['kernel']
else:
return None
def get_phi_dtype_name(self, name):
name = name.replace('Scalar', 'phi::Scalar')
name = name.replace('IntArray', 'phi::IntArray')
name = name.replace('DataLayout', 'phi::DataLayout')
name = name.replace('DataType', 'phi::DataType')
if name.startswith(
(
"Place",
"CPUPlace",
"GPUPlace",
"GPUPinnedPlace",
"XPUPlace",
"IPUPlace",
"CustomPlace",
)
):
return "phi::" + name
return name
def to_pascal_case(s):
......@@ -472,6 +578,280 @@ def to_pascal_case(s):
# =====================================
# Generate Op Definition Files
# =====================================
def GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
for_func_define=True,
):
'''
Example: ir::Builder &builder, ir::OperationArgument &argument, ir::OpResult x_, phi::DataType dtype=phi::DataType::UNDEFINED, phi::Place place={}
'''
build_args_str = "ir::Builder &builder, ir::OperationArgument &argument"
if len(op_input_name_list) > 0:
for input_name in op_input_name_list:
build_args_str += ", ir::OpResult " + input_name + "_"
for attr_idx in range(len(op_attribute_name_list)):
build_args_str += (
", "
+ op_attribute_build_arg_type_list[attr_idx]
+ " "
+ op_attribute_name_list[attr_idx]
)
if for_func_define:
if op_attribute_default_value_list[attr_idx] is not None:
default_value = op_attribute_default_value_list[attr_idx]
if op_attribute_build_arg_type_list[attr_idx] != "std::string":
if default_value[0] == "'" or default_value[0] == '"':
default_value = default_value[1:]
if default_value[-1] == "'" or default_value[-1] == '"':
default_value = default_value[0:-1]
build_args_str += "=" + default_value
return build_args_str
def GenBuildInputs(op_input_name_list):
BUILD_INPUT_TEMPLATE = """ std::vector<ir::OpResult> argument_inputs = {{{inputs_args}}};
argument.addOperands(argument_inputs.begin(), argument_inputs.end());
"""
build_input_str = ""
if len(op_input_name_list) > 0:
inputs_args_str = "_, ".join(op_input_name_list) + "_"
build_input_str = BUILD_INPUT_TEMPLATE.format(
inputs_args=inputs_args_str
)
return build_input_str
def GenBuildAttributes(op_attribute_name_list, op_attribute_type_list):
INTARRAY_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::IntArray({attr}));
"""
SCALAR_STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), phi::Scalar({attr}));
"""
STR_TEMPLATE = """ ir::Attribute attr_{attr_name} = {op_attribute_type}::get(ir::IrContext::Instance(), {attr});
"""
ARRAY_ATTRIBUTE_TEMPLATE = """ std::vector<ir::Attribute> vec_{attr_name};
for (size_t i = 0; i < static_cast<size_t>({attr_size}); i++) {{
{create_attribute}
vec_{attr_name}.push_back(attr_{attr_name});
}}
ir::Attribute attr_{attr_name} = ir::ArrayAttribute::get(ir::IrContext::Instance(), vec_{attr_name});
"""
attr_str = ""
for idx in range(len(op_attribute_name_list)):
if "ir::ArrayAttribute<" in op_attribute_type_list[idx]:
inner_attribute_type = op_attribute_type_list[idx][19:-1]
if inner_attribute_type == "paddle::dialect::IntArrayAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
elif inner_attribute_type == "paddle::dialect::ScalarAttribute":
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
else:
attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
attr_size=op_attribute_name_list[idx] + ".size()",
create_attribute=STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=inner_attribute_type,
attr=op_attribute_name_list[idx] + "[i]",
),
)
elif (
op_attribute_type_list[idx] == "paddle::dialect::IntArrayAttribute"
):
attr_str += INTARRAY_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
elif op_attribute_type_list[idx] == "paddle::dialect::ScalarAttribute":
attr_str += SCALAR_STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
else:
attr_str += STR_TEMPLATE.format(
attr_name=op_attribute_name_list[idx],
op_attribute_type=op_attribute_type_list[idx],
attr=op_attribute_name_list[idx],
)
attr_str += """ argument.addAttribute("{attr_name}", attr_{attr_name});\n""".format(
attr_name=op_attribute_name_list[idx]
)
return attr_str
def GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
):
build_output_str = ""
CREATE_INPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
dense_{name}.set_meta(
phi::DenseTensorMeta(TransToPhiDataType({name}.dtype()),
{name}.dims(),
{name}.data_layout(),
{name}.lod(),
{name}.offset())
);
phi::MetaTensor meta_{name}(&dense_{name});
"""
CREATE_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}({name}.size(), phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < static_cast<size_t>({name}.size()); i++) {{
vec_dense_{name}[i].set_meta(
phi::DenseTensorMeta(TransToPhiDataType({name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dtype()),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().data_layout(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
{name}[i].dyn_cast<paddle::dialect::DenseTensorType>().offset())
);
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}}
std::vector<const phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""
CREATE_OUTPUT_METATENSOR_TEMPLATE = """ phi::DenseTensor dense_{name};
phi::MetaTensor meta_{name}(&dense_{name});
"""
CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE = """ std::vector<phi::DenseTensor> vec_dense_{name}(({output_size}), phi::DenseTensor());
std::vector<phi::MetaTensor> vec_meta_{name};
for (size_t i=0; i < static_cast<size_t>({output_size}); i++) {{
vec_meta_{name}.push_back(phi::MetaTensor(&vec_dense_{name}[i]));
}}
std::vector<phi::MetaTensor*> meta_{name};
for (size_t i=0; i < static_cast<size_t>(vec_meta_{name}.size()); i++) {{
meta_{name}.push_back(&vec_meta_{name}[i]);
}}
"""
# Prepar input type
for idx in range(len(op_input_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_input_type_list[idx]:
build_output_str += " ir::VectorType {name} = {name}_.type().dyn_cast<ir::VectorType>(); (void){name};\n".format(
name=op_input_name_list[idx]
)
# is a Tensor
else:
build_output_str += " paddle::dialect::DenseTensorType {name} = {name}_.type().dyn_cast<paddle::dialect::DenseTensorType>(); (void){name};\n".format(
name=op_input_name_list[idx]
)
# Prepare inputs for infer meta
infer_meta_args = []
for idx in range(len(op_infer_meta_map['param'])):
# is input
if op_infer_meta_map['param'][idx] in op_input_name_list:
if (
"meta_" + op_infer_meta_map['param'][idx]
) not in infer_meta_args:
# is a vector<Tensor>
if (
'ir::VectorType'
in op_input_type_list[
op_input_name_list.index(
op_infer_meta_map['param'][idx]
)
]
):
build_output_str += (
CREATE_INPUT_VEC_METATENSOR_TEMPLATE.format(
name=op_infer_meta_map['param'][idx]
)
)
# is a Tensor
else:
build_output_str += CREATE_INPUT_METATENSOR_TEMPLATE.format(
name=op_infer_meta_map['param'][idx]
)
infer_meta_args.append("meta_" + op_infer_meta_map['param'][idx])
# is attribute
else:
infer_meta_args.append(op_infer_meta_map['param'][idx])
# Prepare outputs for infer meta
for idx in range(len(op_output_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_output_type_list[idx]:
build_output_str += CREATE_OUTPUT_VEC_METATENSOR_TEMPLATE.format(
name=op_output_name_list[idx],
output_size=op_output_size_list[idx],
)
infer_meta_args.append(f"meta_{op_output_name_list[idx]}")
# is a Tensor
else:
build_output_str += CREATE_OUTPUT_METATENSOR_TEMPLATE.format(
name=op_output_name_list[idx]
)
infer_meta_args.append(f"&meta_{op_output_name_list[idx]}")
# Execute infer meta function
CREATE_INFER_META_FUNC_TEMPLATE = """
phi::{func}({args});
"""
build_output_str += CREATE_INFER_META_FUNC_TEMPLATE.format(
func=op_infer_meta_map['func'], args=", ".join(infer_meta_args)
)
# use dense_{name} or vec_dense_{name} to create Outputs type
build_output_str += "\n std::vector<ir::Type> argument_outputs;"
CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE = """
ir::Type {name}_dense_tensor_type = paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(dense_{name}.dtype()), dense_{name}.dims(), dense_{name}.layout(), dense_{name}.lod(), dense_{name}.offset());
argument_outputs.push_back({name}_dense_tensor_type);
"""
CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE = """
std::vector<ir::Type> {name}_types;
for (size_t i=0; i < static_cast<size_t>({output_size}); i++) {{
{name}_types.push_back(paddle::dialect::DenseTensorType::get(ir::IrContext::Instance(), TransToIrDataType(vec_dense_{name}[i].dtype()), vec_dense_{name}[i].dims(), vec_dense_{name}[i].layout(), vec_dense_{name}[i].lod(), vec_dense_{name}[i].offset()));
}}
ir::Type {name}_vector_type = ir::VectorType::get(ir::IrContext::Instance(), {name}_types);
argument_outputs.push_back({name}_vector_type);
"""
for idx in range(len(op_output_name_list)):
# is a vector<Tensor>
if 'ir::VectorType' in op_output_type_list[idx]:
build_output_str += CREATE_OUTPUT_VEC_DENSE_TENSOR_TEMPLATE.format(
name=op_output_name_list[idx],
output_size=op_output_size_list[idx],
)
# is a Tensor
else:
build_output_str += CREATE_OUTPUT_DENSE_TENSOR_TEMPLATE.format(
name=op_output_name_list[idx]
)
build_output_str += " argument.addTypes(argument_outputs.begin(), argument_outputs.end());\n"
return build_output_str
def OpGenerator(
op_yaml_files,
op_compat_yaml_file,
......@@ -512,11 +892,16 @@ def OpGenerator(
op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
op_output_name_list = op_info.output_name_list
op_output_type_list = op_info.output_type_list
op_output_size_list = op_info.output_size_list
op_output_optional_list = op_info.output_optional_list
op_output_intermediate_list = op_info.output_intermediate_list
op_attribute_name_list = op_info.attribute_name_list
op_attribute_type_list = op_info.attribute_type_list
op_attribute_data_type_list = op_info.attribute_data_type_list
op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
op_attribute_default_value_list = op_info.attribute_default_value_list
op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map
op_interfaces = ["GetOpInfoInterface"]
op_traits = []
......@@ -552,6 +937,53 @@ def OpGenerator(
output_index=idx,
)
# gen build str
build_define_input_args_str = ""
build_declare_input_args_str = ""
build_func_declare_str = ""
if op_infer_meta_map is not None:
build_define_input_args_str = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
True,
)
build_declare_input_args_str = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
False,
)
build_inputs_str = GenBuildInputs(op_input_name_list)
build_attributes_str = GenBuildAttributes(
op_attribute_name_list, op_attribute_type_list
)
build_outputs_str = GenBuildOutputs(
op_input_name_list,
op_input_type_list,
op_output_name_list,
op_output_type_list,
op_output_size_list,
op_infer_meta_map,
)
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs=build_inputs_str,
build_attributes=build_attributes_str,
build_outputs=build_outputs_str,
)
else:
build_func_declare_str = OP_BUILD_TEMPLATE.format(
op_name=op_class_name,
build_args=build_declare_input_args_str,
build_inputs="",
build_attributes="",
build_outputs="",
)
# gen op_declare_str/op_defined_str
if len(op_attribute_name_list) == 0:
op_declare_str = OP_DECLARE_TEMPLATE.format(
......@@ -561,6 +993,7 @@ def OpGenerator(
traits=op_traits_str,
attribute_declare=op_0_attribute_declare_str,
attribute_num=0,
build_args=build_define_input_args_str,
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
......@@ -575,6 +1008,7 @@ def OpGenerator(
attribute_num=len(op_attribute_name_list)
),
attribute_num=len(op_attribute_name_list),
build_args=build_define_input_args_str,
get_inputs_and_outputs=op_get_inputs_outputs_str,
exclusive_interface=exclusive_interface_str,
)
......@@ -631,11 +1065,27 @@ def OpGenerator(
)
attribute_info_str = ", ".join(attribute_info_list)
# generate runtiem info
infer_meta_func_str = ""
infer_meta_param_str = ""
if op_infer_meta_map is not None:
infer_meta_func_str = op_infer_meta_map['func']
infer_meta_param_str = '", "'.join(op_infer_meta_map['param'])
kernel_func_str = ""
kernel_param_str = ""
if op_kernel_map is not None:
kernel_func_str = '", "'.join(op_kernel_map['func'])
kernel_param_str = '", "'.join(op_kernel_map['param'])
op_info_func_str = OP_INFO_TEMPLATE.format(
op_name=op_class_name,
inputs=inputs_info_str,
attributes=attribute_info_str,
outputs=outputs_info_str,
infer_meta_func=infer_meta_func_str,
infer_meta_param=infer_meta_param_str,
kernel_func=kernel_func_str,
kernel_param=kernel_param_str,
)
# generate op verify function: inputs_type_check_str
......@@ -736,14 +1186,19 @@ def OpGenerator(
)
# generate op verify function
op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
)
if "GradOp" in op_class_name or "Grad_Op" in op_class_name:
op_verify_str = GRAD_OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
)
else:
op_verify_str = OP_VERIFY_TEMPLATE.format(
op_name=op_class_name,
inputs_size=len(op_input_type_list),
outputs_size=len(op_output_type_list),
inputs_type_check=inputs_type_check_str,
outputs_type_check=outputs_type_check_str,
attributes_check=attributes_check_str,
)
op_infer_shape_str = ""
if op_info.infer_shape_func:
......@@ -756,6 +1211,7 @@ def OpGenerator(
ops_declare_list.append(op_declare_str)
ops_defined_list.append(op_defined_str)
ops_defined_list.append(op_info_func_str)
ops_defined_list.append(build_func_declare_str)
ops_defined_list.append(op_verify_str)
ops_defined_list.append(op_infer_shape_str)
......@@ -786,7 +1242,7 @@ def OpGenerator(
namespace=name, input=source_file_str
) # Add namespaces
source_file_str = CC_FILE_TEMPLATE.format(
h_file=op_def_h_file, input=source_file_str
h_file=op_def_h_file[:-4], input=source_file_str
) # Add head
# (5) Generate pd_op.h.tmp, pd_op.cc.tmp
......@@ -817,6 +1273,7 @@ def ParseArguments():
# =====================================
if __name__ == "__main__":
# parse arguments
print("auto gen op")
args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file
......
......@@ -35,13 +35,13 @@ ParameterConvertInterface::ParameterToVariable(ir::Parameter *parameter) {
std::make_shared<paddle::framework::Variable>();
phi::DenseTensor *tensor = var->GetMutable<phi::DenseTensor>();
// Init DenseTensor
auto dim = parameter->type().dyn_cast<DenseTensorType>().dim();
auto dim = parameter->type().dyn_cast<DenseTensorType>().dims();
phi::DenseTensorMeta meta(
TransToPhiDataType(
parameter->type().dyn_cast<DenseTensorType>().dtype()),
phi::DDim(dim.data(), dim.size()),
TransToPhiDataLayout(
parameter->type().dyn_cast<DenseTensorType>().data_layout()),
dim,
parameter->type().dyn_cast<DenseTensorType>().data_layout(),
parameter->type().dyn_cast<DenseTensorType>().lod(),
parameter->type().dyn_cast<DenseTensorType>().offset());
tensor->set_meta(meta);
......@@ -67,17 +67,13 @@ std::unique_ptr<ir::Parameter> ParameterConvertInterface::VariableToParameter(
// Get Meta
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Type data_type = TransToIrDataType(tensor->dtype(), ctx);
DenseTensorTypeStorage::Dim dims(tensor->dims().size());
std::copy(tensor->dims().Get(),
tensor->dims().Get() + tensor->dims().size(),
dims.data());
DenseTensorTypeStorage::DataLayout data_layout =
TransToIrDataLayout(tensor->layout());
DenseTensorTypeStorage::LoD lod = tensor->lod();
size_t offset = tensor->meta().offset;
void *data = tensor->data();
ir::Type dense_tensor_type =
DenseTensorType::get(ctx, data_type, dims, data_layout, lod, offset);
ir::Type dense_tensor_type = DenseTensorType::get(ctx,
data_type,
tensor->dims(),
tensor->layout(),
tensor->lod(),
tensor->meta().offset);
return std::make_unique<ir::Parameter>(
data,
tensor->numel() * phi::SizeOf(tensor->dtype()),
......@@ -116,8 +112,7 @@ void PaddleDialect::PrintType(ir::Type type, std::ostream &os) {
DenseTensorType tensor_type = type.dyn_cast<DenseTensorType>();
os << "tensor<";
auto &dims = tensor_type.dim();
for (auto d : dims) {
for (auto d : phi::vectorize(tensor_type.dims())) {
os << d;
os << "x";
}
......
......@@ -19,25 +19,22 @@
using OpInfoTuple = std::tuple<std::vector<paddle::dialect::OpInputInfo>,
std::vector<paddle::dialect::OpAttributeInfo>,
std::vector<paddle::dialect::OpOutputInfo>>;
std::vector<paddle::dialect::OpOutputInfo>,
paddle::dialect::OpRunTimeInfo>;
namespace paddle {
namespace dialect {
class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
public:
struct Concept {
explicit Concept(OpInfoTuple (*get_op_info)(ir::Operation *))
explicit Concept(OpInfoTuple (*get_op_info)())
: get_op_info_(get_op_info) {}
OpInfoTuple (*get_op_info_)(ir::Operation *);
OpInfoTuple (*get_op_info_)();
};
template <class ConcreteOp>
struct Model : public Concept {
static OpInfoTuple GetOpInfo(ir::Operation *op) {
ConcreteOp concret_op = op->dyn_cast<ConcreteOp>();
if (concret_op == nullptr) throw("concret_op is nullptr");
return concret_op.GetOpInfo();
}
static OpInfoTuple GetOpInfo() { return ConcreteOp::GetOpInfo(); }
Model() : Concept(GetOpInfo) {}
};
......@@ -45,7 +42,7 @@ class GetOpInfoInterface : public ir::OpInterfaceBase<GetOpInfoInterface> {
GetOpInfoInterface(ir::Operation *op, Concept *impl)
: ir::OpInterfaceBase<GetOpInfoInterface>(op), impl_(impl) {}
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(operation()); }
OpInfoTuple GetOpInfo() { return impl_->get_op_info_(); }
private:
Concept *impl_;
......
......@@ -11,17 +11,6 @@
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null
backward: null
- name: fetch
......@@ -37,16 +26,5 @@
- {typename: 'Tensor[]', name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: null
param: null
kernel:
func: null
param: null
backend: null
layout: null
data_type: null
dispatch: null
force_backend: null
inplace: null
backward: null
......@@ -18,20 +18,13 @@ namespace paddle {
namespace dialect {
const ir::Type& DenseTensorType::dtype() const { return storage()->dtype_; }
const paddle::dialect::DenseTensorTypeStorage::Dim& DenseTensorType::dim()
const {
return storage()->dims_;
}
const phi::DDim& DenseTensorType::dims() const { return storage()->dims_; }
const paddle::dialect::DenseTensorTypeStorage::DataLayout&
DenseTensorType::data_layout() const {
const phi::DataLayout& DenseTensorType::data_layout() const {
return storage()->layout_;
}
const paddle::dialect::DenseTensorTypeStorage::LoD& DenseTensorType::lod()
const {
return storage()->lod_;
}
const phi::LoD& DenseTensorType::lod() const { return storage()->lod_; }
const size_t& DenseTensorType::offset() const { return storage()->offset_; }
......
......@@ -30,12 +30,11 @@ class DenseTensorType : public ir::Type {
const ir::Type &dtype() const;
const paddle::dialect::DenseTensorTypeStorage::Dim &dim() const;
const phi::DDim &dims() const;
const paddle::dialect::DenseTensorTypeStorage::DataLayout &data_layout()
const;
const phi::DataLayout &data_layout() const;
const paddle::dialect::DenseTensorTypeStorage::LoD &lod() const;
const phi::LoD &lod() const;
const size_t &offset() const;
};
......
......@@ -18,6 +18,7 @@
#include "paddle/ir/core/type.h"
#include "paddle/ir/core/utils.h"
#include "paddle/phi/core/tensor_meta.h"
namespace std {
///
......@@ -46,46 +47,20 @@ namespace dialect {
/// (3)define HashValue method, (4)overload operator==.
///
struct DenseTensorTypeStorage : public ir::TypeStorage {
///
/// \brief It is consistent with the DataLayout defined by Phi operator
/// library. See the file for details: paddle/phi/common/layout.h.
///
enum class DataLayout : unsigned int {
UNDEFINED = 0,
NHWC,
NCHW,
NCDHW,
NDHWC,
ONEDNN,
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
// for compatible with fluid DataLayout, here need prefix `k`
kNHWC = NHWC,
kNCHW = NCHW,
kMKLDNN = ONEDNN, // all layouts supported by ONEDNN internally
kNDHWC = NDHWC,
kNCDHW = NCDHW,
};
using Dim = std::vector<int64_t>;
using DataLayout = phi::DataLayout;
using Dim = phi::DDim;
using LoD = std::vector<std::vector<size_t>>;
///
/// \brief Declare ParamKey according to parameter type.
///
using ParamKey = std::tuple<ir::Type, Dim, DataLayout, LoD, size_t>;
DenseTensorTypeStorage(
ir::Type dtype, Dim dims, DataLayout layout, LoD lod, size_t offset)
using ParamKey =
std::tuple<ir::Type, phi::DDim, phi::DataLayout, phi::LoD, size_t>;
DenseTensorTypeStorage(ir::Type dtype,
phi::DDim dims,
phi::DataLayout layout,
phi::LoD lod,
size_t offset)
: dtype_(dtype),
dims_(dims),
layout_(layout),
......@@ -114,16 +89,16 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
ir::hash_combine(hash_value, std::hash<ir::Type>()(std::get<0>(key)));
// hash dims
hash_value =
ir::hash_combine(hash_value, std::hash<Dim>()(std::get<1>(key)));
ir::hash_combine(hash_value, std::hash<phi::DDim>()(std::get<1>(key)));
// hash layout
hash_value = ir::hash_combine(
hash_value,
std::hash<std::underlying_type<DataLayout>::type>()(
static_cast<std::underlying_type<DataLayout>::type>(
std::hash<std::underlying_type<phi::DataLayout>::type>()(
static_cast<std::underlying_type<phi::DataLayout>::type>(
std::get<2>(key))));
// hash lod
hash_value =
ir::hash_combine(hash_value, std::hash<LoD>()(std::get<3>(key)));
ir::hash_combine(hash_value, std::hash<phi::LoD>()(std::get<3>(key)));
// hash offset
hash_value =
ir::hash_combine(hash_value, std::hash<size_t>()(std::get<4>(key)));
......@@ -146,9 +121,9 @@ struct DenseTensorTypeStorage : public ir::TypeStorage {
/// layout, lod, offset.
///
ir::Type dtype_;
Dim dims_;
DataLayout layout_;
LoD lod_;
phi::DDim dims_;
phi::DataLayout layout_;
phi::LoD lod_;
size_t offset_;
};
......
......@@ -70,67 +70,76 @@ inline ir::Type TransToIrDataType(phi::DataType dtype,
}
}
inline phi::DataLayout TransToPhiDataLayout(
DenseTensorTypeStorage::DataLayout data_layout) {
switch (data_layout) {
case DenseTensorTypeStorage::DataLayout::NHWC:
return phi::DataLayout::NHWC;
case DenseTensorTypeStorage::DataLayout::NCHW:
return phi::DataLayout::NCHW;
case DenseTensorTypeStorage::DataLayout::NCDHW:
return phi::DataLayout::NCDHW;
case DenseTensorTypeStorage::DataLayout::NDHWC:
return phi::DataLayout::NDHWC;
case DenseTensorTypeStorage::DataLayout::ONEDNN:
return phi::DataLayout::ONEDNN;
case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
return phi::DataLayout::SPARSE_COO;
case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
return phi::DataLayout::SPARSE_CSR;
case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
return phi::DataLayout::PSTRING_UNION;
case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
return phi::DataLayout::NUM_DATA_LAYOUTS;
case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
return phi::DataLayout::ALL_LAYOUT;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported ir data layout `%s` when casting it into "
"phi data type.",
static_cast<int>(data_layout)));
}
}
// inline phi::DataLayout TransToPhiDataLayout(
// DenseTensorTypeStorage::DataLayout data_layout) {
// switch (data_layout) {
// case DenseTensorTypeStorage::DataLayout::NHWC:
// return phi::DataLayout::NHWC;
// case DenseTensorTypeStorage::DataLayout::NCHW:
// return phi::DataLayout::NCHW;
// case DenseTensorTypeStorage::DataLayout::NCDHW:
// return phi::DataLayout::NCDHW;
// case DenseTensorTypeStorage::DataLayout::NDHWC:
// return phi::DataLayout::NDHWC;
// case DenseTensorTypeStorage::DataLayout::ONEDNN:
// return phi::DataLayout::ONEDNN;
// case DenseTensorTypeStorage::DataLayout::SPARSE_COO:
// return phi::DataLayout::SPARSE_COO;
// case DenseTensorTypeStorage::DataLayout::SPARSE_CSR:
// return phi::DataLayout::SPARSE_CSR;
// case DenseTensorTypeStorage::DataLayout::PSTRING_UNION:
// return phi::DataLayout::PSTRING_UNION;
// case DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS:
// return phi::DataLayout::NUM_DATA_LAYOUTS;
// case DenseTensorTypeStorage::DataLayout::ALL_LAYOUT:
// return phi::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported ir data layout `%s` when casting it into "
// "phi data type.",
// static_cast<int>(data_layout)));
// }
// }
inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
phi::DataLayout data_layout) {
switch (data_layout) {
case phi::DataLayout::NHWC:
return DenseTensorTypeStorage::DataLayout::NHWC;
case phi::DataLayout::NCHW:
return DenseTensorTypeStorage::DataLayout::NCHW;
case phi::DataLayout::NCDHW:
return DenseTensorTypeStorage::DataLayout::NCDHW;
case phi::DataLayout::NDHWC:
return DenseTensorTypeStorage::DataLayout::NDHWC;
case phi::DataLayout::ONEDNN:
return DenseTensorTypeStorage::DataLayout::ONEDNN;
case phi::DataLayout::SPARSE_COO:
return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
case phi::DataLayout::SPARSE_CSR:
return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
case phi::DataLayout::PSTRING_UNION:
return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
case phi::DataLayout::NUM_DATA_LAYOUTS:
return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
case phi::DataLayout::ALL_LAYOUT:
return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported phi data layout `%s` when casting it into "
"ir data type.",
static_cast<int>(data_layout)));
}
}
// inline DenseTensorTypeStorage::DataLayout TransToIrDataLayout(
// phi::DataLayout data_layout) {
// switch (data_layout) {
// case phi::DataLayout::NHWC:
// return DenseTensorTypeStorage::DataLayout::NHWC;
// case phi::DataLayout::NCHW:
// return DenseTensorTypeStorage::DataLayout::NCHW;
// case phi::DataLayout::NCDHW:
// return DenseTensorTypeStorage::DataLayout::NCDHW;
// case phi::DataLayout::NDHWC:
// return DenseTensorTypeStorage::DataLayout::NDHWC;
// case phi::DataLayout::ONEDNN:
// return DenseTensorTypeStorage::DataLayout::ONEDNN;
// case phi::DataLayout::SPARSE_COO:
// return DenseTensorTypeStorage::DataLayout::SPARSE_COO;
// case phi::DataLayout::SPARSE_CSR:
// return DenseTensorTypeStorage::DataLayout::SPARSE_CSR;
// case phi::DataLayout::PSTRING_UNION:
// return DenseTensorTypeStorage::DataLayout::PSTRING_UNION;
// case phi::DataLayout::NUM_DATA_LAYOUTS:
// return DenseTensorTypeStorage::DataLayout::NUM_DATA_LAYOUTS;
// case phi::DataLayout::ALL_LAYOUT:
// return DenseTensorTypeStorage::DataLayout::ALL_LAYOUT;
// default:
// PADDLE_THROW(phi::errors::Unimplemented(
// "Unsupported phi data layout `%s` when casting it into "
// "ir data type.",
// static_cast<int>(data_layout)));
// }
// }
// inline phi::DenseTensorMeta TransToDenseTensorMeta(
// paddle::dialect::DenseTensorType type) {
// return phi::DenseTensorMeta(TransToPhiDataType(type.dtype()),
// type.dim(),
// type.data_layout(),
// type.lod(),
// type.offset());
// }
struct OpInputInfo {
std::string name;
......@@ -172,5 +181,20 @@ struct OpAttributeInfo {
: name(name), type_name(type_name), data_type(data_type) {}
};
struct OpRunTimeInfo {
std::string infer_meta_func;
std::vector<std::string> infer_meta_param;
std::vector<std::string> kernel_func;
std::vector<std::string> kernel_param;
OpRunTimeInfo(std::string infer_meta_func,
std::vector<std::string> infer_meta_param,
std::vector<std::string> kernel_func,
std::vector<std::string> kernel_param)
: infer_meta_func(infer_meta_func),
infer_meta_param(infer_meta_param),
kernel_func(kernel_func),
kernel_param(kernel_param) {}
};
} // namespace dialect
} // namespace paddle
......@@ -50,7 +50,7 @@ TypeTranslator::TypeTranslator() {
ir::Type dtype =
this->operator[](var_desc.GetDataType())(ctx, var_desc);
DenseTensorTypeStorage::Dim dim = var_desc.GetShape();
DenseTensorTypeStorage::Dim dim = phi::make_ddim(var_desc.GetShape());
DenseTensorTypeStorage::DataLayout layout =
DenseTensorTypeStorage::DataLayout::UNDEFINED;
DenseTensorTypeStorage::LoD lod = {};
......
......@@ -25,7 +25,7 @@ namespace ir {
#define DECLARE_BASE_TYPE_ATTRIBUTE_STORAGE(concrete_storage, base_type) \
struct concrete_storage : public ir::AttributeStorage { \
using ParamKey = bool; \
using ParamKey = base_type; \
\
explicit concrete_storage(const ParamKey &key) { data_ = key; } \
\
......
......@@ -221,7 +221,7 @@
- backward_op : broadcast_tensors_grad
forward : broadcast_tensors (Tensor[] input) -> Tensor[](out)
args : (Tensor[] input, Tensor[] out_grad)
output : Tensor[](input_grad)
output : Tensor[](input_grad){input.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [input]
......
......@@ -235,7 +235,7 @@
- backward_op : einsum_grad
forward : einsum (Tensor[] x, str equation) -> Tensor(out), Tensor[](inner_cache), Tensor[](x_shape)
args : (Tensor[] x_shape, Tensor[] inner_cache, Tensor out_grad, str equation)
output : Tensor[](x_grad){x.size()}
output : Tensor[](x_grad){x_shape.size()}
infer_meta :
func : UnchangedMultiInferMeta
param : [x_shape]
......
......@@ -107,10 +107,9 @@ TEST(program_test, program) {
a_interface->ParameterToVariable(program.GetParameter("a"));
const phi::DenseTensor &a_tensor = a_var->Get<phi::DenseTensor>();
EXPECT_EQ(a_tensor.numel(), 4);
EXPECT_EQ(a_tensor.dims(), phi::DDim(dims.data(), dims.size()));
EXPECT_EQ(a_tensor.dims(), dims);
EXPECT_EQ(a_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(a_tensor.layout(),
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(a_tensor.layout(), data_layout);
EXPECT_EQ(a_tensor.lod(), lod);
EXPECT_EQ(a_tensor.offset(), offset);
for (int64_t i = 0; i < a_tensor.numel(); i++) {
......@@ -137,10 +136,9 @@ TEST(program_test, program) {
b_interface->ParameterToVariable(program.GetParameter("b"));
const phi::DenseTensor &b_tensor = b_var->Get<phi::DenseTensor>();
EXPECT_EQ(b_tensor.numel(), 4);
EXPECT_EQ(b_tensor.dims(), phi::DDim(dims.data(), dims.size()));
EXPECT_EQ(b_tensor.dims(), dims);
EXPECT_EQ(b_tensor.dtype(), paddle::dialect::TransToPhiDataType(fp32_dtype));
EXPECT_EQ(b_tensor.layout(),
paddle::dialect::TransToPhiDataLayout(data_layout));
EXPECT_EQ(b_tensor.layout(), data_layout);
EXPECT_EQ(b_tensor.lod(), lod);
EXPECT_EQ(b_tensor.offset(), offset);
for (int64_t i = 0; i < b_tensor.numel(); i++) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册