提交 62b15566 编写于 作者: J jim19930609

Added python-c code generation for final state Eager Dygraph

上级 ca743508
cc_library(scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info) cc_library(scale_node SRCS scale_node.cc DEPS global_utils pten pten_api grad_node_info)
#cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps}) cc_library(final_dygraph_node SRCS nodes.cc DEPS ${eager_deps})
#add_dependencies(final_dygraph_node eager_final_state_codegen) add_dependencies(final_dygraph_node eager_final_state_codegen)
cc_library(eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node) cc_library(eager_scale SRCS scale.cc DEPS pten_api pten autograd_meta scale_node)
#cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps}) cc_library(final_dygraph_function SRCS dygraph_functions.cc DEPS ${eager_deps})
#add_dependencies(final_dygraph_function eager_final_state_codegen) add_dependencies(final_dygraph_function eager_final_state_codegen)
#add_subdirectory(final_state_generator) add_subdirectory(final_state_generator)
set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag) set(EAGER_GENERETOR_DEPS ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} pybind proto_desc executor layer tracer engine imperative_profiler imperative_flag)
......
...@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen ...@@ -24,3 +24,13 @@ add_custom_target(eager_final_state_codegen
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path} ${nodes_h_path}
VERBATIM VERBATIM
) )
set(tmp_python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h")
set(python_c_output_path "${PADDLE_SOURCE_DIR}/paddle/fluid/pybind/eager_final_state_op_function_impl.h")
add_custom_target(eager_final_state_python_c_codegen
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--output_path=${tmp_python_c_output_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_python_c_output_path} ${python_c_output_path}
VERBATIM
)
...@@ -82,6 +82,14 @@ def RemoveConstAndReference(string): ...@@ -82,6 +82,14 @@ def RemoveConstAndReference(string):
return ret return ret
def GetGradNodeName(string):
return f"FinalGradNode{string}"
def GetForwardFunctionName(string):
return f"{string}_final_state_dygraph_function"
def GetAutoGradMetaName(string): def GetAutoGradMetaName(string):
return f"{string}_autograd_meta" return f"{string}_autograd_meta"
...@@ -145,13 +153,13 @@ def ParseYamlArgs(string): ...@@ -145,13 +153,13 @@ def ParseYamlArgs(string):
def ParseYamlReturns(string): def ParseYamlReturns(string):
# Example: Tensor, Tensor # Example: Tensor, Tensor
# list = [ [ret_type, orig_position], ...] # list = [ ["", ret_type, orig_position], ...]
returns_list = [] returns_list = []
returns = [x.strip() for x in string.strip().split(",")] returns = [x.strip() for x in string.strip().split(",")]
for i in range(len(returns)): for i in range(len(returns)):
ret = returns[i] ret = returns[i]
returns_list.append([ret, i]) returns_list.append(["", ret, i])
return returns_list return returns_list
...@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list, ...@@ -260,8 +268,8 @@ def ForwardsValidationCheck(forward_inputs_list, forward_attrs_list,
assert orig_attr_pos == forward_attr_pos assert orig_attr_pos == forward_attr_pos
for i in range(len(forward_returns_list)): for i in range(len(forward_returns_list)):
orig_return_type = orig_forward_returns_list[i][0] orig_return_type = orig_forward_returns_list[i][1]
orig_return_pos = orig_forward_returns_list[i][1] orig_return_pos = orig_forward_returns_list[i][2]
forward_return_type = forward_returns_list[i][1] forward_return_type = forward_returns_list[i][1]
forward_return_pos = forward_returns_list[i][2] forward_return_pos = forward_returns_list[i][2]
...@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map, ...@@ -452,13 +460,14 @@ def GenerateNodeDeclaration(fwd_api_name, backward_fwd_input_map,
RemoveConstAndReference(atype), saved_attr_name, default_val) RemoveConstAndReference(atype), saved_attr_name, default_val)
# End: SetAttributes & Attribute Members # End: SetAttributes & Attribute Members
grad_node_name = GetGradNodeName(fwd_api_name)
NODE_DECLARATION_TEMPLATE = """ NODE_DECLARATION_TEMPLATE = """
class GradNode{} : public egr::GradNodeBase {{ class {} : public egr::GradNodeBase {{
public: public:
GradNode{}() : egr::GradNodeBase() {{}} {}() : egr::GradNodeBase() {{}}
GradNode{}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : {}(size_t bwd_in_slot_num, size_t bwd_out_slot_num) :
egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}} egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {{}}
~GradNode{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<egr::EagerTensor>> operator()( virtual std::vector<std::vector<egr::EagerTensor>> operator()(
const std::vector<std::vector<egr::EagerTensor>>& grads) override; const std::vector<std::vector<egr::EagerTensor>>& grads) override;
...@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{ ...@@ -476,7 +485,7 @@ class GradNode{} : public egr::GradNodeBase {{
}}; }};
""" """
node_declaration_str = NODE_DECLARATION_TEMPLATE.format( node_declaration_str = NODE_DECLARATION_TEMPLATE.format(
forward_op_name, forward_op_name, forward_op_name, forward_op_name, grad_node_name, grad_node_name, grad_node_name, grad_node_name,
set_tensor_wrapper_methods_str, set_attribute_methods_str, set_tensor_wrapper_methods_str, set_attribute_methods_str,
tensor_wrapper_members_str, attribute_members_str) tensor_wrapper_members_str, attribute_members_str)
...@@ -503,10 +512,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -503,10 +512,15 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_api_args[ grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr) )" grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( egr::EagerUtils::RecoverTensorWrapper(&this->{tensor_wrapper_name}, nullptr) )"
for _, (_, fwd_position, for _, (ttype, fwd_position,
grad_api_position) in backward_grad_input_map.items(): grad_api_position) in backward_grad_input_map.items():
grad_api_args[ if IsPlainTensorType(ttype):
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}] )" grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}][0] )"
else:
assert IsVectorTensorType(ttype)
grad_api_args[
grad_api_position] = f"egr::EagerUtils::SyncToPtenTensors( grads[{fwd_position}] )"
for name, _, _, grad_api_position in backward_attrs_list: for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name) saved_attribute_name = GetSavedName(name)
...@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -531,8 +545,9 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
returns_str += f"returns[{fwd_position}] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[{grad_api_position}] );\n" returns_str += f"returns[{fwd_position}] = egr::EagerUtils::CreateEagerTensorFromTensor( grad_api_returns[{grad_api_position}] );\n"
returns_str += f"return returns;\n" returns_str += f"return returns;\n"
grad_node_name = GetGradNodeName(fwd_api_name)
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{ std::vector<std::vector<egr::EagerTensor>> {}::operator()(const std::vector<std::vector<egr::EagerTensor>>& grads) {{
// Call grad_api function // Call grad_api function
auto grad_api_returns = paddle::experimental::{}({}); auto grad_api_returns = paddle::experimental::{}({});
{} {}
...@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec ...@@ -540,7 +555,7 @@ std::vector<std::vector<egr::EagerTensor>> GradNode{}::operator()(const std::vec
""" """
node_definition_str = FUNCTION_TEMPLATE.format( node_definition_str = FUNCTION_TEMPLATE.format(
fwd_api_name, bwd_api_name, grad_api_args_str, returns_str) grad_node_name, bwd_api_name, grad_api_args_str, returns_str)
return node_definition_str return node_definition_str
...@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -610,7 +625,8 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# Node Construction # Node Construction
num_bwd_inputs = len(backward_grad_input_map.keys()) num_bwd_inputs = len(backward_grad_input_map.keys())
num_bwd_outputs = len(backward_grad_output_map.keys()) num_bwd_outputs = len(backward_grad_output_map.keys())
node_construction_str = f" auto grad_node = std::make_shared<GradNode{fwd_api_name}>({num_bwd_inputs}, {num_bwd_outputs});" grad_node_name = GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
# SetAttributes # SetAttributes
set_attributes_list = [] set_attributes_list = []
...@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -786,7 +802,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list)
FORWARD_FUNCTION_TEMPLATE = """ FORWARD_FUNCTION_TEMPLATE = """
{} {}_dygraph_function({}) {{ {} {}({}) {{
// Forward API Call // Forward API Call
{} {}
...@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -799,15 +815,34 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
}} }}
""" """
forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, fwd_api_name, inputs_args_str, forward_call_str, returns_type_str, forward_function_name, inputs_args_str,
returns_str, node_creation_str) forward_call_str, returns_str, node_creation_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_str});"
forward_function_declaration_str = f"{returns_type_str} {fwd_api_name}_dygraph_function({inputs_args_str});"
return forward_function_str, forward_function_declaration_str return forward_function_str, forward_function_declaration_str
def FakeMatmulGradAPI():
fake_matmul_grad_str = """
namespace paddle {
namespace experimental {
std::vector<std::vector<Tensor>> matmul_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
bool transpose_x,
bool transpose_y) {
std::vector<std::vector<Tensor>> ret;
return ret;
}
}
}
"""
return fake_matmul_grad_str
def GenerateNodeCCFile(filepath, node_definition_str): def GenerateNodeCCFile(filepath, node_definition_str):
file_contents = """ file_contents = """
#include "glog/logging.h" #include "glog/logging.h"
...@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): ...@@ -819,6 +854,7 @@ def GenerateNodeCCFile(filepath, node_definition_str):
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
""" """
file_contents += FakeMatmulGradAPI()
file_contents += node_definition_str file_contents += node_definition_str
with open(filepath, 'a') as f: with open(filepath, 'a') as f:
f.write(file_contents) f.write(file_contents)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
"int": "CastPyArg2Int",
"long": "CastPyArg2Long",
"float": "CastPyArg2Float",
"string": "CastPyArg2String",
"bool[]": "CastPyArg2Booleans",
"int[]": "CastPyArg2Ints",
"long[]": "CastPyArg2Longs",
"float[]": "CastPyArg2Floats",
"double[]": "CastPyArg2Float64s",
"string[]": "CastPyArg2Strings"
}
atype_to_cxx_type = {
"bool": "bool",
"int": "int",
"long": "long",
"float": "float",
"string": "std::string",
"bool[]": "std::vector<bool>",
"int[]": "std::vector<int>",
"long[]": "std::vector<long>",
"float[]": "std::vector<float>",
"double[]": "std::vector<double>",
"string[]": "std::vector<std::string>"
}
def ParseArguments():
parser = argparse.ArgumentParser(
description='Eager Code Generator Args Parser')
parser.add_argument('--api_yaml_path', type=str)
parser.add_argument('--output_path', type=str)
args = parser.parse_args()
return args
def GetCxxType(atype):
if atype not in atype_to_cxx_type.keys():
assert False
return atype_to_cxx_type[atype]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
assert False
return atype_to_parsing_function[atype]
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# Get EagerTensor from args
# Get dygraph function call args
num_args = len(forward_inputs_position_map.keys()) + len(forward_attrs_list)
num_input_tensors = len(forward_inputs_position_map.keys())
dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetEagerTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = " paddle::framework::AttributeMap attrs;\n"
# Get Attributes
for name, atype, _, pos in forward_attrs_list:
parsing_function = FindParsingFunctionFromAttributeType(atype)
cxx_type = GetCxxType(atype)
key = f"{name}"
parse_attributes_str += f" PyObject* {name}_obj = PyTuple_GET_ITEM(args, {pos});\n"
parse_attributes_str += f" {cxx_type} {name} = {parsing_function}({name}_obj, \"{fwd_api_name}\", {pos});\n"
dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
PYTHON_C_FUNCTION_TEMPLATE = """
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
{{
PyThreadState *tstate = nullptr;
try
{{
// Get EagerTensors from args
{}
// Parse Attributes
{}
tstate = PyEval_SaveThread();
auto out = {}({});
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
}}
catch(...) {{
if (tstate) {{
PyEval_RestoreThread(tstate);
}}
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
fwd_api_name, get_eager_tensor_str, parse_attributes_str,
GetForwardFunctionName(fwd_api_name), dygraph_function_call_str)
python_c_function_reg_str = f"{{\"final_state_{fwd_api_name}\", (PyCFunction)(void(*)(void))eager_final_state_api_{fwd_api_name}, METH_VARARGS | METH_KEYWORDS, \"C++ interface function for {fwd_api_name} in dygraph.\"}}"
return python_c_function_str, python_c_function_reg_str
def GeneratePythonCWrappers(python_c_function_str, python_c_function_reg_str):
PYTHON_C_WRAPPER_TEMPLATE = """
#pragma once
#include "pybind11/detail/common.h"
#include "paddle/fluid/pybind/op_function_common.h"
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/pybind/exception.h"
#include <Python.h>
namespace paddle {{
namespace pybind {{
{}
static PyMethodDef EagerFinalStateMethods[] = {{
{}
}};
}} // namespace pybind
}} // namespace paddle
"""
python_c_str = PYTHON_C_WRAPPER_TEMPLATE.format(python_c_function_str,
python_c_function_reg_str)
return python_c_str
def GeneratePythonCFile(filepath, python_c_str):
with open(filepath, 'a') as f:
f.write(python_c_str)
if __name__ == "__main__":
args = ParseArguments()
api_yaml_path = args.api_yaml_path
fwd_api_list = ReadFwdFile(api_yaml_path)
python_c_function_list = []
python_c_function_reg_list = []
for fwd_api in fwd_api_list:
# We only generate Ops with grad
if 'backward' not in fwd_api.keys():
continue
assert 'api' in fwd_api.keys()
assert 'args' in fwd_api.keys()
assert 'output' in fwd_api.keys()
assert 'backward' in fwd_api.keys()
fwd_api_name = fwd_api['api']
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
print("Parsed Original Forward Inputs List: ", forward_inputs_list)
print("Prased Original Forward Attrs List: ", forward_attrs_list)
print("Parsed Original Forward Returns List: ", forward_returns_list)
forward_inputs_position_map, forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
print("Generated Forward Input Position Map: ",
forward_inputs_position_map)
print("Generated Forward Output Position Map: ",
forward_outputs_position_map)
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
python_c_functions_str = "\n".join(python_c_function_list)
python_c_functions_reg_str = ",\n".join(python_c_function_reg_list)
python_c_str = GeneratePythonCWrappers(python_c_functions_str,
python_c_functions_reg_str)
print("Generated Python-C Codes: ", python_c_str)
output_path = args.output_path
for path in [output_path]:
if os.path.exists(path):
os.remove(path)
GeneratePythonCFile(output_path, python_c_str)
...@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad( ...@@ -288,7 +288,9 @@ void EagerUtils::CheckAndRetainGrad(
paddle::experimental::Tensor EagerUtils::SyncToPtenTensors( paddle::experimental::Tensor EagerUtils::SyncToPtenTensors(
const egr::EagerTensor& tensor) { const egr::EagerTensor& tensor) {
const_cast<EagerTensor*>(&tensor)->SyncToTensor(); if (!tensor.initialized()) {
const_cast<EagerTensor*>(&tensor)->SyncToTensor();
}
return *tensor.Tensor().get(); return *tensor.Tensor().get();
} }
...@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors( ...@@ -298,7 +300,9 @@ std::vector<paddle::experimental::Tensor> EagerUtils::SyncToPtenTensors(
size_t num = tensors.size(); size_t num = tensors.size();
res.reserve(num); res.reserve(num);
for (size_t i = 0; i < num; i++) { for (size_t i = 0; i < num; i++) {
const_cast<EagerTensor*>(&(tensors[i]))->SyncToTensor(); if (!tensors[i].initialized()) {
const_cast<EagerTensor*>(&(tensors[i]))->SyncToTensor();
}
res.push_back(*tensors[i].Tensor().get()); res.push_back(*tensors[i].Tensor().get());
} }
return res; return res;
......
...@@ -151,7 +151,7 @@ if(WITH_PYTHON) ...@@ -151,7 +151,7 @@ if(WITH_PYTHON)
set(tmp_eager_impl_file ${eager_impl_file}.tmp) set(tmp_eager_impl_file ${eager_impl_file}.tmp)
set(OP_IMPL_DEPS op_function_generator) set(OP_IMPL_DEPS op_function_generator)
set(EAGER_OP_IMPL_DEPS eager_op_function_generator) set(EAGER_OP_IMPL_DEPS eager_op_function_generator eager_final_state_python_c_codegen)
if(WIN32) if(WIN32)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja") if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
...@@ -275,7 +275,7 @@ if(WITH_PYTHON) ...@@ -275,7 +275,7 @@ if(WITH_PYTHON)
if(NOT ON_INFER) if(NOT ON_INFER)
cc_library(paddle_eager cc_library(paddle_eager
SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc
DEPS eager_api autograd_meta backward grad_node_info pten op_function_common dygraph_function dygraph_node accumulation_node global_utils utils python) DEPS eager_api autograd_meta backward grad_node_info pten op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python)
add_dependencies(paddle_eager eager_codegen) add_dependencies(paddle_eager eager_codegen)
add_dependencies(paddle_eager eager_op_function_generator_cmd) add_dependencies(paddle_eager eager_op_function_generator_cmd)
list(APPEND PYBIND_DEPS paddle_eager) list(APPEND PYBIND_DEPS paddle_eager)
......
...@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) { ...@@ -393,6 +393,7 @@ int main(int argc, char* argv[]) {
std::vector<std::string> headers{ std::vector<std::string> headers{
"\"pybind11/detail/common.h\"", "\"pybind11/detail/common.h\"",
"\"paddle/fluid/pybind/eager_final_state_op_function_impl.h\"",
"\"paddle/fluid/pybind/op_function_common.h\"", "\"paddle/fluid/pybind/op_function_common.h\"",
"\"paddle/fluid/eager/api/generated/fluid_generated/" "\"paddle/fluid/eager/api/generated/fluid_generated/"
"dygraph_forward_api.h\"", "dygraph_forward_api.h\"",
...@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) { ...@@ -441,6 +442,10 @@ int main(int argc, char* argv[]) {
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n" "core.eager.ops failed!\"));\n"
<< " }\n\n" << " }\n\n"
<< " if (PyModule_AddFunctions(m.ptr(), EagerFinalStateMethods) < 0) {\n"
<< " PADDLE_THROW(platform::errors::Fatal (\"Add functions to "
"core.eager.ops failed!\"));\n"
<< " }\n\n"
<< "}\n\n" << "}\n\n"
<< "} // namespace pybind\n" << "} // namespace pybind\n"
<< "} // namespace paddle\n"; << "} // namespace paddle\n";
......
...@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) { ...@@ -100,17 +100,15 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); } bool PyObject_CheckString(PyObject* obj) { return PyUnicode_Check(obj); }
void CastPyArg2AttrBoolean(PyObject* obj, bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
paddle::framework::AttributeMap& attrs, // NOLINT ssize_t arg_pos) {
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
if (obj == Py_None) { if (obj == Py_None) {
attrs[key] = false; // To be compatible with QA integration testing. Some return false; // To be compatible with QA integration testing. Some
// test case pass in None. // test case pass in None.
} else if (obj == Py_True) { } else if (obj == Py_True) {
attrs[key] = true; return true;
} else if (obj == Py_False) { } else if (obj == Py_False) {
attrs[key] = false; return false;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj, ...@@ -118,62 +116,89 @@ void CastPyArg2AttrBoolean(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return false;
}
void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Boolean(obj, op_type, arg_pos);
}
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) {
return (int)PyLong_AsLong(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"int, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
return 0;
} }
void CastPyArg2AttrInt(PyObject* obj, void CastPyArg2AttrInt(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Int(obj, op_type, arg_pos);
}
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) { if (PyObject_CheckLongOrToLong(&obj)) {
attrs[key] = (int)PyLong_AsLong(obj); // NOLINT return (int64_t)PyLong_AsLong(obj); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"int, but got %s", "long, but got %s",
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return 0;
} }
void CastPyArg2AttrLong(PyObject* obj, void CastPyArg2AttrLong(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
if (PyObject_CheckLongOrToLong(&obj)) { attrs[key] = CastPyArg2Long(obj, op_type, arg_pos);
attrs[key] = (int64_t)PyLong_AsLong(obj); // NOLINT }
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) {
return (float)PyFloat_AsDouble(obj); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
"long, but got %s", "float, but got %s",
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return 0.0;
} }
void CastPyArg2AttrFloat(PyObject* obj, void CastPyArg2AttrFloat(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
if (PyObject_CheckFloatOrToFloat(&obj)) { attrs[key] = CastPyArg2Float(obj, op_type, arg_pos);
attrs[key] = (float)PyFloat_AsDouble(obj); // NOLINT
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be "
"float, but got %s",
op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
}
} }
void CastPyArg2AttrString(PyObject* obj, std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
paddle::framework::AttributeMap& attrs, // NOLINT ssize_t arg_pos) {
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
if (PyObject_CheckString(obj)) { if (PyObject_CheckString(obj)) {
Py_ssize_t size; Py_ssize_t size;
const char* data; const char* data;
data = PyUnicode_AsUTF8AndSize(obj, &size); data = PyUnicode_AsUTF8AndSize(obj, &size);
attrs[key] = std::string(data, (size_t)size); // NOLINT return std::string(data, (size_t)size); // NOLINT
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj, ...@@ -181,16 +206,23 @@ void CastPyArg2AttrString(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return "";
} }
void CastPyArg2AttrBooleans(PyObject* obj, void CastPyArg2AttrString(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2String(obj, op_type, arg_pos);
}
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<bool> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckBool(&item)) { if (PyObject_CheckBool(&item)) {
...@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -204,11 +236,9 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<bool> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckBool(&item)) { if (PyObject_CheckBool(&item)) {
...@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -222,7 +252,6 @@ void CastPyArg2AttrBooleans(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj, ...@@ -230,16 +259,23 @@ void CastPyArg2AttrBooleans(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrInts(PyObject* obj, void CastPyArg2AttrBooleans(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Booleans(obj, op_type, arg_pos);
}
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -253,11 +289,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -271,11 +305,9 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -289,7 +321,6 @@ void CastPyArg2AttrInts(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj, ...@@ -297,16 +328,23 @@ void CastPyArg2AttrInts(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrLongs(PyObject* obj, void CastPyArg2AttrInts(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Ints(obj, op_type, arg_pos);
}
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<int64_t> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -320,11 +358,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -338,11 +374,9 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<int64_t> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckLongOrToLong(&item)) { if (PyObject_CheckLongOrToLong(&item)) {
...@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -356,7 +390,6 @@ void CastPyArg2AttrLongs(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj, ...@@ -364,16 +397,23 @@ void CastPyArg2AttrLongs(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrFloats(PyObject* obj, void CastPyArg2AttrLongs(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Longs(obj, op_type, arg_pos);
}
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos) {
std::vector<float> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -387,11 +427,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -405,11 +443,9 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<float> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -423,7 +459,6 @@ void CastPyArg2AttrFloats(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj, ...@@ -431,16 +466,24 @@ void CastPyArg2AttrFloats(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrFloat64s(PyObject* obj, void CastPyArg2AttrFloats(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Floats(obj, op_type, arg_pos);
}
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<double> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -454,11 +497,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -472,11 +513,9 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PySequence_Check(obj)) { } else if (PySequence_Check(obj)) {
Py_ssize_t len = PySequence_Size(obj); Py_ssize_t len = PySequence_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<double> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PySequence_GetItem(obj, i); item = PySequence_GetItem(obj, i);
if (PyObject_CheckFloatOrToFloat(&item)) { if (PyObject_CheckFloatOrToFloat(&item)) {
...@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -490,7 +529,6 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj, ...@@ -498,16 +536,24 @@ void CastPyArg2AttrFloat64s(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
} }
void CastPyArg2AttrStrings(PyObject* obj, void CastPyArg2AttrFloat64s(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
ssize_t arg_pos) { ssize_t arg_pos) {
attrs[key] = CastPyArg2Float64s(obj, op_type, arg_pos);
}
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos) {
std::vector<std::string> value;
if (PyList_Check(obj)) { if (PyList_Check(obj)) {
Py_ssize_t len = PyList_Size(obj); Py_ssize_t len = PyList_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyList_GetItem(obj, i); item = PyList_GetItem(obj, i);
if (PyObject_CheckString(item)) { if (PyObject_CheckString(item)) {
...@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -524,11 +570,9 @@ void CastPyArg2AttrStrings(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else if (PyTuple_Check(obj)) { } else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj); Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr; PyObject* item = nullptr;
std::vector<std::string> value;
for (Py_ssize_t i = 0; i < len; i++) { for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i); item = PyTuple_GetItem(obj, i);
if (PyObject_CheckString(item)) { if (PyObject_CheckString(item)) {
...@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -545,7 +589,6 @@ void CastPyArg2AttrStrings(PyObject* obj,
i)); i));
} }
} }
attrs[key] = value;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument (position %d) must be " "%s(): argument (position %d) must be "
...@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj, ...@@ -553,6 +596,15 @@ void CastPyArg2AttrStrings(PyObject* obj,
op_type, arg_pos + 1, op_type, arg_pos + 1,
((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT ((PyTypeObject*)obj->ob_type)->tp_name)); // NOLINT
} }
return value;
}
void CastPyArg2AttrStrings(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type,
ssize_t arg_pos) {
attrs[key] = CastPyArg2Strings(obj, op_type, arg_pos);
} }
void CastPyArg2AttrBlock(PyObject* obj, void CastPyArg2AttrBlock(PyObject* obj,
......
...@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj); ...@@ -43,6 +43,30 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj);
bool PyObject_CheckString(PyObject* obj); bool PyObject_CheckString(PyObject* obj);
bool CastPyArg2Boolean(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
int CastPyArg2Int(PyObject* obj, const std::string& op_type, ssize_t arg_pos);
int64_t CastPyArg2Long(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
float CastPyArg2Float(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::string CastPyArg2String(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<bool> CastPyArg2Booleans(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int> CastPyArg2Ints(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<int64_t> CastPyArg2Longs(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<float> CastPyArg2Floats(PyObject* obj, const std::string& op_type,
ssize_t arg_pos);
std::vector<double> CastPyArg2Float64s(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
std::vector<std::string> CastPyArg2Strings(PyObject* obj,
const std::string& op_type,
ssize_t arg_pos);
void CastPyArg2AttrBoolean(PyObject* obj, void CastPyArg2AttrBoolean(PyObject* obj,
paddle::framework::AttributeMap& attrs, // NOLINT paddle::framework::AttributeMap& attrs, // NOLINT
const std::string& key, const std::string& op_type, const std::string& key, const std::string& op_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册