未验证 提交 9793fc5a 编写于 作者: P pangyoki 提交者: GitHub

support inplace in dygraph eager_final state (#40695)

* support inplace in eager_final state

* little change

* little bug
上级 dd9d7206
...@@ -56,6 +56,14 @@ def ParseArguments(): ...@@ -56,6 +56,14 @@ def ParseArguments():
################# #################
### Helpers ### ### Helpers ###
################# #################
def RecoverBaseNameOfInplaceFunction(function_name):
return function_name[:-1]
def GetInplacedFunctionName(function_name):
return function_name + "_"
def FindGradName(string): def FindGradName(string):
return string + "_grad" return string + "_grad"
...@@ -149,6 +157,24 @@ def ReadBwdFile(filepath): ...@@ -149,6 +157,24 @@ def ReadBwdFile(filepath):
###################### ######################
### Yaml Parsers ### ### Yaml Parsers ###
###################### ######################
def ParseInplaceInfo(string):
# string: "(x -> out0), (y -> out2)"
inplace_map = {}
for pair in string.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
inplace_map[key] = val
return inplace_map
def RemoveSpecialSymbolsInName(string): def RemoveSpecialSymbolsInName(string):
# Remove any name after '@' # Remove any name after '@'
ret = string.split("@")[0] ret = string.split("@")[0]
...@@ -683,9 +709,10 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std: ...@@ -683,9 +709,10 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std:
def GenerateNodeCreationCodes( def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs): backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -722,19 +749,19 @@ def GenerateNodeCreationCodes( ...@@ -722,19 +749,19 @@ def GenerateNodeCreationCodes(
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1: if num_fwd_outputs == 1:
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else: else:
# Tuple api_result # Tuple api_result
if IsPlainTensorType(rtype): if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));" output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else: else:
assert IsVectorTensorType(rtype) assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n" output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};" output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta) outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name) pass_stop_gradient_args_list.append(output_autograd_meta_name)
...@@ -743,16 +770,34 @@ def GenerateNodeCreationCodes( ...@@ -743,16 +770,34 @@ def GenerateNodeCreationCodes(
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list) outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list) pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += f"""
// Check Inplace
egr::EagerUtils::CheckInplace({inplace_name}, {inplace_autograd_meta_name}, require_any_grad);\n
"""
bump_inplace_version_str += f"""
// Bump Inplace Version
{inplace_name}.bump_inplace_version();
VLOG(3) << \"Tensor(\" << {inplace_name}.name() << \") uses Inplace Strategy.\";\n
"""
# Node Construction # Node Construction
num_bwd_inputs = len(backward_grad_input_map.keys()) num_bwd_inputs = len(backward_grad_input_map.keys())
num_bwd_outputs = len(backward_grad_output_map.keys()) num_bwd_outputs = len(backward_grad_output_map.keys())
grad_node_name = GetGradNodeName(fwd_api_name) grad_node_name = GetGradNodeName(
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});" RecoverBaseNameOfInplaceFunction(
fwd_api_name)) if inplace_map else GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
# SetAttributes # SetAttributes
set_attributes_list = [] set_attributes_list = []
for name, _, _, _ in backward_attrs_list: for name, _, _, _ in backward_attrs_list:
set_attributes = f" grad_node->SetAttribute{name}({name});" set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes_list.append(set_attributes) set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list) set_attributes_str = "\n".join(set_attributes_list)
...@@ -763,9 +808,9 @@ def GenerateNodeCreationCodes( ...@@ -763,9 +808,9 @@ def GenerateNodeCreationCodes(
if is_fwd_input: if is_fwd_input:
if is_optional: if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else: else:
if num_fwd_outputs > 1: if num_fwd_outputs > 1:
# Aligned with forward output position # Aligned with forward output position
...@@ -776,9 +821,9 @@ def GenerateNodeCreationCodes( ...@@ -776,9 +821,9 @@ def GenerateNodeCreationCodes(
tw_name = f"api_result" tw_name = f"api_result"
if is_optional: if is_optional:
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
else: else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);" set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers) set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list) set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
...@@ -787,8 +832,8 @@ def GenerateNodeCreationCodes( ...@@ -787,8 +832,8 @@ def GenerateNodeCreationCodes(
set_edges_list = [] set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items(): for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name) input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});" set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});" set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta) set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges) set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list) set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
...@@ -802,14 +847,14 @@ def GenerateNodeCreationCodes( ...@@ -802,14 +847,14 @@ def GenerateNodeCreationCodes(
num_outputs = len(forward_outputs_position_map.keys()) num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items(): for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name) output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});" set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);" set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1: if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);" set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});" set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else: else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));" set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});" set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank) set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history) set_history_list.append(set_history)
...@@ -821,55 +866,64 @@ def GenerateNodeCreationCodes( ...@@ -821,55 +866,64 @@ def GenerateNodeCreationCodes(
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list) set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list) set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
"""
node_creation_event_str = NODE_CREATION_TEMPLATE.format(node_event_name)
NODE_CREATION_TEMPLATE = """ NODE_CREATION_TEMPLATE = """
// Get AutoGradMeta // Get AutoGradMeta
{}
{} {}
bool trace_backward = egr::Controller::Instance().HasGrad(); bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({}); bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// Node Construction
{} {}
// Forward API Call
// SetAttributes {}
{} {}
{{
// SetTensorWrappers
{} {}
// SetGradOutMeta & SetEdges
{} {}
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// Node Construction
{} {}
// SetAttributes
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{} {}
// SetTensorWrappers
{} {}
// SetGradOutMeta & SetEdges
{} {}
{} {}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
}}
}} }}
""" """
node_creation_str = NODE_CREATION_TEMPLATE.format( node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, outputs_autograd_meta_str, inputs_autograd_meta_str, compute_require_grad_args_str,
compute_require_grad_args_str, pass_stop_gradient_args_str, check_inplace_str, forward_call_str, bump_inplace_version_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str, node_creation_event_str, outputs_autograd_meta_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str, set_history_str, pass_stop_gradient_args_str, node_construction_str, set_attributes_str,
set_grad_in_meta_str, set_retain_grad_str) set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str,
set_out_rank_str, set_history_str, set_grad_in_meta_str,
set_retain_grad_str)
return node_creation_str return node_creation_str
def GenerateForwardDefinition(fwd_api_name, bwd_api_name, def GenerateForwardDefinition(
forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, backward_grad_output_map, backward_attrs_list, optional_inputs,
optional_inputs, intermediate_outputs): intermediate_outputs, inplace_map):
# fwd_api_name = "" # fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] } # forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] } # forward_outputs_position_map = { "name" : [type, fwd_position] }
...@@ -893,7 +947,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -893,7 +947,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
if is_optional: if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}" arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else: else:
arg_str = f"const paddle::experimental::Tensor& {name}" if inplace_map and name in inplace_map.keys():
arg_str = f"paddle::experimental::Tensor& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}" arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
...@@ -956,26 +1013,16 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -956,26 +1013,16 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
node_creation_str = GenerateNodeCreationCodes( node_creation_str = GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map, fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list, forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs) backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map)
node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """{{\n
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
{}\n
}}"""
node_creation_str = NODE_CREATION_TEMPLATE.format(node_event_name,
node_creation_str)
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);" dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
FORWARD_FUNCTION_TEMPLATE = """ FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{ {} {}({}) {{
{} {}
// Forward API Call
{}
{} {}
...@@ -987,7 +1034,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name, ...@@ -987,7 +1034,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_function_name = GetForwardFunctionName(fwd_api_name) forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format( forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str, returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, forward_call_str, node_creation_str, returns_str) dygraph_event_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});" forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"
return forward_function_str, forward_function_declaration_str return forward_function_str, forward_function_declaration_str
...@@ -1189,6 +1236,10 @@ if __name__ == "__main__": ...@@ -1189,6 +1236,10 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args'] fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output'] fwd_returns_str = fwd_api['output']
inplace_map = {}
if 'inplace' in fwd_api.keys():
inplace_map = ParseInplaceInfo(fwd_api['inplace'])
bwd_api_name = fwd_api['backward'] bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys() assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name] bwd_api = grad_api_dict[bwd_api_name]
...@@ -1285,7 +1336,7 @@ if __name__ == "__main__": ...@@ -1285,7 +1336,7 @@ if __name__ == "__main__":
forward_outputs_position_map, orig_forward_attrs_list, forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map, backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs, backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs) intermediate_outputs, {})
print("Generated Forward Definition: ", forward_definition_str) print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str) print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0] yaml_forward_definition_str += definition_declaration_pair[0]
...@@ -1296,6 +1347,30 @@ if __name__ == "__main__": ...@@ -1296,6 +1347,30 @@ if __name__ == "__main__":
forward_outputs_position_map, forward_outputs_position_map,
orig_forward_attrs_list) orig_forward_attrs_list)
# Inplaced Version Dygraph Function Generation
if fwd_api_name != "sum" and "inplace" in fwd_api.keys():
fwd_api_name_inplaced = GetInplacedFunctionName(fwd_api_name)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name_inplaced, bwd_api_name,
forward_inputs_position_map, forward_outputs_position_map,
forward_attrs_list, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list, optional_inputs, intermediate_outputs,
inplace_map)
print("Generated Inplaced Forward Definition: ",
forward_definition_str)
print("Generated Inplaced Forward Declaration: ",
forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(
fwd_api_name_inplaced, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list)
if len(namespace) > 0: if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{ forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str} {yaml_forward_definition_str}
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import os import os
import argparse import argparse
import logging import logging
from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap, GetInplacedFunctionName, ParseInplaceInfo
########################### ###########################
## Global Configurations ## ## Global Configurations ##
...@@ -71,6 +71,14 @@ RECORD_EVENT_TEMPLATE = \ ...@@ -71,6 +71,14 @@ RECORD_EVENT_TEMPLATE = \
" paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::Operator, 1);" " paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::Operator, 1);"
RETURN_INPLACE_PYOBJECT_TEMPLATE = \
"""
ssize_t arg_id = GetIdxFromCoreOpsInfoMap(core_ops_final_state_args_info, \"final_state_{}\", \"{}\");
ssize_t return_id = GetIdxFromCoreOpsInfoMap(core_ops_final_state_returns_info, \"final_state_{}\", \"{}\");
return ToPyObject(out, return_id, args, arg_id);
"""
PYTHON_C_FUNCTION_TEMPLATE = \ PYTHON_C_FUNCTION_TEMPLATE = \
""" """
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs) static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
...@@ -94,7 +102,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj ...@@ -94,7 +102,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
PyEval_RestoreThread(tstate); PyEval_RestoreThread(tstate);
tstate = nullptr; tstate = nullptr;
return ToPyObject(out); {}
}} }}
catch(...) {{ catch(...) {{
if (tstate) {{ if (tstate) {{
...@@ -287,9 +295,10 @@ class PythonCSingleFunctionGenerator: ...@@ -287,9 +295,10 @@ class PythonCSingleFunctionGenerator:
self.forward_inputs_position_map, self.forward_outputs_position_map = DetermineForwardPositionMap( self.forward_inputs_position_map, self.forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list) forward_inputs_list, forward_returns_list)
def GeneratePythonCFunction(self): def GeneratePythonCFunction(self, inplace_map):
namespace = self.namespace namespace = self.namespace
forward_api_name = self.forward_api_name forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if inplace_map else self.forward_api_name
forward_attrs_list = self.forward_attrs_list forward_attrs_list = self.forward_attrs_list
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
...@@ -339,19 +348,31 @@ class PythonCSingleFunctionGenerator: ...@@ -339,19 +348,31 @@ class PythonCSingleFunctionGenerator:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format( fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name)) "::", namespace, GetForwardFunctionName(forward_api_name))
if inplace_map:
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
forward_api_name, inplace_input, forward_api_name,
inplace_output)
break
else:
return_str = " return ToPyObject(out);"
# Generate Record Event for performance profiling # Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format( pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func") "pythonc_record_event", forward_api_name, "pybind_imperative_func")
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format( self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name, forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, fwd_function_name, get_eager_tensor_str, parse_attributes_str, fwd_function_name,
dygraph_function_call_str) dygraph_function_call_str, return_str)
# Generate Python-C Function Registration # Generate Python-C Function Registration
self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name, namespace, forward_api_name, forward_api_name) forward_api_name, namespace, forward_api_name, forward_api_name)
def run(self): def run(self, inplace_map):
# Initialized is_forward_only # Initialized is_forward_only
self.CollectIsForwardOnly() self.CollectIsForwardOnly()
...@@ -382,7 +403,7 @@ class PythonCSingleFunctionGenerator: ...@@ -382,7 +403,7 @@ class PythonCSingleFunctionGenerator:
) )
# Code Generation # Code Generation
self.GeneratePythonCFunction() self.GeneratePythonCFunction(inplace_map)
logging.info( logging.info(
f"Generated Python-C Function: {self.python_c_function_str}") f"Generated Python-C Function: {self.python_c_function_str}")
logging.info( logging.info(
...@@ -414,12 +435,23 @@ class PythonCYamlGenerator: ...@@ -414,12 +435,23 @@ class PythonCYamlGenerator:
for forward_api_content in forward_api_list: for forward_api_content in forward_api_list:
f_generator = PythonCSingleFunctionGenerator(forward_api_content, f_generator = PythonCSingleFunctionGenerator(forward_api_content,
namespace) namespace)
status = f_generator.run() status = f_generator.run({})
if status == True: if status == True:
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n" self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator.python_c_function_str + "\n" self.python_c_functions_str += f_generator.python_c_function_str + "\n"
if 'inplace' in forward_api_content.keys():
inplace_map = ParseInplaceInfo(forward_api_content['inplace'])
f_generator_inplace = PythonCSingleFunctionGenerator(
forward_api_content, namespace)
status = f_generator_inplace.run(inplace_map)
if status == True:
self.python_c_functions_reg_str += f_generator_inplace.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator_inplace.python_c_function_str + "\n"
def InferNameSpace(self): def InferNameSpace(self):
yaml_path = self.yaml_path yaml_path = self.yaml_path
if "sparse" in yaml_path: if "sparse" in yaml_path:
......
...@@ -22,7 +22,7 @@ from ...tensor.math import multiply ...@@ -22,7 +22,7 @@ from ...tensor.math import multiply
import warnings import warnings
from ...fluid.layer_helper import LayerHelper from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import convert_np_dtype_to_dtype_ from ...fluid.framework import convert_np_dtype_to_dtype_, _in_eager_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle import paddle
from paddle import _C_ops, in_dynamic_mode from paddle import _C_ops, in_dynamic_mode
...@@ -576,6 +576,8 @@ def relu_(x, name=None): ...@@ -576,6 +576,8 @@ def relu_(x, name=None):
Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``. Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_relu`. Please refer to :ref:`api_nn_cn_relu`.
""" """
if _in_eager_mode():
return _C_ops.final_state_relu_(x)
return _C_ops.relu_(x) return _C_ops.relu_(x)
......
...@@ -167,6 +167,7 @@ ...@@ -167,6 +167,7 @@
kernel : kernel :
func : relu func : relu
inplace : (x -> out) inplace : (x -> out)
backward: relu_grad
- api : scale - api : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale) args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
......
...@@ -56,6 +56,16 @@ ...@@ -56,6 +56,16 @@
kernel : kernel :
func : abs_grad func : abs_grad
- backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : relu_grad
- backward_api : trunc_grad - backward_api : trunc_grad
forward : trunc (Tensor x) -> Tensor(out) forward : trunc (Tensor x) -> Tensor(out)
args : (Tensor out_grad) args : (Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册