未验证 提交 c60acca4 编写于 作者: Z zyfncg 提交者: GitHub

Add assign_out_ yaml (#42833)

* add assign_out_ yaml

* fix final_state_assign

* fix inplace bug

* add inplace_check_blacklist for assign

* fix merge conflict
上级 c921a812
...@@ -174,7 +174,10 @@ def RecoverBaseNameOfInplaceFunction(function_name): ...@@ -174,7 +174,10 @@ def RecoverBaseNameOfInplaceFunction(function_name):
def GetInplacedFunctionName(function_name): def GetInplacedFunctionName(function_name):
return function_name + "_" inplace_func_name = function_name
if inplace_func_name[-1] != '_':
inplace_func_name += '_'
return inplace_func_name
def GetForwardFunctionName(string): def GetForwardFunctionName(string):
......
...@@ -34,6 +34,13 @@ from codegen_utils import FunctionGeneratorBase, GeneratorBase ...@@ -34,6 +34,13 @@ from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage, GetIndent from codegen_utils import AssertMessage, GetIndent
# Note: assign is a inplace api when parameter(output) isn't none,
# so we should check parameter(output) with rule of inplace.
# But because there is no check in old dygraph mode, in order to
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"])
########### ###########
## Utils ## ## Utils ##
...@@ -848,13 +855,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -848,13 +855,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
namespace = self.namespace namespace = self.namespace
if self.forward_api_name[-1] == '_' and not is_inplaced:
return
forward_api_name = GetInplacedFunctionName( forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name self.forward_api_name) if is_inplaced else self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list forward_attrs_list = self.forward_attrs_list
backward_grad_outputs_map = self.backward_grad_outputs_map
optional_inputs = self.optional_inputs optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs intermediate_outputs = self.intermediate_outputs
...@@ -994,17 +1003,26 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -994,17 +1003,26 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_list = [] inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"] compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items(): for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name) # Has corresponding grad output
if IsPlainTensorType(ttype): has_corresponding_grad_output = False
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" for _, (_, corresponding_pos,
else: _) in backward_grad_outputs_map.items():
assert IsVectorTensorType(ttype) if pos == corresponding_pos:
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) has_corresponding_grad_output = True
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" if has_corresponding_grad_output or (
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" name in forward_inplace_map and
forward_api_name not in inplace_check_blacklist):
inputs_autograd_meta_list.append(input_autograd_meta) input_autograd_meta_name = GetAutoGradMetaName(name)
compute_require_grad_args_list.append(input_autograd_meta_name) if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
name)
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list) compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
...@@ -1038,9 +1056,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1038,9 +1056,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str = "" bump_inplace_version_str = ""
if is_inplaced: if is_inplaced:
for inplace_name in forward_inplace_map.keys(): for inplace_name in forward_inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) if forward_api_name not in inplace_check_blacklist:
check_inplace_str += CHECK_INPLACE_TEMPLATE.format( inplace_autograd_meta_name = GetAutoGradMetaName(
inplace_name, inplace_autograd_meta_name) inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name) inplace_name, inplace_name)
......
...@@ -381,7 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -381,7 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
break break
# Generate Python-C Function Definetion # Generate Python-C Function Definetion
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str, inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str, parse_attributes_str, set_device_str,
...@@ -389,11 +389,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): ...@@ -389,11 +389,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_fwd_function_name, dygraph_function_call_str, inplaced_fwd_function_name, dygraph_function_call_str,
return_str) return_str)
# Generate Python-C Function Registration python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, inplaced_forward_api_name, namespace, forward_api_name_prefix, inplaced_forward_api_name, namespace,
inplaced_forward_api_name, inplaced_forward_api_name) inplaced_forward_api_name, inplaced_forward_api_name)
# self.forward_api_name ending with '_' means it only has inplace api
if self.forward_api_name[-1] == '_':
self.python_c_function_str = python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str = python_c_inplace_func_reg_str
else:
self.python_c_function_str += python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str
def run(self): def run(self):
# Initialized is_forward_only # Initialized is_forward_only
self.CollectIsForwardOnly() self.CollectIsForwardOnly()
......
...@@ -1510,12 +1510,14 @@ def assign(x, output=None): ...@@ -1510,12 +1510,14 @@ def assign(x, output=None):
# isinstance(VarBase, Variable) == False. It will cause return None # isinstance(VarBase, Variable) == False. It will cause return None
# after this api. # after this api.
if isinstance(input, (Variable, core.VarBase)): if isinstance(input, (Variable, core.VarBase)):
if _non_static_mode(): if in_dygraph_mode():
if output is None: if output is None:
if _in_legacy_dygraph(): output = _C_ops.final_state_assign(input)
output = core.VarBase() else:
else: _C_ops.final_state_assign_out_(input, output)
output = core.eager.Tensor() elif _in_legacy_dygraph():
if output is None:
output = core.VarBase()
_C_ops.assign(input, output) _C_ops.assign(input, output)
else: else:
check_dtype(input.dtype, 'input', [ check_dtype(input.dtype, 'input', [
...@@ -1575,7 +1577,7 @@ def assign(x, output=None): ...@@ -1575,7 +1577,7 @@ def assign(x, output=None):
value_name: values value_name: values
}) })
if is_inplace and _non_static_mode(): if is_inplace and _in_legacy_dygraph():
output._bump_inplace_version() output._bump_inplace_version()
return output return output
......
...@@ -189,6 +189,18 @@ ...@@ -189,6 +189,18 @@
func : assign func : assign
backward : assign_grad backward : assign_grad
- api : assign_out_
args : (Tensor x, Tensor output)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : assign
param : [x]
inplace : (output -> out)
backward : assign_out__grad
# atan # atan
- api : atan - api : atan
args : (Tensor x) args : (Tensor x)
......
...@@ -311,7 +311,7 @@ class BaseAPI(object): ...@@ -311,7 +311,7 @@ class BaseAPI(object):
view_map = {} view_map = {}
in_out_mapping_list = api_item_yaml[mode].split(',') in_out_mapping_list = api_item_yaml[mode].split(',')
for item in in_out_mapping_list: for item in in_out_mapping_list:
result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item) result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
in_val = result.group('in') in_val = result.group('in')
out_val = result.group('out') out_val = result.group('out')
assert in_val in self.inputs['names'], \ assert in_val in self.inputs['names'], \
...@@ -840,6 +840,8 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ ...@@ -840,6 +840,8 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
if self.is_base_api: if self.is_base_api:
api_code = self.gene_base_api_code() api_code = self.gene_base_api_code()
if len(self.inplace_map) > 0: if len(self.inplace_map) > 0:
if self.api[-1] == '_':
api_code = ""
api_code = api_code + self.gene_base_api_code(inplace_flag=True) api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code return api_code
......
...@@ -132,7 +132,15 @@ ...@@ -132,7 +132,15 @@
output : Tensor(x_grad) output : Tensor(x_grad)
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
param : [out_grad] kernel :
func : assign
- backward_api : assign_out__grad
forward : assign_out_ (Tensor x, Tensor output) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
kernel : kernel :
func : assign func : assign
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册