未验证 提交 d6ebbd49 编写于 作者: W WangZhen 提交者: GitHub

[NewIR]Gen new ir api for paddle::dialect::xxx (#56241)

上级 0a0f3ef9
......@@ -46,12 +46,36 @@ add_custom_command(
${op_compat_yaml_file}
VERBATIM)
set(api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/api_gen.py)
set(api_header_file ${PD_DIALECT_BINARY_DIR}/pd_api.h)
set(api_source_file ${PD_DIALECT_BINARY_DIR}/pd_api.cc)
set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp)
add_custom_command(
OUTPUT ${api_header_file} ${api_source_file}
COMMAND
${PYTHON_EXECUTABLE} ${api_gen_file} --op_yaml_files ${op_yaml_files}
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces ${op_namespace}
--api_def_h_file ${api_header_file_tmp} --api_def_cc_file
${api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp}
${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp}
${api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2}
${op_backward_yaml_file1} ${op_backward_yaml_file2}
${op_compat_yaml_file}
VERBATIM)
# All source files of pd_dialect, except for the source file of op, which is generated in the compilation directory.
file(GLOB PD_DIALECT_SRCS "*.cc")
cc_library(
pd_dialect
SRCS ${PD_DIALECT_SRCS} ${op_source_file}
SRCS ${PD_DIALECT_SRCS} ${op_source_file} ${api_source_file}
DEPS phi
phi_utils
pd_interface
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import yaml
from op_gen import OpCompatParser, OpInfoParser, to_pascal_case
H_FILE_TEMPLATE = """
#pragma once
#include <vector>
#include "paddle/ir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
{body}
"""
CPP_FILE_TEMPLATE = """
#include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
{body}
"""
NAMESPACE_TEMPLATE = """
namespace {namespace} {{
{body}
}} // namespace {namespace}
"""
API_DECLARE_TEMPLATE = """
{ret_type} {api_name}({args});
"""
API_IMPL_TEMPLATE = """
{ret_type} {api_name}({args}){{
{in_combine}
{compute_op}
{out_slice}
{out_combine}
{return_result}
}}
"""
COMBINE_OP_TEMPLATE = """auto {op_name} = APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>({in_name});"""
COMPUTE_OP_TEMPLATE = """paddle::dialect::{op_class_name} {op_inst_name} = APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::{op_class_name}>({args});"""
API_LIST = ['add_n', 'mean', 'sum', 'divide', 'full', 'tanh_grad', 'mean_grad']
OP_RESULT = 'ir::OpResult'
VECTOR_TYPE = 'ir::VectorType'
def get_op_class_name(op_name):
return to_pascal_case(op_name) + 'Op'
class CodeGen:
def __init__(self) -> None:
self._type_map = {
'paddle::dialect::DenseTensorType': 'ir::OpResult',
'ir::VectorType<paddle::dialect::DenseTensorType>': 'std::vector<ir::OpResult>',
}
def _parse_yaml(self, op_yaml_files, op_compat_yaml_file):
op_compat_parser = OpCompatParser(op_compat_yaml_file)
op_yaml_items = []
for yaml_file in op_yaml_files:
with open(yaml_file, "r") as f:
ops = yaml.safe_load(f)
op_yaml_items = op_yaml_items + ops
op_info_items = []
for op in op_yaml_items:
op_info_items.append(
OpInfoParser(op, op_compat_parser.get_compat(op['name']))
)
return op_info_items
# =====================================
# Gen declare functions
# =====================================
def _gen_api_inputs(self, op_info):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
ret = []
for name, type in zip(name_list, type_list):
ret.append(f'{self._type_map[type]} {name}')
return ', '.join(ret)
def _gen_api_attrs(self, op_info, with_default):
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
default_value_list = op_info.attribute_default_value_list
assert len(name_list) == len(type_list) == len(default_value_list)
ret = []
for name, type, default_value in zip(
name_list, type_list, default_value_list
):
if with_default and default_value is not None:
ret.append(
'{type} {name} = {default_value}'.format(
type=type, name=name, default_value=default_value
)
)
else:
ret.append(f'{type} {name}')
return ', '.join(ret)
def _gen_api_args(self, op_info, with_default_attr):
inputs = self._gen_api_inputs(op_info)
attrs = self._gen_api_attrs(op_info, with_default_attr)
return (inputs + ', ' + attrs).strip(', ')
def _gen_one_declare(self, op_info, op_name):
return API_DECLARE_TEMPLATE.format(
ret_type=OP_RESULT,
api_name=op_name,
args=self._gen_api_args(op_info, True),
)
def _gen_h_file(self, op_info_items, namespaces, h_file_path):
declare_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
if op_name not in API_LIST:
continue
declare_str += self._gen_one_declare(op_info, op_name)
body = declare_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
with open(h_file_path, 'w') as f:
f.write(H_FILE_TEMPLATE.format(body=body))
# =====================================
# Gen impl functions
# =====================================
def _gen_in_combine(self, op_info):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
combine_op = ''
combine_op_list = []
for name, type in zip(name_list, type_list):
if VECTOR_TYPE in type:
op_name = f'{name}_combine_op'
combine_op += COMBINE_OP_TEMPLATE.format(
op_name=op_name, in_name=name
)
combine_op_list.append(op_name)
else:
combine_op_list.append(None)
return combine_op, combine_op_list
def _gen_compute_op_args(self, op_info, in_combine_op_list):
input_name_list = op_info.input_name_list
attribute_name_list = op_info.attribute_name_list
assert len(input_name_list) == len(in_combine_op_list)
ret = []
for input_name, combine_op in zip(input_name_list, in_combine_op_list):
if combine_op is None:
ret.append(input_name)
else:
ret.append(f'{combine_op}.out()')
ret += list(attribute_name_list)
return ', '.join(ret)
def _gen_compute_op(self, op_info, op_name, in_combine_op_list):
op_class_name = to_pascal_case(op_name) + 'Op'
op_inst_name = op_name + '_op'
return (
COMPUTE_OP_TEMPLATE.format(
op_class_name=op_class_name,
op_inst_name=op_inst_name,
args=self._gen_compute_op_args(op_info, in_combine_op_list),
),
op_inst_name,
)
def _gen_out_slice(self):
return ''
def _gen_out_combine(self):
return ''
def _gen_return_result(self, op_info, op_inst_name):
output_name_list = op_info.output_name_list
assert len(output_name_list) == 1
return f'return {op_inst_name}.result(0);'
def _gen_one_impl(self, op_info, op_name):
in_combine, in_combine_op_list = self._gen_in_combine(op_info)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list
)
return API_IMPL_TEMPLATE.format(
ret_type=OP_RESULT,
api_name=op_name,
args=self._gen_api_args(op_info, False),
in_combine=in_combine,
compute_op=compute_op,
out_slice=self._gen_out_slice(),
out_combine=self._gen_out_combine(),
return_result=self._gen_return_result(op_info, op_inst_name),
).replace(' \n', '')
def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
impl_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
if op_name not in API_LIST:
continue
impl_str += self._gen_one_impl(op_info, op_name)
body = impl_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
with open(cpp_file_path, 'w') as f:
f.write(CPP_FILE_TEMPLATE.format(body=body))
def gen_h_and_cpp_file(
self,
op_yaml_files,
op_compat_yaml_file,
namespaces,
h_file_path,
cpp_file_path,
):
if os.path.exists(h_file_path):
os.remove(h_file_path)
if os.path.exists(cpp_file_path):
os.remove(cpp_file_path)
op_info_items = self._parse_yaml(op_yaml_files, op_compat_yaml_file)
self._gen_h_file(op_info_items, namespaces, h_file_path)
self._gen_cpp_file(op_info_items, namespaces, cpp_file_path)
def ParseArguments():
parser = argparse.ArgumentParser(
description='Generate Dialect API Files By Yaml'
)
parser.add_argument('--op_yaml_files', type=str)
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--namespaces', type=str)
parser.add_argument('--api_def_h_file', type=str)
parser.add_argument('--api_def_cc_file', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file
if args.namespaces is not None:
namespaces = args.namespaces.split(",")
api_def_h_file = args.api_def_h_file
api_def_cc_file = args.api_def_cc_file
code_gen = CodeGen()
code_gen.gen_h_and_cpp_file(
op_yaml_files,
op_compat_yaml_file,
namespaces,
api_def_h_file,
api_def_cc_file,
)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
namespace paddle {
namespace dialect {
ir::OpResult add_n(std::vector<ir::OpResult> x) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(x);
paddle::dialect::AddNOp add_n_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::AddNOp>(
combine_op.out());
return add_n_op.out();
}
ir::OpResult mean(ir::OpResult x,
const std::vector<int64_t>& axis,
bool keepdim) {
paddle::dialect::MeanOp mean_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanOp>(
x, axis, keepdim);
return mean_op.out();
}
ir::OpResult sum(ir::OpResult x,
const std::vector<int64_t>& axis,
phi::DataType dtype,
bool keepdim) {
auto sum_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SumOp>(
x, axis, dtype, keepdim);
return sum_op.out();
}
ir::OpResult divide(ir::OpResult x, ir::OpResult y) {
auto divide_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::DivideOp>(x,
y);
return divide_op.out();
}
ir::OpResult full(const std::vector<int64_t>& shape,
float value,
phi::DataType dtype,
const phi::Place& place) {
auto full_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::FullOp>(
shape, value, dtype, place);
return full_op.out();
}
ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) {
paddle::dialect::TanhGradOp tanh_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::TanhGradOp>(
out, grad_out);
return tanh_grad_op.result(0);
}
ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all) {
paddle::dialect::MeanGradOp mean_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::MeanGradOp>(
x, out_grad, axis, keepdim, reduce_all);
return mean_grad_op.result(0);
}
} // namespace dialect
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/ir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace dialect {
ir::OpResult add_n(std::vector<ir::OpResult> x);
ir::OpResult mean(ir::OpResult x,
const std::vector<int64_t>& axis = {},
bool keepdim = false);
ir::OpResult sum(ir::OpResult x,
const std::vector<int64_t>& axis = {},
phi::DataType dtype = phi::DataType::UNDEFINED,
bool keepdim = false);
ir::OpResult divide(ir::OpResult x, ir::OpResult y);
ir::OpResult full(const std::vector<int64_t>& shape,
float value,
phi::DataType dtype = phi::DataType::FLOAT32,
const phi::Place& place = phi::CPUPlace());
ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out);
ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
const std::vector<int64_t>& axis = {},
bool keepdim = false,
bool reduce_all = false);
} // namespace dialect
} // namespace paddle
......@@ -4,3 +4,4 @@ cc_library(
primitive_vjp_experimental
SRCS ${VJP_SRCS}
DEPS primitive_backend_static_experimental)
add_dependencies(primitive_vjp_experimental pd_dialect)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册