未验证 提交 9793fc5a 编写于 作者: P pangyoki 提交者: GitHub

support inplace in dygraph eager_final state (#40695)

* support inplace in eager_final state

* little change

* little bug
上级 dd9d7206
......@@ -56,6 +56,14 @@ def ParseArguments():
#################
### Helpers ###
#################
def RecoverBaseNameOfInplaceFunction(function_name):
return function_name[:-1]
def GetInplacedFunctionName(function_name):
return function_name + "_"
def FindGradName(string):
return string + "_grad"
......@@ -149,6 +157,24 @@ def ReadBwdFile(filepath):
######################
### Yaml Parsers ###
######################
def ParseInplaceInfo(string):
# string: "(x -> out0), (y -> out2)"
inplace_map = {}
for pair in string.split(","):
pair = pair.strip()
if pair.startswith("("):
pair = pair[1:]
if pair.endswith(")"):
pair = pair[:-1]
key = pair.split("->")[0].strip()
val = pair.split("->")[1].strip()
inplace_map[key] = val
return inplace_map
def RemoveSpecialSymbolsInName(string):
# Remove any name after '@'
ret = string.split("@")[0]
......@@ -683,9 +709,10 @@ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std:
def GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs):
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
......@@ -722,19 +749,19 @@ def GenerateNodeCreationCodes(
output_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
if num_fwd_outputs == 1:
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&api_result);"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
else:
# Tuple api_result
if IsPlainTensorType(rtype):
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
output_autograd_meta = f" egr::AutogradMeta* {output_autograd_meta_name} = egr::EagerUtils::autograd_meta(&std::get<{pos}>(api_result));"
else:
assert IsVectorTensorType(rtype)
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
output_autograd_meta = f" std::vector<egr::AutogradMeta*> {output_autograd_meta_vec_name} = egr::EagerUtils::autograd_meta(&api_result[{pos}]);\n"
output_autograd_meta += f" std::vector<egr::AutogradMeta*>* {output_autograd_meta_name} = &{output_autograd_meta_vec_name};"
outputs_autograd_meta_list.append(output_autograd_meta)
pass_stop_gradient_args_list.append(output_autograd_meta_name)
......@@ -743,16 +770,34 @@ def GenerateNodeCreationCodes(
outputs_autograd_meta_str = "\n".join(outputs_autograd_meta_list)
pass_stop_gradient_args_str = ",".join(pass_stop_gradient_args_list)
# Check Inplace
check_inplace_str = ""
bump_inplace_version_str = ""
for inplace_name in inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += f"""
// Check Inplace
egr::EagerUtils::CheckInplace({inplace_name}, {inplace_autograd_meta_name}, require_any_grad);\n
"""
bump_inplace_version_str += f"""
// Bump Inplace Version
{inplace_name}.bump_inplace_version();
VLOG(3) << \"Tensor(\" << {inplace_name}.name() << \") uses Inplace Strategy.\";\n
"""
# Node Construction
num_bwd_inputs = len(backward_grad_input_map.keys())
num_bwd_outputs = len(backward_grad_output_map.keys())
grad_node_name = GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
grad_node_name = GetGradNodeName(
RecoverBaseNameOfInplaceFunction(
fwd_api_name)) if inplace_map else GetGradNodeName(fwd_api_name)
node_construction_str = f" auto grad_node = std::make_shared<{grad_node_name}>({num_bwd_inputs}, {num_bwd_outputs});"
# SetAttributes
set_attributes_list = []
for name, _, _, _ in backward_attrs_list:
set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes = f" grad_node->SetAttribute{name}({name});"
set_attributes_list.append(set_attributes)
set_attributes_str = "\n".join(set_attributes_list)
......@@ -763,9 +808,9 @@ def GenerateNodeCreationCodes(
if is_fwd_input:
if is_optional:
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f" if({name}.is_initialized()) grad_node->SetTensorWrapper{name}({name}, true);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({name}, true);"
else:
if num_fwd_outputs > 1:
# Aligned with forward output position
......@@ -776,9 +821,9 @@ def GenerateNodeCreationCodes(
tw_name = f"api_result"
if is_optional:
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers = f" if({tw_name}.is_initialized()) grad_node->SetTensorWrapper{name}({tw_name}, false);"
else:
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers = f" grad_node->SetTensorWrapper{name}({tw_name}, false);"
set_tensor_wrappers_list.append(set_tensor_wrappers)
set_tensor_wrappers_str = "\n".join(set_tensor_wrappers_list)
......@@ -787,8 +832,8 @@ def GenerateNodeCreationCodes(
set_edges_list = []
for name, (_, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta = f" grad_node->SetGradOutMeta({name}, {pos});"
set_edges = f" grad_node->AddEdges({input_autograd_meta_name}, {pos});"
set_grad_out_meta_list.append(set_grad_out_meta)
set_edges_list.append(set_edges)
set_grad_out_meta_str = "\n".join(set_grad_out_meta_list)
......@@ -802,14 +847,14 @@ def GenerateNodeCreationCodes(
num_outputs = len(forward_outputs_position_map.keys())
for name, (_, pos) in forward_outputs_position_map.items():
output_autograd_meta_name = GetAutoGradMetaName(name)
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
set_out_rank = f" egr::EagerUtils::SetOutRankWithSlot({output_autograd_meta_name}, {pos});"
set_history = f" egr::EagerUtils::SetHistory({output_autograd_meta_name}, grad_node);"
if num_outputs == 1:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(api_result);"
set_grad_in_meta = f" grad_node->SetGradInMeta(api_result, {pos});"
else:
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_retain_grad = f" egr::EagerUtils::CheckAndRetainGrad(std::get<{pos}>(api_result));"
set_grad_in_meta = f" grad_node->SetGradInMeta(std::get<{pos}>(api_result), {pos});"
set_out_rank_list.append(set_out_rank)
set_history_list.append(set_history)
......@@ -821,55 +866,64 @@ def GenerateNodeCreationCodes(
set_grad_in_meta_str = "\n".join(set_grad_in_meta_list)
set_retain_grad_str = "\n".join(set_retain_grad_list)
node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
"""
node_creation_event_str = NODE_CREATION_TEMPLATE.format(node_event_name)
NODE_CREATION_TEMPLATE = """
// Get AutoGradMeta
{}
{}
bool trace_backward = egr::Controller::Instance().HasGrad();
bool require_any_grad = egr::EagerUtils::ComputeRequireGrad({});
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// Node Construction
{}
// SetAttributes
// Forward API Call
{}
{}
// SetTensorWrappers
{{
{}
// SetGradOutMeta & SetEdges
{}
if(require_any_grad) {{
egr::EagerUtils::PassStopGradient({});
// Node Construction
{}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
// SetAttributes
{}
// SetTensorWrappers
{}
// SetGradOutMeta & SetEdges
{}
{}
// SetOutRank & SetHistory & SetGradInMeta & RetainGrad
{}
{}
{}
{}
}}
}}
"""
node_creation_str = NODE_CREATION_TEMPLATE.format(
inputs_autograd_meta_str, outputs_autograd_meta_str,
compute_require_grad_args_str, pass_stop_gradient_args_str,
node_construction_str, set_attributes_str, set_tensor_wrappers_str,
set_grad_out_meta_str, set_edges_str, set_out_rank_str, set_history_str,
set_grad_in_meta_str, set_retain_grad_str)
inputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, forward_call_str, bump_inplace_version_str,
node_creation_event_str, outputs_autograd_meta_str,
pass_stop_gradient_args_str, node_construction_str, set_attributes_str,
set_tensor_wrappers_str, set_grad_out_meta_str, set_edges_str,
set_out_rank_str, set_history_str, set_grad_in_meta_str,
set_retain_grad_str)
return node_creation_str
def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list,
optional_inputs, intermediate_outputs):
def GenerateForwardDefinition(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs, inplace_map):
# fwd_api_name = ""
# forward_inputs_position_map = { "name" : [type, fwd_position] }
# forward_outputs_position_map = { "name" : [type, fwd_position] }
......@@ -893,7 +947,10 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
if is_optional:
arg_str = f"const paddle::optional<paddle::experimental::Tensor>& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
if inplace_map and name in inplace_map.keys():
arg_str = f"paddle::experimental::Tensor& {name}"
else:
arg_str = f"const paddle::experimental::Tensor& {name}"
else:
assert IsVectorTensorType(ttype)
arg_str = f"const std::vector<paddle::experimental::Tensor>& {name}"
......@@ -956,26 +1013,16 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
node_creation_str = GenerateNodeCreationCodes(
fwd_api_name, bwd_api_name, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list,
forward_outputs_position_map, forward_attrs_list, forward_call_str,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs)
node_event_name = fwd_api_name + " node_creation"
NODE_CREATION_TEMPLATE = """{{\n
paddle::platform::RecordEvent node_creation_record_event(\"{}\", paddle::platform::TracerEventType::Operator, 1);\n
{}\n
}}"""
node_creation_str = NODE_CREATION_TEMPLATE.format(node_event_name,
node_creation_str)
backward_grad_output_map, backward_attrs_list, optional_inputs,
inplace_map)
dygraph_event_str = f"paddle::platform::RecordEvent dygraph_entrance_record_event(\"{fwd_api_name} dygraph\", paddle::platform::TracerEventType::Operator, 1);"
FORWARD_FUNCTION_TEMPLATE = """
{} {}({}) {{
{}
// Forward API Call
{}
{}
......@@ -987,7 +1034,7 @@ def GenerateForwardDefinition(fwd_api_name, bwd_api_name,
forward_function_name = GetForwardFunctionName(fwd_api_name)
forward_function_str = FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, forward_call_str, node_creation_str, returns_str)
dygraph_event_str, node_creation_str, returns_str)
forward_function_declaration_str = f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});"
return forward_function_str, forward_function_declaration_str
......@@ -1189,6 +1236,10 @@ if __name__ == "__main__":
fwd_args_str = fwd_api['args']
fwd_returns_str = fwd_api['output']
inplace_map = {}
if 'inplace' in fwd_api.keys():
inplace_map = ParseInplaceInfo(fwd_api['inplace'])
bwd_api_name = fwd_api['backward']
assert bwd_api_name in grad_api_dict.keys()
bwd_api = grad_api_dict[bwd_api_name]
......@@ -1285,7 +1336,7 @@ if __name__ == "__main__":
forward_outputs_position_map, orig_forward_attrs_list,
backward_fwd_input_map, backward_grad_input_map,
backward_grad_output_map, backward_attrs_list, optional_inputs,
intermediate_outputs)
intermediate_outputs, {})
print("Generated Forward Definition: ", forward_definition_str)
print("Generated Forward Declaration: ", forward_declaration_str)
yaml_forward_definition_str += definition_declaration_pair[0]
......@@ -1296,6 +1347,30 @@ if __name__ == "__main__":
forward_outputs_position_map,
orig_forward_attrs_list)
# Inplaced Version Dygraph Function Generation
if fwd_api_name != "sum" and "inplace" in fwd_api.keys():
fwd_api_name_inplaced = GetInplacedFunctionName(fwd_api_name)
# Node Definition Generation
definition_declaration_pair = GenerateForwardDefinition(
fwd_api_name_inplaced, bwd_api_name,
forward_inputs_position_map, forward_outputs_position_map,
forward_attrs_list, backward_fwd_input_map,
backward_grad_input_map, backward_grad_output_map,
backward_attrs_list, optional_inputs, intermediate_outputs,
inplace_map)
print("Generated Inplaced Forward Definition: ",
forward_definition_str)
print("Generated Inplaced Forward Declaration: ",
forward_declaration_str)
forward_definition_str += definition_declaration_pair[0]
forward_declaration_str += definition_declaration_pair[1]
# For python-level API dispatch
CollectCoreOpsInformation(
fwd_api_name_inplaced, forward_inputs_position_map,
forward_outputs_position_map, forward_attrs_list)
if len(namespace) > 0:
forward_definition_str += f"""namespace {namespace} {{
{yaml_forward_definition_str}
......
......@@ -15,7 +15,7 @@
import os
import argparse
import logging
from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap
from eager_gen import namespace, yaml_types_mapping, ReadFwdFile, ParseDispensable, IsVectorTensorType, GetForwardFunctionName, ParseYamlForward, DetermineForwardPositionMap, GetInplacedFunctionName, ParseInplaceInfo
###########################
## Global Configurations ##
......@@ -71,6 +71,14 @@ RECORD_EVENT_TEMPLATE = \
" paddle::platform::RecordEvent {}(\"{} {}\", paddle::platform::TracerEventType::Operator, 1);"
RETURN_INPLACE_PYOBJECT_TEMPLATE = \
"""
ssize_t arg_id = GetIdxFromCoreOpsInfoMap(core_ops_final_state_args_info, \"final_state_{}\", \"{}\");
ssize_t return_id = GetIdxFromCoreOpsInfoMap(core_ops_final_state_returns_info, \"final_state_{}\", \"{}\");
return ToPyObject(out, return_id, args, arg_id);
"""
PYTHON_C_FUNCTION_TEMPLATE = \
"""
static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
......@@ -94,7 +102,7 @@ static PyObject * eager_final_state_api_{}(PyObject *self, PyObject *args, PyObj
PyEval_RestoreThread(tstate);
tstate = nullptr;
return ToPyObject(out);
{}
}}
catch(...) {{
if (tstate) {{
......@@ -287,9 +295,10 @@ class PythonCSingleFunctionGenerator:
self.forward_inputs_position_map, self.forward_outputs_position_map = DetermineForwardPositionMap(
forward_inputs_list, forward_returns_list)
def GeneratePythonCFunction(self):
def GeneratePythonCFunction(self, inplace_map):
namespace = self.namespace
forward_api_name = self.forward_api_name
forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if inplace_map else self.forward_api_name
forward_attrs_list = self.forward_attrs_list
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
......@@ -339,19 +348,31 @@ class PythonCSingleFunctionGenerator:
fwd_function_name = FUNCTION_NAME_TEMPLATE.format(
"::", namespace, GetForwardFunctionName(forward_api_name))
if inplace_map:
assert len(
inplace_map
) == 1, f"size of inplace_map must be 1, but inplace_map of \"{forward_api_name}\" op got {len(inplace_map)}"
for inplace_input, inplace_output in inplace_map.items():
return_str = RETURN_INPLACE_PYOBJECT_TEMPLATE.format(
forward_api_name, inplace_input, forward_api_name,
inplace_output)
break
else:
return_str = " return ToPyObject(out);"
# Generate Record Event for performance profiling
pythonc_record_event_str = RECORD_EVENT_TEMPLATE.format(
"pythonc_record_event", forward_api_name, "pybind_imperative_func")
self.python_c_function_str = PYTHON_C_FUNCTION_TEMPLATE.format(
forward_api_name, pythonc_record_event_str, forward_api_name,
get_eager_tensor_str, parse_attributes_str, fwd_function_name,
dygraph_function_call_str)
dygraph_function_call_str, return_str)
# Generate Python-C Function Registration
self.python_c_function_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name, namespace, forward_api_name, forward_api_name)
def run(self):
def run(self, inplace_map):
# Initialized is_forward_only
self.CollectIsForwardOnly()
......@@ -382,7 +403,7 @@ class PythonCSingleFunctionGenerator:
)
# Code Generation
self.GeneratePythonCFunction()
self.GeneratePythonCFunction(inplace_map)
logging.info(
f"Generated Python-C Function: {self.python_c_function_str}")
logging.info(
......@@ -414,12 +435,23 @@ class PythonCYamlGenerator:
for forward_api_content in forward_api_list:
f_generator = PythonCSingleFunctionGenerator(forward_api_content,
namespace)
status = f_generator.run()
status = f_generator.run({})
if status == True:
self.python_c_functions_reg_str += f_generator.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator.python_c_function_str + "\n"
if 'inplace' in forward_api_content.keys():
inplace_map = ParseInplaceInfo(forward_api_content['inplace'])
f_generator_inplace = PythonCSingleFunctionGenerator(
forward_api_content, namespace)
status = f_generator_inplace.run(inplace_map)
if status == True:
self.python_c_functions_reg_str += f_generator_inplace.python_c_function_reg_str + ",\n"
self.python_c_functions_str += f_generator_inplace.python_c_function_str + "\n"
def InferNameSpace(self):
yaml_path = self.yaml_path
if "sparse" in yaml_path:
......
......@@ -22,7 +22,7 @@ from ...tensor.math import multiply
import warnings
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import convert_np_dtype_to_dtype_
from ...fluid.framework import convert_np_dtype_to_dtype_, _in_eager_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
import paddle
from paddle import _C_ops, in_dynamic_mode
......@@ -576,6 +576,8 @@ def relu_(x, name=None):
Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_relu`.
"""
if _in_eager_mode():
return _C_ops.final_state_relu_(x)
return _C_ops.relu_(x)
......
......@@ -167,6 +167,7 @@
kernel :
func : relu
inplace : (x -> out)
backward: relu_grad
- api : scale
args : (Tensor x, Scalar scale, float bias, bool bias_after_scale)
......
......@@ -56,6 +56,16 @@
kernel :
func : abs_grad
- backward_api : relu_grad
forward : relu (Tensor x) -> Tensor(out)
args : (Tensor x, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : relu_grad
- backward_api : trunc_grad
forward : trunc (Tensor x) -> Tensor(out)
args : (Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册