未验证 提交 ca11a0e5 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Support dispensable inputs for eager final state codegen (#39743)

上级 96d530c1
...@@ -143,6 +143,11 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list): ...@@ -143,6 +143,11 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
assert pos in intermediate_positions assert pos in intermediate_positions
def ParseDispensable(string):
# string: "X, Y"
return [v.strip() for v in string.split(",")]
def ParseIntermediate(string): def ParseIntermediate(string):
return [v.strip() for v in string.split(",")] return [v.strip() for v in string.split(",")]
...@@ -596,11 +601,11 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std: ...@@ -596,11 +601,11 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std:
return node_definition_str return node_definition_str
def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, def GenerateNodeCreationCodes(
forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list): backward_grad_output_map, backward_attrs_list, optional_inputs):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -674,8 +679,15 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -674,8 +679,15 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# SetTensorWrappers # SetTensorWrappers
set_tensor_wrappers_list = [] set_tensor_wrappers_list = []
for name, (_, is_fwd_input, _) in backward_fwd_input_map.items(): for name, (_, is_fwd_input, _) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input: if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
...@@ -762,11 +774,12 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name, ...@@ -762,11 +774,12 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str return node_creation_str
def GenerateForwardDefinition( def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
fwd_api_name, bwd_api_name, forward_inputs_position_map, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs): backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -775,6 +788,7 @@ def GenerateForwardDefinition( ...@@ -775,6 +788,7 @@ def GenerateForwardDefinition(
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...} # backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...} # backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] # backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = ["name0", ...]
# Get Function Args # Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys( num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
...@@ -784,17 +798,18 @@ def GenerateForwardDefinition( ...@@ -784,17 +798,18 @@ def GenerateForwardDefinition(
inputs_call_list = ["" for i in range(num_inputs)] inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}" inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
inputs_args_definition_list[ if is_optional:
pos] = f"const paddle::experimental::Tensor& {name}" arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[ else:
pos] = f"const paddle::experimental::Tensor& {name}" arg_str = f"const paddle::experimental::Tensor& {name}"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
inputs_args_definition_list[ arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[ inputs_args_definition_list[pos] = arg_str
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}" inputs_args_declaration_list[pos] = arg_str
for name, atype, default_val, pos in forward_attrs_list: for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name inputs_call_list[pos] = name
...@@ -849,7 +864,7 @@ def GenerateForwardDefinition( ...@@ -849,7 +864,7 @@ def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list) backward_grad_output_map, backward_attrs_list, optional_inputs)
FORWARD_FUNCTION_TEMPLATE = """ FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{ {} {}({}) {{
...@@ -1053,6 +1068,12 @@ if __name__ == "__main__": ...@@ -1053,6 +1068,12 @@ if __name__ == "__main__":
assert 'args' in bwd_api.keys() assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys() assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys() assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward'] bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args'] bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output'] bwd_returns_str = bwd_api['output']
...@@ -1128,7 +1149,8 @@ if __name__ == "__main__": ...@@ -1128,7 +1149,8 @@ if __name__ == "__main__":
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs) backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0] forward_definition_str += definition_declaration_pair[0]
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import os import os
import argparse import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = { atype_to_parsing_function = {
"bool": "CastPyArg2Boolean", "bool": "CastPyArg2Boolean",
...@@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype): ...@@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype):
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map, def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map): forward_attrs_list, forward_outputs_position_map,
optional_inputs):
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...] # forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = [name0, ...]
# Get EagerTensor from args # Get EagerTensor from args
# Get dygraph function call args # Get dygraph function call args
...@@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map, ...@@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
dygraph_function_call_list = ["" for i in range(num_args)] dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = "" get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n" is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
get_eager_tensor_str += f" auto {name} = GetTensorListFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
if is_optional:
get_eager_tensor_str += f" auto {name} = GetOptionalTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
get_eager_tensor_str += f" auto {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}" dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = "" parse_attributes_str = ""
...@@ -267,6 +276,11 @@ if __name__ == "__main__": ...@@ -267,6 +276,11 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output'] fwd_returns_str = fwd_api['output']
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
# Collect Original Forward Inputs/Outputs and then perform validation checks # Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward( forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str) fwd_args_str, fwd_returns_str)
...@@ -283,7 +297,7 @@ if __name__ == "__main__": ...@@ -283,7 +297,7 @@ if __name__ == "__main__":
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction( python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list, fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map) forward_outputs_position_map, optional_inputs)
python_c_function_list.append(python_c_function_str) python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str) python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str) print("Generated Python-C Function: ", python_c_function_str)
......
...@@ -555,6 +555,32 @@ PyObject* ToPyObject( ...@@ -555,6 +555,32 @@ PyObject* ToPyObject(
return dict; return dict;
} }
// For Final State Dygraph,
// We directly use paddle::optional(Tensor) as dispensable Tensor
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable) {
PyObject* obj = PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check(obj)) {
obj = PyTuple_GET_ITEM(obj, 0);
}
if (obj == nullptr || obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type, arg_name, arg_idx));
}
return {};
}
return paddle::make_optional<paddle::experimental::Tensor>(
reinterpret_cast<TensorObject*>(obj)->tensor);
}
// For Intermediate State Dygraph,
// we use an uninitialized Tensor to represent dispensable Tensor
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type, paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name, const std::string& arg_name,
PyObject* args, ssize_t arg_idx, PyObject* args, ssize_t arg_idx,
......
...@@ -89,10 +89,15 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) { ...@@ -89,10 +89,15 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
return result; return result;
} }
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type, paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name, const std::string& arg_name,
PyObject* args, ssize_t arg_idx, PyObject* args, ssize_t arg_idx,
bool dispensable = false); bool dispensable = false);
std::vector<paddle::experimental::Tensor> GetTensorListFromArgs( std::vector<paddle::experimental::Tensor> GetTensorListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args, const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false); ssize_t arg_idx, bool dispensable = false);
...@@ -102,6 +107,7 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type, ...@@ -102,6 +107,7 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type,
PyObject* args, PyObject* args,
ssize_t arg_idx, ssize_t arg_idx,
bool dispensable = false); bool dispensable = false);
std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs( std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args, const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false); ssize_t arg_idx, bool dispensable = false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册