未验证 提交 048ce0c4 编写于 作者: W WangZhen 提交者: GitHub

[NewIR]Gen python c apis for new ir (#56571)

* Gen all Apis

* Gen python c apis

* Add empty file

* Fix cast data type

* Fix None dtype
上级 e9bb3126
......@@ -97,3 +97,4 @@ python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/fluid/ir_adaptor/translator/op_compat_info.cc
paddle/fluid/pybind/static_op_function.*
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import re
from api_gen import (
NAMESPACE_TEMPLATE,
OP_RESULT,
PD_MANUAL_OP_LIST,
VECTOR_TYPE,
CodeGen,
)
H_FILE_TEMPLATE = """
#pragma once
#include <Python.h>
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
{body}
"""
API_DECLARE_TEMPLATE = """
PyObject *static_api_{name}(PyObject *self, PyObject *args, PyObject *kwargs);
"""
CPP_FILE_TEMPLATE = """
#include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/enforce.h"
{body}
"""
NO_MUTABLE_ATTR_API_IMPL_TEMPLATE = """
PyObject *static_api_{api_name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
try {{
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
{inputs}
// Parse Attributes
{attrs}
// Call ir static api
auto static_api_out = paddle::dialect::{api_name}({args});
return ToPyObject(static_api_out);
}} catch (...) {{
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
INPUT_TEMPLATE = """
PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index});
auto {name} = {cast_func}({name}_obj, "{api_name}", {index});"""
NO_MUTABLE_ATTR_CAST_TEMPLATE = """
PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index});
{type} {name} = {cast_func}({name}_obj, "{api_name}", {index});"""
MUTABLE_ATTR_API_IMPL_TEMPLATE = """
PyObject *static_api_{api_name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
try {{
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
{inputs}
// Parse Attributes
{attrs_py_obj}
// Check for mutable attrs
bool has_mutable_attr = false;
{check_mutable_attrs}
if (has_mutable_attr){{
{cast_attrs_with_mutable}
// Call ir static api
auto static_api_out = paddle::dialect::{api_name}({args_with_mutable_attrs});
return ToPyObject(static_api_out);
}} else {{
{cast_attrs_without_mutable}
// Call ir static api
auto static_api_out = paddle::dialect::{api_name}({args_without_mutable_attrs});
return ToPyObject(static_api_out);
}}
}} catch (...) {{
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
CHECK_MUTABLE_ATTR_TEMPLATE = """
if (PyObject_CheckIROpResult({name}_obj)){{
has_mutable_attr = true;
}}"""
MUTABLE_ATTR_OBJ_TEMPLATE = """
PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index});"""
MUTABLE_ATTR_CAST_TEMPLATE = """
{type} {name} = {cast_func}({name}_obj, "{api_name}", {index});"""
TYPE_TO_FUNC_MAP = {
"bool": "CastPyArg2Boolean",
"int": "CastPyArg2Int",
"long": "CastPyArg2Long",
"int64_t": "CastPyArg2Long",
"float": "CastPyArg2Float",
"double": "CastPyArg2Double",
"std::string": "CastPyArg2String",
"std::vector<bool>": "CastPyArg2Booleans",
"std::vector<int>": "CastPyArg2Ints",
"std::vector<long>": "CastPyArg2Longs",
"std::vector<int64_t>": "CastPyArg2Longs",
"std::vector<float>": "CastPyArg2Floats",
"std::vector<double>": "CastPyArg2Float64s",
"std::vector<std::string>": "CastPyArg2Strings",
"paddle::experimental::Scalar": "CastPyArg2Scalar",
"std::vector<phi::Scalar>": "CastPyArg2ScalarArray",
"paddle::experimental::IntArray": "CastPyArg2IntArray",
"paddle::Place": "CastPyArg2Place",
"Place": "CastPyArg2Place",
"phi::DataType": "CastPyArg2DataTypeDirectly",
}
class PythonCCodeGen(CodeGen):
def __init__(self) -> None:
super().__init__()
def _gen_one_declare(self, op_name):
return API_DECLARE_TEMPLATE.format(name=op_name)
def _gen_h_file(self, op_info_items, namespaces, h_file_path):
declare_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue
declare_str += self._gen_one_declare(op_name)
body = declare_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
with open(h_file_path, 'w') as f:
f.write(H_FILE_TEMPLATE.format(body=body))
def _gen_inputs(self, op_info, op_name):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
ret = ''
for i, (name, type) in enumerate(zip(name_list, type_list)):
cast_func = (
'CastPyArg2VectorOfOpResult'
if VECTOR_TYPE in type
else 'CastPyArg2OpResult'
)
ret += INPUT_TEMPLATE.format(
name=name, index=i, cast_func=cast_func, api_name=op_name
)
return ret
def _gen_attrs_without_mutable(self, op_info, op_name):
input_size = len(op_info.input_name_list)
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
assert len(name_list) == len(type_list)
ret = ''
for i, (name, type) in enumerate(zip(name_list, type_list)):
type = type.replace('const ', '').replace('&', '')
cast_func = TYPE_TO_FUNC_MAP[type]
ret += NO_MUTABLE_ATTR_CAST_TEMPLATE.format(
name=name,
index=input_size + i,
type=type,
cast_func=cast_func,
api_name=op_name,
)
return ret
def _gen_attrs_py_obj_with_mutable(self, op_info):
input_size = len(op_info.input_name_list)
name_list = op_info.attribute_name_list
ret = ''
for i, name in enumerate(name_list):
ret += MUTABLE_ATTR_OBJ_TEMPLATE.format(
name=name, index=input_size + i
)
return ret
def _gen_check_mutable_attrs(self, op_info):
name_list = op_info.mutable_attribute_name_list
ret = ''
for name in name_list:
ret += CHECK_MUTABLE_ATTR_TEMPLATE.format(name=name)
return ret
def _gen_cast_attrs(self, op_info, op_name, with_mutable):
input_size = len(op_info.input_name_list)
attr_name_list = op_info.attribute_name_list
attr_type_list = op_info.attribute_build_arg_type_list
mutable_attr_name_list = op_info.mutable_attribute_name_list
assert len(attr_name_list) == len(attr_type_list)
ret = ''
for i, (name, type) in enumerate(zip(attr_name_list, attr_type_list)):
type = type.replace('const ', '').replace('&', '')
cast_func = TYPE_TO_FUNC_MAP[type]
if with_mutable and name in mutable_attr_name_list:
type = OP_RESULT
cast_func = 'CastPyArg2OpResult'
ret += MUTABLE_ATTR_CAST_TEMPLATE.format(
type=type,
name=name,
cast_func=cast_func,
api_name=op_name,
index=input_size + i,
)
return ret
def _gen_one_impl(self, op_info, op_name):
input_name_list = op_info.input_name_list
attr_name_list = op_info.attribute_name_list
mutable_attr_name_list = op_info.mutable_attribute_name_list
no_mutable_attr_name_list = op_info.non_mutable_attribute_name_list
if len(mutable_attr_name_list) > 0:
ret = MUTABLE_ATTR_API_IMPL_TEMPLATE.format(
api_name=op_name,
inputs=self._gen_inputs(op_info, op_name),
attrs_py_obj=self._gen_attrs_py_obj_with_mutable(op_info),
check_mutable_attrs=self._gen_check_mutable_attrs(op_info),
cast_attrs_with_mutable=self._gen_cast_attrs(
op_info, op_name, True
),
args_with_mutable_attrs=', '.join(
input_name_list
+ mutable_attr_name_list
+ no_mutable_attr_name_list
),
cast_attrs_without_mutable=self._gen_cast_attrs(
op_info, op_name, False
),
args_without_mutable_attrs=', '.join(
input_name_list + attr_name_list
),
)
else:
ret = NO_MUTABLE_ATTR_API_IMPL_TEMPLATE.format(
api_name=op_name,
inputs=self._gen_inputs(op_info, op_name),
attrs=self._gen_attrs_without_mutable(op_info, op_name),
args=', '.join(input_name_list + attr_name_list),
)
ret = re.sub(r' +\n', '', ret)
return ret
def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
impl_str = ''
for op_info in op_info_items:
for op_name in op_info.op_phi_name:
# NOTE:When infer_meta_func is None, the Build() function generated in pd_op
# is wrong, so temporarily skip the automatic generation of these APIs
if (
op_info.infer_meta_func is None
and op_name not in PD_MANUAL_OP_LIST
):
continue
impl_str += self._gen_one_impl(op_info, op_name)
body = impl_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
with open(cpp_file_path, 'w') as f:
f.write(CPP_FILE_TEMPLATE.format(body=body))
def ParseArguments():
parser = argparse.ArgumentParser(
description='Generate Dialect Python C Files By Yaml'
)
parser.add_argument('--op_yaml_files', type=str)
parser.add_argument('--op_compat_yaml_file', type=str)
parser.add_argument('--namespaces', type=str)
parser.add_argument('--python_c_def_h_file', type=str)
parser.add_argument('--python_c_def_cc_file', type=str)
return parser.parse_args()
if __name__ == '__main__':
args = ParseArguments()
op_yaml_files = args.op_yaml_files.split(",")
op_compat_yaml_file = args.op_compat_yaml_file
if args.namespaces is not None:
namespaces = args.namespaces.split(",")
python_c_def_h_file = args.python_c_def_h_file
python_c_def_cc_file = args.python_c_def_cc_file
code_gen = PythonCCodeGen()
code_gen.gen_h_and_cpp_file(
op_yaml_files,
op_compat_yaml_file,
namespaces,
python_c_def_h_file,
python_c_def_cc_file,
)
......@@ -84,6 +84,35 @@ add_custom_command(
${op_compat_yaml_file}
VERBATIM)
set(python_c_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/op_generator/python_c_gen.py)
set(python_c_header_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.h)
set(python_c_source_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/static_op_function.cc)
set(python_c_header_file_tmp ${python_c_header_file}.tmp)
set(python_c_source_file_tmp ${python_c_source_file}.tmp)
add_custom_command(
OUTPUT ${python_c_header_file} ${python_c_source_file}
COMMAND
${PYTHON_EXECUTABLE} ${python_c_gen_file} --op_yaml_files ${op_yaml_files}
--op_compat_yaml_file ${op_compat_yaml_file} --namespaces "paddle,pybind"
--python_c_def_h_file ${python_c_header_file_tmp} --python_c_def_cc_file
${python_c_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_header_file_tmp}
${python_c_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${python_c_source_file_tmp}
${python_c_source_file}
COMMENT "copy_if_different ${python_c_header_file} ${python_c_source_file}"
DEPENDS ${python_c_gen_file} ${op_forward_yaml_file1} ${op_forward_yaml_file2}
${op_backward_yaml_file1} ${op_backward_yaml_file2}
${op_compat_yaml_file}
VERBATIM)
add_custom_target(static_op_function_gen ALL DEPENDS ${python_c_header_file}
${python_c_source_file})
cc_library(
pd_dialect_core
SRCS pd_attribute.cc pd_type.cc
......
......@@ -136,6 +136,10 @@ bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj) {
bool PyObject_CheckStr(PyObject* obj) { return PyUnicode_Check(obj); }
bool PyObject_CheckIROpResult(PyObject* obj) {
return PyObject_TypeCheck(obj, g_ir_opresult_pytype);
}
bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos) {
if (obj == Py_None) {
return false; // To be compatible with QA integration testing. Some
......@@ -647,6 +651,10 @@ paddle::framework::proto::VarType::Type CastPyArg2ProtoType(PyObject* obj,
paddle::DataType CastPyArg2DataTypeDirectly(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
return phi::DataType::UNDEFINED;
}
paddle::DataType dtype;
if (PyObject_TypeCheck(obj, g_data_type_pytype)) {
dtype = ::pybind11::handle(obj).cast<paddle::DataType>();
......@@ -883,6 +891,16 @@ PyObject* ToPyObject(const ir::OpResult& value) {
return obj.ptr();
}
PyObject* ToPyObject(const std::vector<ir::OpResult>& value) {
PyObject* result = PyList_New((Py_ssize_t)value.size());
for (size_t i = 0; i < value.size(); i++) {
PyList_SET_ITEM(result, static_cast<Py_ssize_t>(i), ToPyObject(value[i]));
}
return result;
}
#ifdef PADDLE_WITH_DISTRIBUTE
PyObject* ToPyObject(const phi::distributed::DistTensor* value) {
auto obj = ::pybind11::cast(value, py::return_value_policy::reference);
......@@ -1452,8 +1470,8 @@ paddle::experimental::Scalar CastNumpy2Scalar(PyObject* obj,
}
}
ir::OpResult CastPyArg2OpResult(const std::string& op_type,
PyObject* obj,
ir::OpResult CastPyArg2OpResult(PyObject* obj,
const std::string& op_type,
size_t arg_pos) {
if (PyObject_TypeCheck(obj, g_ir_opresult_pytype)) {
return ::pybind11::handle(obj).cast<ir::OpResult>();
......@@ -1467,8 +1485,8 @@ ir::OpResult CastPyArg2OpResult(const std::string& op_type,
}
}
std::vector<ir::OpResult> CastPyArg2VectorOfOpResult(const std::string& op_type,
PyObject* obj,
std::vector<ir::OpResult> CastPyArg2VectorOfOpResult(PyObject* obj,
const std::string& op_type,
size_t arg_pos) {
std::vector<ir::OpResult> result_list;
if (PyList_Check(obj)) {
......
......@@ -58,6 +58,7 @@ int TensorDtype2NumpyDtype(phi::DataType dtype);
bool PyObject_CheckLongOrConvertToLong(PyObject** obj);
bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj);
bool PyObject_CheckStr(PyObject* obj);
bool PyObject_CheckIROpResult(PyObject* obj);
bool CastPyArg2AttrBoolean(PyObject* obj, ssize_t arg_pos);
int CastPyArg2AttrInt(PyObject* obj, ssize_t arg_pos);
int64_t CastPyArg2AttrLong(PyObject* obj, ssize_t arg_pos);
......@@ -76,11 +77,11 @@ std::vector<int> CastPyArg2VectorOfInt(PyObject* obj, size_t arg_pos);
std::vector<int64_t> CastPyArg2VectorOfInt64(PyObject* obj, size_t arg_pos);
std::vector<size_t> CastPyArg2VectorOfSize_t(PyObject* obj, size_t arg_pos);
std::vector<float> CastPyArg2VectorOfFloat(PyObject* obj, size_t arg_pos);
ir::OpResult CastPyArg2OpResult(const std::string& op_type,
PyObject* obj,
ir::OpResult CastPyArg2OpResult(PyObject* obj,
const std::string& op_type,
size_t arg_pos);
std::vector<ir::OpResult> CastPyArg2VectorOfOpResult(const std::string& op_type,
PyObject* obj,
std::vector<ir::OpResult> CastPyArg2VectorOfOpResult(PyObject* obj,
const std::string& op_type,
size_t arg_pos);
std::vector<std::vector<size_t>> CastPyArg2VectorOfVectorOfSize_t(
PyObject* obj, size_t arg_pos);
......@@ -135,6 +136,7 @@ PyObject* ToPyObject(const paddle::framework::Vocab& value);
PyObject* ToPyObject(std::shared_ptr<egr::GradNodeBase> grad_node);
PyObject* ToPyObject(const ir::OpResult& value);
PyObject* ToPyObject(const std::vector<ir::OpResult>& value);
class PyTensorHook : public egr::TensorHook {
public:
......
......@@ -22,6 +22,7 @@ if __name__ == "__main__":
empty_files = [os.path.join(pybind_dir, "eager_legacy_op_function.cc")]
empty_files.append(os.path.join(pybind_dir, "eager_op_function.cc"))
empty_files.append(os.path.join(pybind_dir, "static_op_function.cc"))
for path in empty_files:
if not os.path.exists(path):
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/pybind/static_op_function.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_api.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/core/enforce.h"
namespace paddle {
namespace pybind {
PyObject *static_api_add_n(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add add_n op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2VectorOfOpResult("add_n", x_obj, 0);
// Parse Attributes if needed
// Call ir static api
auto out = paddle::dialect::add_n(x);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add mean op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2OpResult("mean", x_obj, 0);
// Parse Attributes if needed
PyObject *axis_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::IntArray axis =
CastPyArg2IntArray(axis_obj, "mean", 1);
PyObject *keepdim_obj = PyTuple_GET_ITEM(args, 2);
bool keepdim = CastPyArg2Boolean(keepdim_obj, "mean", 2);
// Call ir static api
auto out = paddle::dialect::mean(x, axis.GetData(), keepdim);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add sum op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2OpResult("sum", x_obj, 0);
// Parse Attributes if needed
PyObject *axis_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::IntArray axis =
CastPyArg2IntArray(axis_obj, "sum", 1);
PyObject *dtype_obj = PyTuple_GET_ITEM(args, 2);
phi::DataType dtype = CastPyArg2DataType(dtype_obj, "sum", 2);
PyObject *keepdim_obj = PyTuple_GET_ITEM(args, 3);
bool keepdim = CastPyArg2Boolean(keepdim_obj, "sum", 3);
// Call ir static api
auto out = paddle::dialect::sum(x, axis.GetData(), dtype, keepdim);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add divide op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2OpResult("divide", x_obj, 0);
PyObject *y_obj = PyTuple_GET_ITEM(args, 1);
auto y = CastPyArg2OpResult("divide", y_obj, 1);
// Call ir static api
auto out = paddle::dialect::divide(x, y);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add concat op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2VectorOfOpResult("concat", x_obj, 0);
PyObject *axis_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "concat", 1);
// Call ir static api
auto out = paddle::dialect::concat(x, axis.to<float>());
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add full op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Parse Attributes if needed
PyObject *shape_obj = PyTuple_GET_ITEM(args, 0);
paddle::experimental::IntArray shape =
CastPyArg2IntArray(shape_obj, "full", 0);
PyObject *value_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::Scalar value = CastPyArg2Scalar(value_obj, "full", 1);
PyObject *dtype_obj = PyTuple_GET_ITEM(args, 2);
phi::DataType dtype = CastPyArg2DataTypeDirectly(dtype_obj, "full", 2);
PyObject *place_obj = PyTuple_GET_ITEM(args, 3);
paddle::Place place = CastPyArg2Place(place_obj, "full", 3);
// Call ir static api
auto out =
paddle::dialect::full(shape.GetData(), value.to<float>(), dtype, place);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_data(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add data op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Parse Attributes if needed
PyObject *name_obj = PyTuple_GET_ITEM(args, 0);
std::string name = CastPyArg2String(name_obj, "data", 0);
PyObject *shape_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::IntArray shape =
CastPyArg2IntArray(shape_obj, "data", 1);
PyObject *dtype_obj = PyTuple_GET_ITEM(args, 2);
phi::DataType dtype = CastPyArg2DataTypeDirectly(dtype_obj, "data", 2);
PyObject *place_obj = PyTuple_GET_ITEM(args, 3);
paddle::Place place = CastPyArg2Place(place_obj, "data", 3);
// Call ir static api
auto out = paddle::dialect::data(name, shape.GetData(), dtype, place);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_fetch(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add fetch op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2OpResult("fetch", x_obj, 0);
// Parse Attributes if needed
PyObject *name_obj = PyTuple_GET_ITEM(args, 1);
std::string name = CastPyArg2String(name_obj, "fetch", 1);
PyObject *col_obj = PyTuple_GET_ITEM(args, 2);
int col = CastPyArg2Int(col_obj, "fetch", 2);
// Call ir static api
auto out = paddle::dialect::fetch(x, name, col);
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
} // namespace pybind
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <Python.h>
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
namespace paddle {
namespace pybind {
PyObject *static_api_add_n(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_data(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_fetch(PyObject *self, PyObject *args, PyObject *kwargs);
} // namespace pybind
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册