提交 62b15566 编写于 作者: J jim19930609

Added python-c code generation for final state Eager Dygraph

上级 ca743508
cc_library(scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info)
#cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
#add_dependencies(final_dygraph_node eager_final_state_codegen)
cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
add_dependencies(final_dygraph_node eager_final_state_codegen)
cc_library(eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node)
#cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
#add_dependencies(final_dygraph_function eager_final_state_codegen)
cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
add_dependencies(final_dygraph_function eager_final_state_codegen)
#add_subdirectory(final_state_generator)
add_subdirectory(final_state_generator)
set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag)
......
......@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path}
VERBATIM
)
set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h")
set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h")
add_custom_target(eager_final_state_python_c_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--output_path=${tmp_python_c_output_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path} ${python_c_output_path}
VERBATIM
)
......@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
return ret
def GetGradNodeName(string):
return f"FinalGradNode{string}"
def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def GetAutoGradMetaName(string):
return f"{string}_autograd_meta"
......@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
def ParseYamlReturns(string):
# Example: Tensor, Tensor
# list = [ [ret_type, orig_position], ...]
# list = [ ["", ret_type, orig_position], ...]
returns_list = []
returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)):
ret = returns[i]
returns_list.append([ret, i])
returns_list.append(["", ret, i])
return returns_list
......@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
assert orig_attr_pos == forward_attr_pos
for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][0]
orig_return_pos = orig_forward_returns_list[i][1]
orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2]
......@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
RemoveConstAndReference(atype), saved_attr_name, default_val)
# End: SetAttributes & Attribute Members
grad_node_name = GetGradNodeName(fwd_api_name)
NODE_DECLARATION_TEMPLATE = """
class GradNode{} : public egr::GradNodeBase {{
class {} : public egr::GradNodeBase {{
public:
GradNode{}() : egr::GradNodeBase() {{}}
GradNode{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
{}() : egr::GradNodeBase() {{}}
{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~GradNode{}() override = default;
~{}() override = default;
virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override;
......@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
}};
"""
node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
forward_op_name, forward_op_name, forward_op_name, forward_op_name,
grad_node_name, grad_node_name, grad_node_name, grad_node_name,
set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str)
......@@ -503,8 +512,13 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr) )"
for _, (_, fwd_position,
for _, (ttype, fwd_position,
grad_api_position) in backward_grad_input_map.items():
if IsPlainTensorType(ttype):
grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}][0] )"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}] )"
......@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str += f"returns[{fwd_position}] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[{grad_api_position}] );\n"
returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name)
FUNCTION_TEMPLATE = """
std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
std::vector<std::vector<egr::EagerTensor>> {}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
// Call grad_api function
auto grad_api_returns = paddle::experimental::{}({});
{}
......@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
"""
node_definition_str = FUNCTION_TEMPLATE.format(
fwd_api_name, bwd_api_name, grad_api_args_str, returns_str)
grad_node_name, bwd_api_name, grad_api_args_str, returns_str)
return node_definition_str
......@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# Node Construction
num_bwd_inputs = len(backward_grad_input_map.keys())
num_bwd_outputs = len(backward_grad_output_map.keys())
node_construction_str = f" auto grad_node = std::make_shared<GradNode{fwd_api_name}>({num_bwd_inputs}, {num_bwd_outputs});"
grad_node_name = GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
# SetAttributes
set_attributes_list = []
......@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
backward_grad_output_map, backward_attrs_list)
FORWARD_FUNCTION_TEMPLATE = """
{} {}_dygraph_function({}) {{
{} {}({}) {{
// Forward API Call
{}
......@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
}}
"""
forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, fwd_api_name, inputs_args_str, forward_call_str,
returns_str, node_creation_str)
forward_function_declaration_str = f"{returns_type_str} {fwd_api_name}_dygraph_function({inputs_args_str});"
returns_type_str, forward_function_name, inputs_args_str,
forward_call_str, returns_str, node_creation_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});"
return forward_function_str, forward_function_declaration_str
def FakeMatmulGradAPI():
fake_matmul_grad_str = """
namespace paddle {
namespace experimental {
std::vector<std::vector<Tensor>> matmul_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
bool transpose_x,
bool transpose_y) {
std::vector<std::vector<Tensor>> ret;
return ret;
}
}
}
"""
return fake_matmul_grad_str
def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """
#include "glog/logging.h"
......@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
"""
file_contents += FakeMatmulGradAPI()
file_contents += node_definition_str
with open(filepath, 'a') as f:
f.write(file_contents)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
"int": "CastPyArg2Int",
"long": "CastPyArg2Long",
"float": "CastPyArg2Float",
"string": "CastPyArg2String",
"bool[]": "CastPyArg2Booleans",
"int[]": "CastPyArg2Ints",
"long[]": "CastPyArg2Longs",
"float[]": "CastPyArg2Floats",
"double[]": "CastPyArg2Float64s",
"string[]": "CastPyArg2Strings"
}
atype_to_cxx_type = {
"bool": "bool",
"int": "int",
"long": "long",
"float": "float",
"string": "std::string",
"bool[]": "std::vector<bool>",
"int[]": "std::vector<int>",
"long[]": "std::vector<long>",
"float[]": "std::vector<float>",
"double[]": "std::vector<double>",
"string[]": "std::vector<std::string>"
}
def ParseArguments():
parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser')
parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--output_path', type=str)
args = parser.parse_args()
return args
def GetCxxType(atype):
if atype not in atype_to_cxx_type.keys():
assert False
return atype_to_cxx_type[atype]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
assert False
return atype_to_parsing_function[atype]
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_input_tensors = len(forward_inputs_position_map.keys())
dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetEagerTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = " paddle::framework::AttributeMap attrs;\n"
# Get Attributes
for name, atype, _, pos in forward_attrs_list:
parsing_function = FindParsingFunctionFromAttributeType(atype)
cxx_type = GetCxxType(atype)
key = f"{name}"
parse_attributes_str += f" PyObject* {name}_obj = PyTuple_GET_ITEM(args, {pos});\n"
parse_attributes_str += f" {cxx_type} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n"
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
PYTHON_C_FUNCTION_TEMPLATE = """
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
return python_c_function_str, python_c_function_reg_str
def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str)
return python_c_str
def GeneratePythonCFile(filepath, python_c_str):
with open(filepath, 'a') as f:
f.write(python_c_str)
if __name__ == "__main__":
args = ParseArguments()
api_yaml_path = args.api_yaml_path
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ", forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)
python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)
print("Generated Python-C Codes: ", python_c_str)
output_path = args.output_path
for path in [output_path]:
if os.path.exists(path):
os.remove(path)
GeneratePythonCFile(output_path, python_c_str)
......@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
paddle::experimental::Tensor EagerUtils::SyncToPtenTensors(
const egr::EagerTensor& tensor) {
if (!tensor.initialized()) {
const_cast<EagerTensor*>(&tensor)->SyncToTensor();
}
return *tensor.Tensor().get();
}
......@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
size_t num = tensors.size();
res.reserve(num);
for (size_t i = 0; i < num; i++) {
if (!tensors[i].initialized()) {
const_cast<EagerTensor*>(&(tensors[i]))->SyncToTensor();
}
res.push_back(*tensors[i].Tensor().get());
}
return res;
......
......@@ -151,7 +151,7 @@ if(WITH_PYTHON)
set(tmp_eager_impl_file ${eager_impl_file}.tmp)
set(OP_IMPL_DEPS op_function_generator)
set(EAGER_OP_IMPL_DEPS eager_op_function_generator)
set(EAGER_OP_IMPL_DEPS eager_op_function_generator eager_final_state_python_c_codegen)
if(WIN32)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
......@@ -275,7 +275,7 @@ if(WITH_PYTHON)
if(NOT ON_INFER)
cc_library(paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python)
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python)
add_dependencies(paddle_eager eager_codegen)
add_dependencies(paddle_eager eager_op_function_generator_cmd)
list(APPEND PYBIND_DEPS paddle_eager)
......
......@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
std::vector<std::string> headers{
"\"pybind11/detail/common.h\"",
"\"paddle/fluid/pybind/eager_final_state_op_function_impl.h\"",
"\"paddle/fluid/pybind/op_function_common.h\"",
"\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"",
......@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< " if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< "}\n\n"
<< "} // namespace pybind\n"
<< "} // namespace paddle\n";
......
......@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some
return false; // To be compatible with QA integration testing. Some
// test case pass in None.
} else if (obj == Py_True) {
attrs[key] = true;
return true;
} else if (obj == Py_False) {
attrs[key] = false;
return false;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -118,14 +116,20 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return false;
}
void CastPyArg2AttrInt(PyObject* obj,
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT
return (int)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -133,14 +137,21 @@ void CastPyArg2AttrInt(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0;
}
void CastPyArg2AttrLong(PyObject* obj,
void CastPyArg2AttrInt(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT
return (int64_t)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -148,14 +159,21 @@ void CastPyArg2AttrLong(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0;
}
void CastPyArg2AttrFloat(PyObject* obj,
void CastPyArg2AttrLong(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
}
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT
return (float)PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -163,17 +181,24 @@ void CastPyArg2AttrFloat(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0.0;
}
void CastPyArg2AttrString(PyObject* obj,
void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
}
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckString(obj)) {
Py_ssize_t size;
const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size); // NOLINT
return std::string(data, (size_t)size); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return "";
}
void CastPyArg2AttrBooleans(PyObject* obj,
void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<bool> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
......@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) {
......@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrInts(PyObject* obj,
void CastPyArg2AttrBooleans(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrLongs(PyObject* obj,
void CastPyArg2AttrInts(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int64_t> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) {
......@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrFloats(PyObject* obj,
void CastPyArg2AttrLongs(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<float> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrFloat64s(PyObject* obj,
void CastPyArg2AttrFloats(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<double> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) {
......@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
void CastPyArg2AttrFloat64s(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<std::string> value;
if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) {
......@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i));
}
}
attrs[key] = value;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) {
......@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i));
}
}
attrs[key] = value;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
......@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
}
void CastPyArg2AttrBlock(PyObject* obj,
......
......@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool PyObject_CheckString(PyObject* obj);
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册