未验证 提交 ca11a0e5 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Support dispensable inputs for eager final state codegen (#39743)

上级 96d530c1
......@@ -143,6 +143,11 @@ def IntermediateValidationCheck(intermediate_outputs, forward_returns_list):
assert pos in intermediate_positions
def ParseDispensable(string):
# string: "X, Y"
return [v.strip() for v in string.split(",")]
def ParseIntermediate(string):
return [v.strip() for v in string.split(",")]
......@@ -596,11 +601,11 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std:
return node_definition_str
def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list):
def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
......@@ -674,10 +679,17 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
# SetTensorWrappers
set_tensor_wrappers_list = []
for name, (_, is_fwd_input, _) in backward_fwd_input_map.items():
is_optional = (name in optional_inputs)
if is_fwd_input:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
......@@ -762,11 +774,12 @@ def GenerateNodeCreationCodes(fwd_api_name, bwd_api_name,
return node_creation_str
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs):
def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
......@@ -775,6 +788,7 @@ def GenerateForwardDefinition(
# backward_grad_input_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_grad_output_map = { "name" : [type, fwd_position, orig_position] ...}
# backward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = ["name0", ...]
# Get Function Args
num_inputs = len(forward_attrs_list) + len(forward_inputs_position_map.keys(
......@@ -784,17 +798,18 @@ def GenerateForwardDefinition(
inputs_call_list = ["" for i in range(num_inputs)]
for name, (ttype, pos) in forward_inputs_position_map.items():
inputs_call_list[pos] = f"{name}"
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype):
inputs_args_definition_list[
pos] = f"const paddle::experimental::Tensor& {name}"
inputs_args_declaration_list[
pos] = f"const paddle::experimental::Tensor& {name}"
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
inputs_args_definition_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_declaration_list[
pos] = f"const std::vector<paddle::experimental::Tensor>& {name}"
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
inputs_args_definition_list[pos] = arg_str
inputs_args_declaration_list[pos] = arg_str
for name, atype, default_val, pos in forward_attrs_list:
inputs_call_list[pos] = name
......@@ -849,7 +864,7 @@ def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list)
backward_grad_output_map, backward_attrs_list, optional_inputs)
FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{
......@@ -1053,6 +1068,12 @@ if __name__ == "__main__":
assert 'args' in bwd_api.keys()
assert 'output' in bwd_api.keys()
assert 'forward' in bwd_api.keys()
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
bwd_forward_str = bwd_api['forward']
bwd_args_str = bwd_api['args']
bwd_returns_str = bwd_api['output']
......@@ -1128,7 +1149,8 @@ if __name__ == "__main__":
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, intermediate_outputs)
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
......
......@@ -14,7 +14,7 @@
import os
import argparse
from eager_gen import ReadFwdFile, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
from eager_gen import ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
atype_to_parsing_function = {
"bool": "CastPyArg2Boolean",
......@@ -70,10 +70,12 @@ def FindParsingFunctionFromAttributeType(atype):
def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
forward_attrs_list, forward_outputs_position_map):
forward_attrs_list, forward_outputs_position_map,
optional_inputs):
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
# forward_attrs_list = [ [attr_name, attr_type, default_value, orig_position], ...]
# optional_inputs = [name0, ...]
# Get EagerTensor from args
# Get dygraph function call args
......@@ -82,7 +84,14 @@ def GeneratePythonCFunction(fwd_api_name, forward_inputs_position_map,
dygraph_function_call_list = ["" for i in range(num_args)]
get_eager_tensor_str = ""
for name, (ttype, pos) in forward_inputs_position_map.items():
get_eager_tensor_str += f" auto& {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
get_eager_tensor_str += f" auto {name} = GetTensorListFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
if is_optional:
get_eager_tensor_str += f" auto {name} = GetOptionalTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
else:
get_eager_tensor_str += f" auto {name} = GetTensorFromArgs(\"{fwd_api_name}\", \"{name}\", args, {pos}, false);\n"
dygraph_function_call_list[pos] = f"{name}"
parse_attributes_str = ""
......@@ -267,6 +276,11 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
# Parse Dispensable Inputs
optional_inputs = []
if 'optional' in fwd_api.keys():
optional_inputs = ParseDispensable(fwd_api['optional'])
# Collect Original Forward Inputs/Outputs and then perform validation checks
forward_inputs_list, forward_attrs_list, forward_returns_list = ParseYamlForward(
fwd_args_str, fwd_returns_str)
......@@ -283,7 +297,7 @@ if __name__ == "__main__":
python_c_function_str, python_c_function_reg_str = GeneratePythonCFunction(
fwd_api_name, forward_inputs_position_map, forward_attrs_list,
forward_outputs_position_map)
forward_outputs_position_map, optional_inputs)
python_c_function_list.append(python_c_function_str)
python_c_function_reg_list.append(python_c_function_reg_str)
print("Generated Python-C Function: ", python_c_function_str)
......
......@@ -555,6 +555,32 @@ PyObject* ToPyObject(
return dict;
}
// For Final State Dygraph,
// We directly use paddle::optional(Tensor) as dispensable Tensor
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable) {
PyObject* obj = PyTuple_GET_ITEM(args, arg_idx);
if (PyTuple_Check(obj)) {
obj = PyTuple_GET_ITEM(obj, 0);
}
if (obj == nullptr || obj == Py_None) {
if (!dispensable) {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s(): argument '%s' (position %d) must be Tensor, but got None",
op_type, arg_name, arg_idx));
}
return {};
}
return paddle::make_optional<paddle::experimental::Tensor>(
reinterpret_cast<TensorObject*>(obj)->tensor);
}
// For Intermediate State Dygraph,
// we use an uninitialized Tensor to represent dispensable Tensor
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
......
......@@ -89,10 +89,15 @@ PyObject* ToPyObject(const std::tuple<Args...>& out) {
return result;
}
paddle::optional<paddle::experimental::Tensor> GetOptionalTensorFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
paddle::experimental::Tensor& GetTensorFromArgs(const std::string& op_type,
const std::string& arg_name,
PyObject* args, ssize_t arg_idx,
bool dispensable = false);
std::vector<paddle::experimental::Tensor> GetTensorListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
......@@ -102,6 +107,7 @@ paddle::experimental::Tensor* GetTensorPtrFromArgs(const std::string& op_type,
PyObject* args,
ssize_t arg_idx,
bool dispensable = false);
std::vector<paddle::experimental::Tensor*> GetTensorPtrListFromArgs(
const std::string& op_type, const std::string& arg_name, PyObject* args,
ssize_t arg_idx, bool dispensable = false);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册