未验证 提交 c60acca4 编写于 作者: Z zyfncg 提交者: GitHub

Add assign_out_ yaml (#42833)

* add assign_out_ yaml

* fix final_state_assign

* fix inplace bug

* add inplace_check_blacklist for assign

* fix merge conflict
上级 c921a812
......@@ -174,7 +174,10 @@ def RecoverBaseNameOfInplaceFunction(function_name):
def GetInplacedFunctionName(function_name):
return function_name + "_"
inplace_func_name = function_name
if inplace_func_name[-1] != '_':
inplace_func_name += '_'
return inplace_func_name
def GetForwardFunctionName(string):
......
......@@ -34,6 +34,13 @@ from codegen_utils import FunctionGeneratorBase, GeneratorBase
from codegen_utils import ops_to_fill_zero_for_empty_grads
from codegen_utils import AssertMessage, GetIndent
# Note: assign is a inplace api when parameter(output) isn't none,
# so we should check parameter(output) with rule of inplace.
# But because there is no check in old dygraph mode, in order to
# keeping the code compatible, here we also skip inplace check in new dygraph temporarily,
# and this will be fixed in the futrue.
inplace_check_blacklist = set(["assign_out_"])
###########
## Utils ##
......@@ -848,13 +855,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
def GenerateForwardDefinitionAndDeclaration(self, is_inplaced):
namespace = self.namespace
if self.forward_api_name[-1] == '_' and not is_inplaced:
return
forward_api_name = GetInplacedFunctionName(
self.forward_api_name) if is_inplaced else self.forward_api_name
forward_inputs_position_map = self.forward_inputs_position_map
forward_outputs_position_map = self.forward_outputs_position_map
forward_attrs_list = self.forward_attrs_list
backward_grad_outputs_map = self.backward_grad_outputs_map
optional_inputs = self.optional_inputs
intermediate_outputs = self.intermediate_outputs
......@@ -994,17 +1003,26 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
inputs_autograd_meta_list = []
compute_require_grad_args_list = ["trace_backward"]
for name, (ttype, pos) in forward_inputs_position_map.items():
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name)
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
# Has corresponding grad output
has_corresponding_grad_output = False
for _, (_, corresponding_pos,
_) in backward_grad_outputs_map.items():
if pos == corresponding_pos:
has_corresponding_grad_output = True
if has_corresponding_grad_output or (
name in forward_inplace_map and
forward_api_name not in inplace_check_blacklist):
input_autograd_meta_name = GetAutoGradMetaName(name)
if IsPlainTensorType(ttype):
input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});"
else:
assert IsVectorTensorType(ttype)
input_autograd_meta_vec_name = GetAutoGradMetaVectorName(
name)
input_autograd_meta = f"{indent}std::vector<egr::AutogradMeta*> {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n"
input_autograd_meta += f"{indent}std::vector<egr::AutogradMeta*>* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};"
inputs_autograd_meta_list.append(input_autograd_meta)
compute_require_grad_args_list.append(input_autograd_meta_name)
inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list)
compute_require_grad_args_str = ",".join(compute_require_grad_args_list)
......@@ -1038,9 +1056,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
bump_inplace_version_str = ""
if is_inplaced:
for inplace_name in forward_inplace_map.keys():
inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
if forward_api_name not in inplace_check_blacklist:
inplace_autograd_meta_name = GetAutoGradMetaName(
inplace_name)
check_inplace_str += CHECK_INPLACE_TEMPLATE.format(
inplace_name, inplace_autograd_meta_name)
bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format(
inplace_name, inplace_name)
......
......@@ -381,7 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
break
# Generate Python-C Function Definetion
self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format(
python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format(
inplaced_forward_api_name, pythonc_record_event_str,
inplaced_forward_api_name, get_eager_tensor_str,
parse_attributes_str, set_device_str,
......@@ -389,11 +389,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_fwd_function_name, dygraph_function_call_str,
return_str)
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format(
python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format(
forward_api_name_prefix, inplaced_forward_api_name, namespace,
inplaced_forward_api_name, inplaced_forward_api_name)
# self.forward_api_name ending with '_' means it only has inplace api
if self.forward_api_name[-1] == '_':
self.python_c_function_str = python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str = python_c_inplace_func_reg_str
else:
self.python_c_function_str += python_c_inplace_func_str
# Generate Python-C Function Registration
self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str
def run(self):
# Initialized is_forward_only
self.CollectIsForwardOnly()
......
......@@ -1510,12 +1510,14 @@ def assign(x, output=None):
# isinstance(VarBase, Variable) == False. It will cause return None
# after this api.
if isinstance(input, (Variable, core.VarBase)):
if _non_static_mode():
if in_dygraph_mode():
if output is None:
if _in_legacy_dygraph():
output = core.VarBase()
else:
output = core.eager.Tensor()
output = _C_ops.final_state_assign(input)
else:
_C_ops.final_state_assign_out_(input, output)
elif _in_legacy_dygraph():
if output is None:
output = core.VarBase()
_C_ops.assign(input, output)
else:
check_dtype(input.dtype, 'input', [
......@@ -1575,7 +1577,7 @@ def assign(x, output=None):
value_name: values
})
if is_inplace and _non_static_mode():
if is_inplace and _in_legacy_dygraph():
output._bump_inplace_version()
return output
......
......@@ -189,6 +189,18 @@
func : assign
backward : assign_grad
- api : assign_out_
args : (Tensor x, Tensor output)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : assign
param : [x]
inplace : (output -> out)
backward : assign_out__grad
# atan
- api : atan
args : (Tensor x)
......
......@@ -311,7 +311,7 @@ class BaseAPI(object):
view_map = {}
in_out_mapping_list = api_item_yaml[mode].split(',')
for item in in_out_mapping_list:
result = re.search(r"(?P<in>\w+)\s*->\s(?P<out>\w+)", item)
result = re.search(r"(?P<in>\w+)\s*->\s*(?P<out>\w+)", item)
in_val = result.group('in')
out_val = result.group('out')
assert in_val in self.inputs['names'], \
......@@ -840,6 +840,8 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{
if self.is_base_api:
api_code = self.gene_base_api_code()
if len(self.inplace_map) > 0:
if self.api[-1] == '_':
api_code = ""
api_code = api_code + self.gene_base_api_code(inplace_flag=True)
return api_code
......
......@@ -132,7 +132,15 @@
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : assign
- backward_api : assign_out__grad
forward : assign_out_ (Tensor x, Tensor output) -> Tensor(out)
args : (Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
kernel :
func : assign
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册