未验证 提交 59b2ad39 编写于 作者: W WangZhen 提交者: GitHub

[NewIR]Gen ops_api.cc for static mode (#56653)

上级 1692af99
......@@ -98,3 +98,4 @@ paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/fluid/ir_adaptor/translator/op_compat_info.cc
paddle/fluid/pybind/static_op_function.*
paddle/fluid/pybind/ops_api.cc
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from api_gen import NAMESPACE_TEMPLATE, PD_MANUAL_OP_LIST, CodeGen
CPP_FILE_TEMPLATE = """
#include <pybind11/pybind11.h>
#include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/phi/core/enforce.h"
{body}
"""
NAMESPACE_INNER_TEMPLATE = """
{function_impl}
static PyMethodDef OpsAPI[] = {{
{ops_api}
{{nullptr, nullptr, 0, nullptr}}
}};
void BindOpsAPI(pybind11::module *module) {{
if (PyModule_AddFunctions(module->ptr(), OpsAPI) < 0) {{
PADDLE_THROW(phi::errors::Fatal("Add C++ api to core.ops failed!"));
}}
}}
"""
FUNCTION_IMPL_TEMPLATE = """
static PyObject *{name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
return static_api_{name}(self, args, kwargs);
}}"""
OPS_API_TEMPLATE = """
{{"{name}", (PyCFunction)(void (*)(void)){name}, METH_VARARGS | METH_KEYWORDS, "C++ interface function for {name}."}},"""
class OpsAPIGen(CodeGen):
def __init__(self) -> None:
super().__init__()
def _gen_one_function_impl(self, name):
return FUNCTION_IMPL_TEMPLATE.format(name=name)
def _gen_one_ops_api(self, name):
return OPS_API_TEMPLATE.format(name=name)
def gen_cpp_file(
self, op_yaml_files, op_compat_yaml_file, namespaces, cpp_file_path
):
if os.path.exists(cpp_file_path):
os.remove(cpp_file_path)
op_info_items = self._parse_yaml(op_yaml_files, op_compat_yaml_file)
function_impl_str = ''
ops_api_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue
function_impl_str += self._gen_one_function_impl(op_name)
ops_api_str += self._gen_one_ops_api(op_name)
inner_body = NAMESPACE_INNER_TEMPLATE.format(
function_impl=function_impl_str, ops_api=ops_api_str
)
body = inner_body
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
with open(cpp_file_path, 'w') as f:
f.write(CPP_FILE_TEMPLATE.format(body=body))
def ParseArguments():
parser = argparse.ArgumentParser(
description='Generate Dialect Python C Files By Yaml'
)
parser.add_argument('--op_yaml_files', type=str)
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--namespaces', type=str)
parser.add_argument('--ops_api_file', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file
if args.namespaces is not None:
namespaces = args.namespaces.split(",")
ops_api_file = args.ops_api_file
code_gen = OpsAPIGen()
code_gen.gen_cpp_file(
op_yaml_files, op_compat_yaml_file, namespaces, ops_api_file
)
......@@ -113,6 +113,32 @@ add_custom_command(
add_custom_target(static_op_function_gen ALL DEPENDS ${python_c_header_file}
${python_c_source_file})
set(ops_api_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/ops_api_gen.py)
set(ops_api_source_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/ops_api.cc)
set(ops_api_source_file_tmp ${ops_api_source_file}.tmp)
add_custom_command(
OUTPUT ${ops_api_source_file}
COMMAND
${PYTHON_EXECUTABLE} ${ops_api_gen_file} --op_yaml_files ${op_yaml_files}
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind"
--ops_api_file ${ops_api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${ops_api_source_file_tmp}
${ops_api_source_file}
COMMENT "copy_if_different ${ops_api_source_file}"
DEPENDS ${ops_api_gen_file}
${op_forward_yaml_file1}
${op_forward_yaml_file2}
${op_backward_yaml_file1}
${op_backward_yaml_file2}
${op_compat_yaml_file}
${python_c_header_file}
${python_c_source_file}
VERBATIM)
add_custom_target(ops_api_gen ALL DEPENDS ${ops_api_source_file})
cc_library(
pd_dialect_core
SRCS pd_attribute.cc pd_type.cc
......
......@@ -536,6 +536,7 @@ if(WITH_PYTHON)
# PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON)
# endif()
add_dependencies(${SHARD_LIB_NAME} ops_api_gen)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
add_dependencies(${SHARD_LIB_NAME} legacy_eager_codegen)
add_dependencies(${SHARD_LIB_NAME} eager_legacy_op_function_generator_cmd)
......
......@@ -664,7 +664,7 @@ paddle::DataType CastPyArg2DataTypeDirectly(PyObject* obj,
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s: argument (position %d) must be "
"one of core.VarDesc.VarType, "
"one of paddle::DataType, "
"but got %s",
op_type,
arg_pos + 1,
......
......@@ -23,6 +23,7 @@ if __name__ == "__main__":
empty_files = [os.path.join(pybind_dir, "eager_legacy_op_function.cc")]
empty_files.append(os.path.join(pybind_dir, "eager_op_function.cc"))
empty_files.append(os.path.join(pybind_dir, "static_op_function.cc"))
empty_files.append(os.path.join(pybind_dir, "ops_api.cc"))
for path in empty_files:
if not os.path.exists(path):
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <pybind11/pybind11.h>
#include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
static PyObject *add_n(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_add_n(self, args, kwargs);
}
static PyObject *mean(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_mean(self, args, kwargs);
}
static PyObject *sum(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_sum(self, args, kwargs);
}
static PyObject *full(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_full(self, args, kwargs);
}
static PyObject *divide(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_divide(self, args, kwargs);
}
static PyObject *data(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_data(self, args, kwargs);
}
static PyObject *fetch(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_fetch(self, args, kwargs);
}
static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_concat(self, args, kwargs);
}
static PyObject *split(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_split(self, args, kwargs);
}
static PyMethodDef OpsAPI[] = {{"add_n",
(PyCFunction)(void (*)(void))add_n,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for add_n."},
{"mean",
(PyCFunction)(void (*)(void))mean,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for mean."},
{"sum",
(PyCFunction)(void (*)(void))sum,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for sum."},
{"divide",
(PyCFunction)(void (*)(void))divide,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for divide."},
{"concat",
(PyCFunction)(void (*)(void))concat,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for concat."},
{"full",
(PyCFunction)(void (*)(void))full,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for full."},
{"split",
(PyCFunction)(void (*)(void))split,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for split."},
{"data",
(PyCFunction)(void (*)(void))data,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for data."},
{"fetch",
(PyCFunction)(void (*)(void))fetch,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for fetch."},
{nullptr, nullptr, 0, nullptr}};
void BindOpsAPI(pybind11::module *module) {
if (PyModule_AddFunctions(module->ptr(), OpsAPI) < 0) {
PADDLE_THROW(phi::errors::Fatal("Add C++ api to core.ops failed!"));
}
}
} // namespace pybind
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册