diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt b/paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt index e04d282748c0a6061ccb7d9429f0066c7fabf3ca..8f4c2c3660326012b49f04f1a6177f38efa92230 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/CMakeLists.txt @@ -1,3 +1,3 @@ cc_library(scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info) -#cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps}) -#add_dependencies(final_dygraph_node eager_final_state_codegen) +cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps}) +add_dependencies(final_dygraph_node eager_final_state_codegen) diff --git a/paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt b/paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt index f682c27992db15e81f28afe0bb9c3b30454a9d88..1187136526589015fcb941ceb1ce79ffda399acd 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt +++ b/paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt @@ -1,3 +1,3 @@ cc_library(eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node) -#cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps}) -#add_dependencies(final_dygraph_function eager_final_state_codegen) +cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps}) +add_dependencies(final_dygraph_function eager_final_state_codegen) diff --git a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt index c504a126ddecaebfcb55313573d6bc490007feef..668e60d857b9ca371243891db686421810fda0bb 100644 --- a/paddle/fluid/eager/auto_code_generator/CMakeLists.txt +++ b/paddle/fluid/eager/auto_code_generator/CMakeLists.txt @@ -1,4 +1,4 @@ -#add_subdirectory(final_state_generator) +add_subdirectory(final_state_generator) set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt b/paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt index 0a96cbc9c970ca776e19cef74e18fe66016804e2..c6bca01205e19c58d5924f4e9d60bb76164fee2b 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/CMakeLists.txt @@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path} VERBATIM ) + +set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h") +set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h") +add_custom_target(eager_final_state_python_c_codegen + COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py" + "--api_yaml_path=${api_yaml_path}" + "--output_path=${tmp_python_c_output_path}" + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path} ${python_c_output_path} + VERBATIM +) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 4782ca6b3b0e5f07130e84e1501a4985e364ed95..4c9372a0b6c888d1fcf678a5b2a07317aa2dc907 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -82,6 +82,14 @@ def RemoveConstAndReference(string): return ret +def GetGradNodeName(string): + return f"FinalGradNode{string}" + + +def GetForwardFunctionName(string): + return f"{string}_final_state_dygraph_function" + + def GetAutoGradMetaName(string): return f"{string}_autograd_meta" @@ -145,13 +153,13 @@ def ParseYamlArgs(string): def ParseYamlReturns(string): # Example: Tensor, Tensor - # list = [ [ret_type, orig_position], ...] + # list = [ ["", ret_type, orig_position], ...] returns_list = [] returns = [x.strip() for x in string.strip().split(",")] for i in range(len(returns)): ret = returns[i] - returns_list.append([ret, i]) + returns_list.append(["", ret, i]) return returns_list @@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, assert orig_attr_pos == forward_attr_pos for i in range(len(forward_returns_list)): - orig_return_type = orig_forward_returns_list[i][0] - orig_return_pos = orig_forward_returns_list[i][1] + orig_return_type = orig_forward_returns_list[i][1] + orig_return_pos = orig_forward_returns_list[i][2] forward_return_type = forward_returns_list[i][1] forward_return_pos = forward_returns_list[i][2] @@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, RemoveConstAndReference(atype), saved_attr_name, default_val) # End: SetAttributes & Attribute Members + grad_node_name = GetGradNodeName(fwd_api_name) NODE_DECLARATION_TEMPLATE = """ -class GradNode{} : public egr::GradNodeBase {{ +class {} : public egr::GradNodeBase {{ public: - GradNode{}() : egr::GradNodeBase() {{}} - GradNode{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : + {}() : egr::GradNodeBase() {{}} + {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} - ~GradNode{}() override = default; + ~{}() override = default; virtual std::vector> operator()( const std::vector>& grads) override; @@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{ }}; """ node_declaration_str = NODE_DECLARATION_TEMPLATE.format( - forward_op_name, forward_op_name, forward_op_name, forward_op_name, + grad_node_name, grad_node_name, grad_node_name, grad_node_name, set_tensor_wrapper_methods_str, set_attribute_methods_str, tensor_wrapper_members_str, attribute_members_str) @@ -503,10 +512,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, grad_api_args[ grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr) )" - for _, (_, fwd_position, + for _, (ttype, fwd_position, grad_api_position) in backward_grad_input_map.items(): - grad_api_args[ - grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}] )" + if IsPlainTensorType(ttype): + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}][0] )" + else: + assert IsVectorTensorType(ttype) + grad_api_args[ + grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}] )" for name, _, _, grad_api_position in backward_attrs_list: saved_attribute_name = GetSavedName(name) @@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, returns_str += f"returns[{fwd_position}] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[{grad_api_position}] );\n" returns_str += f"return returns;\n" + grad_node_name = GetGradNodeName(fwd_api_name) FUNCTION_TEMPLATE = """ -std::vector> GradNode{}::operator()(const std::vector>& grads) {{ +std::vector> {}::operator()(const std::vector>& grads) {{ // Call grad_api function auto grad_api_returns = paddle::experimental::{}({}); {} @@ -540,7 +555,7 @@ std::vector> GradNode{}::operator()(const std::vec """ node_definition_str = FUNCTION_TEMPLATE.format( - fwd_api_name, bwd_api_name, grad_api_args_str, returns_str) + grad_node_name, bwd_api_name, grad_api_args_str, returns_str) return node_definition_str @@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, # Node Construction num_bwd_inputs = len(backward_grad_input_map.keys()) num_bwd_outputs = len(backward_grad_output_map.keys()) - node_construction_str = f" auto grad_node = std::make_shared({num_bwd_inputs}, {num_bwd_outputs});" + grad_node_name = GetGradNodeName(fwd_api_name) + node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});" # SetAttributes set_attributes_list = [] @@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, backward_grad_output_map, backward_attrs_list) FORWARD_FUNCTION_TEMPLATE = """ -{} {}_dygraph_function({}) {{ +{} {}({}) {{ // Forward API Call {} @@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, }} """ + forward_function_name = GetForwardFunctionName(fwd_api_name) forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( - returns_type_str, fwd_api_name, inputs_args_str, forward_call_str, - returns_str, node_creation_str) - - forward_function_declaration_str = f"{returns_type_str} {fwd_api_name}_dygraph_function({inputs_args_str});" + returns_type_str, forward_function_name, inputs_args_str, + forward_call_str, returns_str, node_creation_str) + forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});" return forward_function_str, forward_function_declaration_str +def FakeMatmulGradAPI(): + fake_matmul_grad_str = """ +namespace paddle { +namespace experimental { + std::vector> matmul_grad(const Tensor& x, + const Tensor& y, + const Tensor& out_grad, + bool transpose_x, + bool transpose_y) { + std::vector> ret; + return ret; + } +} +} + +""" + return fake_matmul_grad_str + + def GenerateNodeCCFile(filepath, node_definition_str): file_contents = """ #include "glog/logging.h" @@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" """ + file_contents += FakeMatmulGradAPI() file_contents += node_definition_str with open(filepath, 'a') as f: f.write(file_contents) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3837914d529219799cbcbc14d922c32c900b07 --- /dev/null +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -0,0 +1,230 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap + +atype_to_parsing_function = { + "bool": "CastPyArg2Boolean", + "int": "CastPyArg2Int", + "long": "CastPyArg2Long", + "float": "CastPyArg2Float", + "string": "CastPyArg2String", + "bool[]": "CastPyArg2Booleans", + "int[]": "CastPyArg2Ints", + "long[]": "CastPyArg2Longs", + "float[]": "CastPyArg2Floats", + "double[]": "CastPyArg2Float64s", + "string[]": "CastPyArg2Strings" +} + +atype_to_cxx_type = { + "bool": "bool", + "int": "int", + "long": "long", + "float": "float", + "string": "std::string", + "bool[]": "std::vector", + "int[]": "std::vector", + "long[]": "std::vector", + "float[]": "std::vector", + "double[]": "std::vector", + "string[]": "std::vector" +} + + +def ParseArguments(): + parser = argparse.ArgumentParser( + description='Eager Code Generator Args Parser') + parser.add_argument('--api_yaml_path', type=str) + parser.add_argument('--output_path', type=str) + + args = parser.parse_args() + return args + + +def GetCxxType(atype): + if atype not in atype_to_cxx_type.keys(): + assert False + + return atype_to_cxx_type[atype] + + +def FindParsingFunctionFromAttributeType(atype): + if atype not in atype_to_parsing_function.keys(): + assert False + + return atype_to_parsing_function[atype] + + +def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map, + forward_attrs_list, forward_outputs_position_map): + # forward_inputs_position_map = { "name" : [type, fwd_position] } + # forward_outputs_position_map = { "name" : [type, fwd_position] } + # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] + + # Get EagerTensor from args + # Get dygraph function call args + num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list) + num_input_tensors = len(forward_inputs_position_map.keys()) + dygraph_function_call_list = ["" for i in range(num_args)] + get_eager_tensor_str = "" + for name, (ttype, pos) in forward_inputs_position_map.items(): + get_eager_tensor_str += f" auto& {name} = GetEagerTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n" + dygraph_function_call_list[pos] = f"{name}" + + parse_attributes_str = " paddle::framework::AttributeMap attrs;\n" + # Get Attributes + for name, atype, _, pos in forward_attrs_list: + parsing_function = FindParsingFunctionFromAttributeType(atype) + cxx_type = GetCxxType(atype) + key = f"{name}" + + parse_attributes_str += f" PyObject* {name}_obj = PyTuple_GET_ITEM(args, {pos});\n" + parse_attributes_str += f" {cxx_type} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n" + + dygraph_function_call_list[pos] = f"{name}" + dygraph_function_call_str = ",".join(dygraph_function_call_list) + + PYTHON_C_FUNCTION_TEMPLATE = """ +static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) +{{ + PyThreadState *tstate = nullptr; + try + {{ + // Get EagerTensors from args +{} + + // Parse Attributes +{} + + tstate = PyEval_SaveThread(); + + auto out = {}({}); + + PyEval_RestoreThread(tstate); + tstate = nullptr; + return ToPyObject(out); + }} + catch(...) {{ + if (tstate) {{ + PyEval_RestoreThread(tstate); + }} + ThrowExceptionToPython(std::current_exception()); + return nullptr; + }} +}} + +""" + python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( + fwd_api_name, get_eager_tensor_str, parse_attributes_str, + GetForwardFunctionName(fwd_api_name), dygraph_function_call_str) + + python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}" + + return python_c_function_str, python_c_function_reg_str + + +def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str): + + PYTHON_C_WRAPPER_TEMPLATE = """ +#pragma once + +#include "pybind11/detail/common.h" +#include "paddle/fluid/pybind/op_function_common.h" +#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" +#include "paddle/fluid/pybind/exception.h" +#include + +namespace paddle {{ +namespace pybind {{ + +{} + +static PyMethodDef EagerFinalStateMethods[] = {{ + {} +}}; + +}} // namespace pybind +}} // namespace paddle + +""" + python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str, + python_c_function_reg_str) + + return python_c_str + + +def GeneratePythonCFile(filepath, python_c_str): + with open(filepath, 'a') as f: + f.write(python_c_str) + + +if __name__ == "__main__": + args = ParseArguments() + + api_yaml_path = args.api_yaml_path + fwd_api_list = ReadFwdFile(api_yaml_path) + + python_c_function_list = [] + python_c_function_reg_list = [] + for fwd_api in fwd_api_list: + # We only generate Ops with grad + if 'backward' not in fwd_api.keys(): + continue + + assert 'api' in fwd_api.keys() + assert 'args' in fwd_api.keys() + assert 'output' in fwd_api.keys() + assert 'backward' in fwd_api.keys() + + fwd_api_name = fwd_api['api'] + fwd_args_str = fwd_api['args'] + fwd_returns_str = fwd_api['output'] + + # Collect Original Forward Inputs/Outputs and then perform validation checks + forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward( + fwd_args_str, fwd_returns_str) + print("Parsed Original Forward Inputs List: ", forward_inputs_list) + print("Prased Original Forward Attrs List: ", forward_attrs_list) + print("Parsed Original Forward Returns List: ", forward_returns_list) + + forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap( + forward_inputs_list, forward_returns_list) + print("Generated Forward Input Position Map: ", + forward_inputs_position_map) + print("Generated Forward Output Position Map: ", + forward_outputs_position_map) + + python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction( + fwd_api_name, forward_inputs_position_map, forward_attrs_list, + forward_outputs_position_map) + python_c_function_list.append(python_c_function_str) + python_c_function_reg_list.append(python_c_function_reg_str) + print("Generated Python-C Function: ", python_c_function_str) + + python_c_functions_str = "\n".join(python_c_function_list) + python_c_functions_reg_str = ",\n".join(python_c_function_reg_list) + + python_c_str = GeneratePythonCWrappers(python_c_functions_str, + python_c_functions_reg_str) + print("Generated Python-C Codes: ", python_c_str) + + output_path = args.output_path + for path in [output_path]: + if os.path.exists(path): + os.remove(path) + + GeneratePythonCFile(output_path, python_c_str) diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 962f866456579f71fc6df5a72fc070a266f28291..349a9d18474e1b725fa54d401133019c49896582 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad( paddle::experimental::Tensor EagerUtils::SyncToPtenTensors( const egr::EagerTensor& tensor) { - const_cast(&tensor)->SyncToTensor(); + if (!tensor.initialized()) { + const_cast(&tensor)->SyncToTensor(); + } return *tensor.Tensor().get(); } @@ -298,7 +300,9 @@ std::vector EagerUtils::SyncToPtenTensors( size_t num = tensors.size(); res.reserve(num); for (size_t i = 0; i < num; i++) { - const_cast(&(tensors[i]))->SyncToTensor(); + if (!tensors[i].initialized()) { + const_cast(&(tensors[i]))->SyncToTensor(); + } res.push_back(*tensors[i].Tensor().get()); } return res; diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 1df77c78a419bb5c99a06a327b3309fdf3c7e6f2..4feba4ab19b785491bc611b00b1749f253433b29 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -151,7 +151,7 @@ if(WITH_PYTHON) set(tmp_eager_impl_file ${eager_impl_file}.tmp) set(OP_IMPL_DEPS op_function_generator) - set(EAGER_OP_IMPL_DEPS eager_op_function_generator) + set(EAGER_OP_IMPL_DEPS eager_op_function_generator eager_final_state_python_c_codegen) if(WIN32) if("${CMAKE_GENERATOR}" STREQUAL "Ninja") @@ -275,7 +275,7 @@ if(WITH_PYTHON) if(NOT ON_INFER) cc_library(paddle_eager SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc - DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python) + DEPS eager_api autograd_meta backward grad_node_info pten op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python) add_dependencies(paddle_eager eager_codegen) add_dependencies(paddle_eager eager_op_function_generator_cmd) list(APPEND PYBIND_DEPS paddle_eager) diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index 090604ab4ee1a15cfca9928513250c8fe0325ea9..34acff7efd19d7fd2e81cdc31a05b5a32156ad4d 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -393,6 +393,7 @@ int main(int argc, char* argv[]) { std::vector headers{ "\"pybind11/detail/common.h\"", + "\"paddle/fluid/pybind/eager_final_state_op_function_impl.h\"", "\"paddle/fluid/pybind/op_function_common.h\"", "\"paddle/fluid/eager/api/generated/fluid_generated/" "dygraph_forward_api.h\"", @@ -441,6 +442,10 @@ int main(int argc, char* argv[]) { << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " "core.eager.ops failed!\"));\n" << " }\n\n" + << " if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {\n" + << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " + "core.eager.ops failed!\"));\n" + << " }\n\n" << "}\n\n" << "} // namespace pybind\n" << "} // namespace paddle\n"; diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index 3ad4994a590f7284da25287c6ce1fa90840ee1aa..09c3cea398b2aec4d7cf0953ffb0aed75de37601 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); } -void CastPyArg2AttrBoolean(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (obj == Py_None) { - attrs[key] = false; // To be compatible with QA integration testing. Some - // test case pass in None. + return false; // To be compatible with QA integration testing. Some + // test case pass in None. } else if (obj == Py_True) { - attrs[key] = true; + return true; } else if (obj == Py_False) { - attrs[key] = false; + return false; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return false; +} + +void CastPyArg2AttrBoolean(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos); +} + +int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { + if (PyObject_CheckLongOrToLong(&obj)) { + return (int)PyLong_AsLong(obj); // NOLINT + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument (position %d) must be " + "int, but got %s", + op_type, arg_pos + 1, + ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT + } + + return 0; } void CastPyArg2AttrInt(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { + attrs[key] = CastPyArg2Int(obj, op_type, arg_pos); +} + +int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (PyObject_CheckLongOrToLong(&obj)) { - attrs[key] = (int)PyLong_AsLong(obj); // NOLINT + return (int64_t)PyLong_AsLong(obj); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " - "int, but got %s", + "long, but got %s", op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return 0; } void CastPyArg2AttrLong(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { - if (PyObject_CheckLongOrToLong(&obj)) { - attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT + attrs[key] = CastPyArg2Long(obj, op_type, arg_pos); +} + +float CastPyArg2Float(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + if (PyObject_CheckFloatOrToFloat(&obj)) { + return (float)PyFloat_AsDouble(obj); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " - "long, but got %s", + "float, but got %s", op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return 0.0; } void CastPyArg2AttrFloat(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type, ssize_t arg_pos) { - if (PyObject_CheckFloatOrToFloat(&obj)) { - attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "%s(): argument (position %d) must be " - "float, but got %s", - op_type, arg_pos + 1, - ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT - } + attrs[key] = CastPyArg2Float(obj, op_type, arg_pos); } -void CastPyArg2AttrString(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +std::string CastPyArg2String(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { if (PyObject_CheckString(obj)) { Py_ssize_t size; const char* data; data = PyUnicode_AsUTF8AndSize(obj, &size); - attrs[key] = std::string(data, (size_t)size); // NOLINT + return std::string(data, (size_t)size); // NOLINT } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return ""; } -void CastPyArg2AttrBooleans(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrString(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2String(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Booleans(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckBool(&item)) { @@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckBool(&item)) { @@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrInts(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrBooleans(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrLongs(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrInts(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Longs(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckLongOrToLong(&item)) { @@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrFloats(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrLongs(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Floats(PyObject* obj, const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrFloat64s(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrFloats(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Float64s(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else if (PySequence_Check(obj)) { Py_ssize_t len = PySequence_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PySequence_GetItem(obj, i); if (PyObject_CheckFloatOrToFloat(&item)) { @@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; } -void CastPyArg2AttrStrings(PyObject* obj, - paddle::framework::AttributeMap& attrs, // NOLINT - const std::string& key, const std::string& op_type, - ssize_t arg_pos) { +void CastPyArg2AttrFloat64s(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos); +} + +std::vector CastPyArg2Strings(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos) { + std::vector value; if (PyList_Check(obj)) { Py_ssize_t len = PyList_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyList_GetItem(obj, i); if (PyObject_CheckString(item)) { @@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj, i)); } } - attrs[key] = value; } else if (PyTuple_Check(obj)) { Py_ssize_t len = PyTuple_Size(obj); PyObject* item = nullptr; - std::vector value; for (Py_ssize_t i = 0; i < len; i++) { item = PyTuple_GetItem(obj, i); if (PyObject_CheckString(item)) { @@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj, i)); } } - attrs[key] = value; } else { PADDLE_THROW(platform::errors::InvalidArgument( "%s(): argument (position %d) must be " @@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj, op_type, arg_pos + 1, ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT } + + return value; +} + +void CastPyArg2AttrStrings(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, const std::string& op_type, + ssize_t arg_pos) { + attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos); } void CastPyArg2AttrBlock(PyObject* obj, diff --git a/paddle/fluid/pybind/op_function_common.h b/paddle/fluid/pybind/op_function_common.h index 9dc3a71a6ccf9433555516f3fa4e43b221c2a24a..7ead9852667252d189b1fcdecc6b4ac7b86d785f 100644 --- a/paddle/fluid/pybind/op_function_common.h +++ b/paddle/fluid/pybind/op_function_common.h @@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj); bool PyObject_CheckString(PyObject* obj); +bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos); +int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +float CastPyArg2Float(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::string CastPyArg2String(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Booleans(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Ints(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Longs(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Floats(PyObject* obj, const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Float64s(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos); +std::vector CastPyArg2Strings(PyObject* obj, + const std::string& op_type, + ssize_t arg_pos); + void CastPyArg2AttrBoolean(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, const std::string& op_type,