未验证 提交 933db9d4 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Forword only add dygraph func (#45153)

* [Eager draft] forward_only interface migrate to autograd_api

* strings api add dygraph forward function

* rm useless comments

* draft version for check CI

* fix ci

* forward-only no need compute_require_grad and pass stop_gradient, rm useless comments

* polish yaml and using CPUPlace = phi::CPUPlace

* rm useless comments

* polish yaml and update some test case

* rm useless funcs

* polish eager_gen code

* polish code
上级 f706d95d
......@@ -38,7 +38,7 @@ add_custom_target(
COMMAND
"${PYTHON_EXECUTABLE}"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py"
"--api_yaml_path=${api_yaml_path}"
"--api_yaml_path=${api_yaml_path},${fwd_api_yaml_path}"
"--backward_yaml_path=${backward_yaml_path}"
"--forwards_cc_path=${tmp_forwards_cc_path}"
"--forwards_h_path=${tmp_forwards_h_path}"
......
......@@ -353,6 +353,9 @@ class FunctionGeneratorBase:
self.forward_api_contents = forward_api_contents
self.namespace = namespace
self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True
self.forward_api_name = ""
self.orig_forward_inputs_list = [
......
......@@ -209,6 +209,26 @@ FORWARD_FUNCTION_TEMPLATE = \
}}
"""
FORWARD_ONLY_FUNCTION_TEMPLATE = \
"""
{} {}({}) {{
// Dygraph Record Event
{}
// AMP Logic
{}
// Forward API Call
VLOG(3) << \"Final State Running: \" << \"{}\";
{}
// Get Outputs
{}
// Returns
return {};
}}
"""
FORWARD_BODY_TEMPLATE = \
""" if(require_any_grad) {{
{}
......@@ -297,6 +317,7 @@ FORWARD_CC_FILE_TEMPLATE = \
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/phi/api/include/strings_api.h"
#include "paddle/phi/api/include/sparse_api.h"
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
......@@ -321,6 +342,7 @@ FORWARD_H_FILE_TEMPLATE = \
#include "paddle/fluid/eager/to_static/run_program_op_func.h"
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
using CPUPlace = phi::CPUPlace;
{}
{}
"""
......@@ -406,6 +428,27 @@ CHECK_NAN_AND_INF_TEMPLATE = \
""" if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
"""
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
inplace_optional_out_type_map = {
"Tensor":
"paddle::optional<paddle::experimental::Tensor>&",
"std::vector<Tensor>":
"paddle::optional<std::vector<paddle::experimental::Tensor>>&"
}
#######################
## Generator Helpers ##
......@@ -513,15 +556,16 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
), "Unable to find \"args\" in api.yaml"
assert 'output' in forward_api_contents.keys(
), "Unable to find \"output\" in api.yaml"
assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml"
assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"
if grad_api_contents is not None:
assert 'backward' in forward_api_contents.keys(
), "Unable to find \"backward\" in api.yaml"
assert 'args' in grad_api_contents.keys(
), "Unable to find \"args\" in backward.yaml"
assert 'output' in grad_api_contents.keys(
), "Unable to find \"output\" in backward.yaml"
assert 'forward' in grad_api_contents.keys(
), "Unable to find \"forward\" in backward.yaml"
def ForwardsValidationCheck(self):
forward_inputs_list = self.forward_inputs_list
......@@ -629,6 +673,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
backward_forward_str)
def CollectForwardInfoFromYamlForward(self):
self.forward_inputs_list, self.forward_attrs_list, self.forward_returns_list = ParseYamlForwardFromBackward(
self.forward_api_contents['args'] + " -> " +
self.forward_api_contents['output'])
def SlotNameMatching(self):
backward_inputs_list = self.backward_inputs_list
backward_returns_list = self.backward_returns_list
......@@ -694,6 +743,14 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
backward_output_pos
]
def GetPassStopGradientArgsList(self, forward_outputs_position_map):
pass_stop_gradient_args_list = ["false"]
for name, (_, _) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
return pass_stop_gradient_args_str
def GenerateNodeCreationCodes(self, for_backward=False):
forward_api_name = self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
......@@ -706,11 +763,8 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
optional_inputs = self.optional_inputs
# Pass Stop Gradient Args
pass_stop_gradient_args_list = ["false"]
for name, (_, _) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
pass_stop_gradient_args_str = self.GetPassStopGradientArgsList(
forward_outputs_position_map)
# Node Construction
num_backward_inputs = len(forward_outputs_position_map.keys())
......@@ -851,10 +905,10 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
##########################
# Parse forward and backward inplace_map
self.ParseForwardInplaceInfo()
self.ParseBackwardInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
if self.grad_api_contents is not None:
self.ParseBackwardInplaceInfo()
# Parse no_need_buffer
self.ParseNoNeedBuffer()
# Parse optional_inputs
self.ParseDispensable()
......@@ -863,11 +917,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.ParseIntermediate()
self.IntermediateValidationCheck()
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo()
if self.grad_api_contents is not None:
# Initialize backward_forward_str, backward_inputs_list, backward_attrs_list, backward_returns_list
self.CollectBackwardInfo()
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents()
# Initialize forward_inputs_list, forward_attrs_list, forward_returns_list
self.CollectForwardInfoFromBackwardContents()
if self.is_forward_only:
self.CollectForwardInfoFromYamlForward()
# Initialize orig_forward_inputs_list, orig_forward_attrs_list, orig_forward_returns_list
self.CollectOriginalForwardInfo()
......@@ -882,11 +940,11 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
self.DetermineForwardPositionMap(self.forward_inputs_list,
self.forward_returns_list)
# Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
self.SlotNameMatching()
# Backward Validation Check
self.BackwardValidationCheck()
if self.grad_api_contents is not None:
# Initialize backward_forward_inputs_map, backward_grad_inputs_map, backward_grad_outputs_map
self.SlotNameMatching()
# Backward Validation Check
self.BackwardValidationCheck()
class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
......@@ -909,7 +967,8 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_grad_outputs_map = self.backward_grad_outputs_map
if not self.is_forward_only:
backward_grad_outputs_map = self.backward_grad_outputs_map
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
......@@ -934,7 +993,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
is_optional = (name in optional_inputs)
if IsPlainTensorType(ttype):
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
if self.is_forward_only and is_inplaced and forward_inplace_map and name in forward_inplace_map.keys(
):
arg_str = f"paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
amp_tensors_vector_optional_list.append(
f"if ({name}) amp_tensors_vector.push_back({{ *{name} }});\n"
)
......@@ -1028,15 +1091,27 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
if IsPlainTensorType(rtype):
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
):
returns_type_list[pos] = "paddle::experimental::Tensor&"
ind = list(forward_inplace_map.values()).index(name)
if list(forward_inplace_map.keys()
)[ind] in self.optional_inputs:
returns_type_list[pos] = inplace_optional_out_type_map[
rtype]
else:
returns_type_list[pos] = "paddle::experimental::Tensor&"
else:
returns_type_list[pos] = "paddle::experimental::Tensor"
else:
assert IsVectorTensorType(rtype)
if is_inplaced and forward_inplace_map and name in forward_inplace_map.values(
):
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&"
ind = list(forward_inplace_map.values()).index(name)
if list(forward_inplace_map.keys()
)[ind] in self.optional_inputs:
returns_type_list[pos] = inplace_optional_out_type_map[
rtype]
else:
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>&"
else:
returns_type_list[
pos] = "std::vector<paddle::experimental::Tensor>"
......@@ -1052,56 +1127,64 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
# Node Creation Pre-Processing
# 1. Get Input AutoGradMeta
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
# Has corresponding grad output
has_corresponding_grad_output = False
for _, (_, corresponding_pos,
_) in backward_grad_outputs_map.items():
if pos == corresponding_pos:
has_corresponding_grad_output = True
if has_corresponding_grad_output or (
name in forward_inplace_map
and forward_api_name not in inplace_check_blacklist):
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
name)
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
if not self.is_forward_only:
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
# Has corresponding grad output
has_corresponding_grad_output = False
if not self.is_forward_only:
for _, (_, corresponding_pos,
_) in backward_grad_outputs_map.items():
if pos == corresponding_pos:
has_corresponding_grad_output = True
if has_corresponding_grad_output or (
name in forward_inplace_map and forward_api_name
not in inplace_check_blacklist) or self.is_forward_only:
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
name)
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(
input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(
compute_require_grad_args_list)
# 2. Get Output AutoGradMeta
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
if not self.is_forward_only:
outputs_autograd_meta_list = []
num_fwd_outputs = len(forward_outputs_position_map.keys())
for name, (rtype, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f"{indent}egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&{name});"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&{name});\n"
output_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
outputs_autograd_meta_list.append(output_autograd_meta)
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
# 3. Check Inplace
check_inplace_str = ""
......@@ -1117,8 +1200,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inplace_name, inplace_name)
# Node Creation
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str
if not self.is_forward_only:
self.GenerateNodeCreationCodes()
node_creation_str = self.node_creation_str
else:
node_creation_str = ""
dygraph_event_str = f"{indent}paddle::platform::RecordEvent dygraph_entrance_record_event(\"{forward_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);\n"
forward_function_name = GetDygraphForwardFunctionName(forward_api_name)
......@@ -1144,13 +1230,30 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
amp_autocast_list_str, amp_call_str)
# Generate forward_definition_str and forward_declaration_str
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_function_name, forward_call_str, check_nan_inf_str,
get_outputs_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
if not self.is_forward_only:
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, amp_logic_str,
inputs_autograd_meta_str, forward_function_name,
forward_call_str, check_nan_inf_str, get_outputs_str,
outputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, bump_inplace_version_str, node_creation_str,
returns_str)
else:
if (len(amp_tensors_vector_list) > 0) and (self.forward_api_name
not in no_amp_list):
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str,
amp_logic_str, forward_function_name, forward_call_str,
get_outputs_str, returns_str)
else:
self.forward_definition_str += FORWARD_ONLY_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name,
inputs_args_definition_str, dygraph_event_str, " ",
forward_function_name, forward_call_str, get_outputs_str,
returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
def GenerateInplacedForwardDygraphFunctions(self):
......@@ -1648,11 +1751,18 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
self.node_declaration_str = ""
self.node_definition_str = ""
def CollectIsForwardOnly(self, forward_api_contents):
self.is_forward_only = False if 'backward' in forward_api_contents.keys(
) else True
def ParseYamlContents(self):
self.ParseForwardYamlContents()
backward_yaml_path = self.backward_yaml_path
self.grad_api_dict = ReadBwdFile(backward_yaml_path)
# string api is forward_only, no backward_yaml respectively
if backward_yaml_path is not None:
self.grad_api_dict = ReadBwdFile(backward_yaml_path)
def GetBackwardAPIContents(self, forward_api_contents):
grad_api_dict = self.grad_api_dict
......@@ -1674,9 +1784,13 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
for forward_api_contents in forward_api_list:
if forward_api_contents['api'] in black_ops_list: continue
backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents)
if backward_api_contents is None: continue
self.CollectIsForwardOnly(forward_api_contents)
if self.is_forward_only:
backward_api_contents = None
else:
backward_api_contents = self.GetBackwardAPIContents(
forward_api_contents)
# Generate Dygraph Forward Function
function_generator = DygraphForwardFunctionGenerator(
......@@ -1688,6 +1802,8 @@ class DygraphForwardAndNodesGenerator(GeneratorBase):
# Generate Dygraph GradNode Function
while True:
if backward_api_contents is None:
break
next_grad_api_contents = self.GetBackwardAPIContents(
backward_api_contents)
......@@ -1787,7 +1903,12 @@ if __name__ == "__main__":
for i in range(len(api_yaml_paths)):
api_yaml_path = api_yaml_paths[i]
backward_yaml_path = backward_yaml_paths[i]
# string api is forwrad only
if not api_yaml_path.endswith('strings_api.yaml'):
backward_yaml_path = backward_yaml_paths[i]
else:
backward_yaml_path = None
generator = DygraphForwardAndNodesGenerator(api_yaml_path,
backward_yaml_path)
......
......@@ -51,20 +51,6 @@ atype_to_parsing_function = {
"paddle::experimental::DataType": "CastPyArg2DataType",
}
# This list contains ops that do not need to generate amp logic
# All optimizer ops in this list
no_amp_list = [
'adam_', 'adam', 'adamw_', 'adamw', 'average_accumulates',
'average_accumulates_', 'decayed_adagrad_', 'decayed_adagrad',
'dgc_momentum_', 'dgc_momentum', 'distributed_fused_lamb_',
'distributed_fused_lamb', 'dpsgd_', 'dpsgd', 'ftrl_', 'ftrl', 'lamb_',
'lamb', 'lars_momentum_', 'lars_momentum', 'merged_adam_', 'merged_adam',
'merged_momentum_', 'merged_momentum', 'momentum_', 'momentum',
'proximal_adagrad_', 'proximal_adagrad', 'proximal_gd_', 'proximal_gd',
'rmsprop_', 'rmsprop', 'sgd_', 'sgd', 'lamb_', 'lamb', 'assign_value_',
'sparse_momentum_', 'sparse_momentum', 'full_'
]
def FindParsingFunctionFromAttributeType(atype):
if atype not in atype_to_parsing_function.keys():
......@@ -131,41 +117,6 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});\n"
AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
decltype({}({})) out;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out = {}({});
}} else {{
out = {}({});
}}
"""
INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
using result_type = decltype({}({}));
std::unique_ptr<result_type> out_ptr;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out_ptr = std::make_unique<result_type>({}({}));
}} else {{
out_ptr = std::make_unique<result_type>({}({}));
}}
result_type& out = *out_ptr;
"""
FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{
......@@ -405,23 +356,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
num_args = len(
forward_inputs_position_map.keys()) + len(orig_forward_attrs_list)
dygraph_function_call_list = ["" for i in range(num_args)]
amp_dygraph_function_call_list = ["" for i in range(num_args)]
for name, (_, pos) in forward_inputs_position_map.items():
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"NEW_{name}"
for name, _, _, pos in orig_forward_attrs_list:
dygraph_function_call_list[pos] = f"{name}"
amp_dygraph_function_call_list[pos] = f"{name}"
dygraph_function_call_str = ",".join(dygraph_function_call_list)
amp_dygraph_function_call_str = ",".join(amp_dygraph_function_call_list)
# Generate Python-C Function Definitions
if is_forward_only:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace, forward_api_name)
else:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name))
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name))
return_str = " return ToPyObject(out);"
......@@ -429,82 +372,15 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
# Forward amp logic
amp_tensors_vector_list = []
amp_tensors_vector_optional_list = []
amp_autocast_list = []
amp_autocast_optional_list = []
for name, (ttype, pos) in forward_inputs_position_map.items():
is_optional = (name in optional_inputs)
if IsVectorTensorType(ttype):
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({name}.get());\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
amp_tensors_vector_list.append(f"{name}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCasts(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
if is_optional:
amp_tensors_vector_optional_list.append(
f"if ({name}.is_initialized()) amp_tensors_vector.push_back({{{name}.get()}});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false) : {name};\n"
)
else:
if forward_inplace_map and name in forward_inplace_map.keys(
):
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
else:
amp_tensors_vector_list.append(f"{{{name}}}")
amp_autocast_list.append(
f"auto NEW_{name} = egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name, false);\n"
)
amp_tensors_vector_list_str = "{ " + ",".join(
amp_tensors_vector_list) + " }"
amp_tensors_vector_optional_list_str = "".join(
amp_tensors_vector_optional_list)
amp_autocast_list_str = " ".join(
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
kernel_trans2_op_name_str = f"auto op_name = phi::TransToFluidOpName(\"{forward_api_name}\");"
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, fwd_function_name,
amp_dygraph_function_call_str, fwd_function_name,
dygraph_function_call_str)
# Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) >
0) and (forward_api_name not in no_amp_list):
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
amp_dygraph_function_str, return_str)
else:
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
noamp_dygraph_function_str, return_str)
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, set_device_str,
noamp_dygraph_function_str, return_str)
# Set prefix of forward_api_name to avoid conflicts
prefix = self.namespace.strip("::")
......@@ -518,27 +394,14 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
if forward_inplace_map:
inplaced_forward_api_name = GetInplacedFunctionName(
self.forward_api_name)
if is_forward_only:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"paddle::experimental::", namespace,
inplaced_forward_api_name)
else:
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
inplaced_fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace,
GetForwardFunctionName(inplaced_forward_api_name))
inplace_noamp_dygraph_function_str = NOAMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str)
inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
amp_autocast_list_str, inplaced_fwd_function_name,
amp_dygraph_function_call_str, inplaced_fwd_function_name,
dygraph_function_call_str)
return_str = " std::map<ssize_t, ssize_t> inplace_var_idx_map;"
for inplace_input, inplace_output in forward_inplace_map.items():
return_str += RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
......@@ -547,19 +410,11 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
return_str += " return ToPyObject(out, args, inplace_var_idx_map);"
# Generate Python-C Function Definetion
if (is_forward_only) and (len(amp_tensors_vector_list) > 0) and (
inplaced_forward_api_name not in no_amp_list):
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_amp_dygraph_function_str, return_str)
else:
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_noamp_dygraph_function_str, return_str)
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
inplace_noamp_dygraph_function_str, return_str)
python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, inplaced_forward_api_name, namespace,
......
......@@ -9,7 +9,7 @@
- api : bernoulli
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......
......@@ -184,7 +184,7 @@
- api : arange
args : (Tensor start, Tensor end, Tensor step, DataType dtype, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : ArangeInferMeta
param : [start, end, step]
......@@ -199,7 +199,7 @@
# arg_max
- api : argmax
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
kernel :
......@@ -208,7 +208,7 @@
# arg_min
- api : argmin
args : (Tensor x, int64_t axis, bool keepdims, bool flatten, int dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : ArgMinMaxInferMeta
kernel :
......@@ -366,7 +366,7 @@
# bitwise_and
- api : bitwise_and
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -375,7 +375,7 @@
# bitwise_not
- api : bitwise_not
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -384,7 +384,7 @@
# bitwise_or
- api : bitwise_or
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -393,7 +393,7 @@
# bitwise_xor
- api : bitwise_xor
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -557,7 +557,7 @@
- api : copy_to
args : (Tensor x, Place place, bool blocking)
output : Tensor
output : Tensor(out)
invoke : copy_to_impl(x, place, blocking)
# cos
......@@ -672,7 +672,7 @@
- api : diag_embed
args : (Tensor x, int offset, int dim1, int dim2)
output : Tensor
output : Tensor(out)
infer_meta :
func : DiagEmbedInferMeta
kernel :
......@@ -720,7 +720,7 @@
- api : eigvals
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : EigvalsInferMeta
kernel :
......@@ -773,7 +773,7 @@
- api : empty
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
......@@ -785,7 +785,7 @@
- api : empty_like
args : (Tensor x, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateLikeInferMeta
param : [x, dtype]
......@@ -797,7 +797,7 @@
- api : equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -805,7 +805,7 @@
- api : equal_all
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareAllInferMeta
kernel :
......@@ -986,7 +986,7 @@
- api : full
args : (IntArray shape, Scalar value, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateInferMeta
param : [shape, dtype]
......@@ -1012,7 +1012,7 @@
- api : full_batch_size_like
args : (Tensor input, int[] shape, DataType dtype, Scalar value, int input_dim_idx, int output_dim_idx, Place place=CPUPlace())
output: Tensor
output: Tensor(out)
infer_meta :
func : FullBatchSizeLikeInferMeta
param : [input, shape, value, dtype, input_dim_idx, output_dim_idx]
......@@ -1024,7 +1024,7 @@
- api : full_like
args : (Tensor x, Scalar value, DataType dtype = DataType::UNDEFINED, Place place = {})
output: Tensor
output: Tensor(out)
infer_meta :
func : CreateLikeInferMeta
param : [x, dtype]
......@@ -1058,7 +1058,7 @@
- api : gather_tree
args : (Tensor ids, Tensor parents)
output : Tensor
output : Tensor(out)
infer_meta :
func : GatherTreeMeta
kernel :
......@@ -1066,7 +1066,7 @@
- api : gaussian_random
args : (IntArray shape, float mean, float std, int seed, DataType dtype, Place place={})
output: Tensor
output: Tensor(out)
infer_meta :
func : GaussianRandomInferMeta
param : [shape, mean, std, seed, dtype]
......@@ -1118,7 +1118,7 @@
- api : greater_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1126,7 +1126,7 @@
- api : greater_than
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1211,7 +1211,7 @@
# histogram
- api : histogram
args : (Tensor x, int64_t bins, int min, int max)
output : Tensor
output : Tensor(out)
infer_meta :
func : HistogramInferMeta
kernel :
......@@ -1238,7 +1238,7 @@
# increment
- api : increment
args : (Tensor x, float value)
output : Tensor
output : Tensor(out)
infer_meta :
func : IncrementInferMeta
kernel :
......@@ -1288,7 +1288,7 @@
# is_empty
- api : is_empty
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsEmptyInferMeta
kernel :
......@@ -1306,7 +1306,7 @@
# isfinite
- api : isfinite
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1316,7 +1316,7 @@
# isinf
- api : isinf
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1326,7 +1326,7 @@
# isnan
- api : isnan
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : IsfiniteInferMeta
kernel :
......@@ -1419,7 +1419,7 @@
- api : less_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1427,7 +1427,7 @@
- api : less_than
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1446,7 +1446,7 @@
- api : linspace
args : (Tensor start, Tensor stop, Tensor number, DataType dtype)
output : Tensor
output : Tensor(out)
infer_meta :
func : LinspaceInferMeta
kernel :
......@@ -1520,7 +1520,7 @@
# logical_and
- api : logical_and
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1529,7 +1529,7 @@
# logical_not
- api : logical_not
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -1538,7 +1538,7 @@
# logical_or
- api : logical_or
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1547,7 +1547,7 @@
# logical_xor
- api : logical_xor
args : (Tensor x, Tensor y)
output : Tensor
output : Tensor(out)
infer_meta :
func : ElementwiseInferMeta
kernel :
......@@ -1827,7 +1827,7 @@
# multinomial
- api : multinomial
args : (Tensor x, int num_samples, bool replacement)
output : Tensor
output : Tensor(out)
infer_meta :
func : MultinomialInferMeta
kernel :
......@@ -1895,7 +1895,7 @@
- api : not_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
output : Tensor(out)
infer_meta :
func : CompareInferMeta
kernel :
......@@ -1903,7 +1903,7 @@
- api : one_hot
args : (Tensor x, Scalar(int) num_classes)
output : Tensor
output : Tensor(out)
infer_meta :
func : OneHotInferMeta
kernel :
......@@ -1911,12 +1911,12 @@
- api : ones
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor
output : Tensor(out)
invoke : full(shape, 1, dtype, place)
- api : ones_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place={})
output : Tensor
output : Tensor(out)
invoke : full_like(x, 1, dtype, place)
- api : p_norm
......@@ -2061,7 +2061,7 @@
- api : randperm
args : (int n, DataType dtype, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : RandpermInferMeta
param : [n, dtype]
......@@ -2322,7 +2322,7 @@
- api : shape
args : (Tensor input)
output : Tensor
output : Tensor(out)
infer_meta :
func : ShapeInferMeta
kernel :
......@@ -2334,7 +2334,7 @@
# shard_index
- api : shard_index
args : (Tensor in, int index_num, int nshards, int shard_id, int ignore_value)
output : Tensor
output : Tensor(out)
infer_meta :
func : ShardIndexInferMeta
kernel :
......@@ -2362,7 +2362,7 @@
- api : sign
args : (Tensor x)
output : Tensor
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
kernel :
......@@ -2401,7 +2401,7 @@
# size
- api : size
args : (Tensor x)
output : Tensor
output : Tensor(size)
infer_meta :
func : SizeInferMeta
kernel :
......@@ -2716,7 +2716,7 @@
# python API: paddle.nn.initializer.TruncatedNormal
- api : truncated_gaussian_random
args : (int[] shape, float mean, float std, int seed, DataType dtype=DataType::FLOAT32, Place place={})
output : Tensor
output : Tensor(out)
infer_meta :
func : TruncatedGaussianRandomInferMeta
param : [shape, mean, std, seed, dtype]
......@@ -2831,7 +2831,7 @@
# where_index
- api : where_index
args : (Tensor condition)
output : Tensor
output : Tensor(out)
infer_meta :
func : WhereIndexInferMeta
kernel :
......@@ -2861,12 +2861,12 @@
- api : zeros
args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace())
output : Tensor
output : Tensor(out)
invoke : full(shape, 0, dtype, place)
- api : zeros_like
args : (Tensor x, DataType dtype=DataType::UNDEFINED, Place place = {})
output : Tensor
output : Tensor(out)
invoke : full_like(x, 0, dtype, place)
- api: broadcast_tensors
......@@ -2881,7 +2881,7 @@
# dirichlet
- api: dirichlet
args: (Tensor alpha)
output: Tensor
output: Tensor(out)
infer_meta:
func: DirichletInferMeta
kernel:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册