diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py index 5b48fb74f5383afba3a56fa3e13793ecac2ee110..9849dc48fc490a0618db6bde784a20e4ba30f7e9 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/codegen_utils.py @@ -174,7 +174,10 @@ def RecoverBaseNameOfInplaceFunction(function_name): def GetInplacedFunctionName(function_name): - return function_name + "_" + inplace_func_name = function_name + if inplace_func_name[-1] != '_': + inplace_func_name += '_' + return inplace_func_name def GetForwardFunctionName(string): diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 9bee8f5f2975354c1d428ab8dcc6cee892402007..403216813dd363cccd1bbc8165e70c881d0d78fa 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -34,6 +34,13 @@ from codegen_utils import FunctionGeneratorBase, GeneratorBase from codegen_utils import ops_to_fill_zero_for_empty_grads from codegen_utils import AssertMessage, GetIndent +# Note: assign is a inplace api when parameter(output) isn't none, +# so we should check parameter(output) with rule of inplace. +# But because there is no check in old dygraph mode, in order to +# keeping the code compatible, here we also skip inplace check in new dygraph temporarily, +# and this will be fixed in the futrue. +inplace_check_blacklist = set(["assign_out_"]) + ########### ## Utils ## @@ -848,13 +855,15 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): def GenerateForwardDefinitionAndDeclaration(self, is_inplaced): namespace = self.namespace - + if self.forward_api_name[-1] == '_' and not is_inplaced: + return forward_api_name = GetInplacedFunctionName( self.forward_api_name) if is_inplaced else self.forward_api_name forward_inputs_position_map = self.forward_inputs_position_map forward_outputs_position_map = self.forward_outputs_position_map forward_attrs_list = self.forward_attrs_list + backward_grad_outputs_map = self.backward_grad_outputs_map optional_inputs = self.optional_inputs intermediate_outputs = self.intermediate_outputs @@ -994,17 +1003,26 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): inputs_autograd_meta_list = [] compute_require_grad_args_list = ["trace_backward"] for name, (ttype, pos) in forward_inputs_position_map.items(): - input_autograd_meta_name = GetAutoGradMetaName(name) - if IsPlainTensorType(ttype): - input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" - else: - assert IsVectorTensorType(ttype) - input_autograd_meta_vec_name = GetAutoGradMetaVectorName(name) - input_autograd_meta = f"{indent}std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" - input_autograd_meta += f"{indent}std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" - - inputs_autograd_meta_list.append(input_autograd_meta) - compute_require_grad_args_list.append(input_autograd_meta_name) + # Has corresponding grad output + has_corresponding_grad_output = False + for _, (_, corresponding_pos, + _) in backward_grad_outputs_map.items(): + if pos == corresponding_pos: + has_corresponding_grad_output = True + if has_corresponding_grad_output or ( + name in forward_inplace_map and + forward_api_name not in inplace_check_blacklist): + input_autograd_meta_name = GetAutoGradMetaName(name) + if IsPlainTensorType(ttype): + input_autograd_meta = f"{indent}egr::AutogradMeta* {input_autograd_meta_name} = egr::EagerUtils::nullable_autograd_meta({name});" + else: + assert IsVectorTensorType(ttype) + input_autograd_meta_vec_name = GetAutoGradMetaVectorName( + name) + input_autograd_meta = f"{indent}std::vector {input_autograd_meta_vec_name} = egr::EagerUtils::nullable_autograd_meta({name});\n" + input_autograd_meta += f"{indent}std::vector* {input_autograd_meta_name} = &{input_autograd_meta_vec_name};" + inputs_autograd_meta_list.append(input_autograd_meta) + compute_require_grad_args_list.append(input_autograd_meta_name) inputs_autograd_meta_str = "\n".join(inputs_autograd_meta_list) compute_require_grad_args_str = ",".join(compute_require_grad_args_list) @@ -1038,9 +1056,11 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): bump_inplace_version_str = "" if is_inplaced: for inplace_name in forward_inplace_map.keys(): - inplace_autograd_meta_name = GetAutoGradMetaName(inplace_name) - check_inplace_str += CHECK_INPLACE_TEMPLATE.format( - inplace_name, inplace_autograd_meta_name) + if forward_api_name not in inplace_check_blacklist: + inplace_autograd_meta_name = GetAutoGradMetaName( + inplace_name) + check_inplace_str += CHECK_INPLACE_TEMPLATE.format( + inplace_name, inplace_autograd_meta_name) bump_inplace_version_str += BUMP_INPLACE_VERSION_TEMPLATE.format( inplace_name, inplace_name) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 602d38510c04f08702472ccf8359c4afb17e0032..c02400299dfa66085dc34867ae2cbf5e0a6928be 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -381,7 +381,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): break # Generate Python-C Function Definetion - self.python_c_function_str += PYTHON_C_FUNCTION_TEMPLATE.format( + python_c_inplace_func_str = PYTHON_C_FUNCTION_TEMPLATE.format( inplaced_forward_api_name, pythonc_record_event_str, inplaced_forward_api_name, get_eager_tensor_str, parse_attributes_str, set_device_str, @@ -389,11 +389,20 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): inplaced_fwd_function_name, dygraph_function_call_str, return_str) - # Generate Python-C Function Registration - self.python_c_function_reg_str += "\n," + PYTHON_C_FUNCTION_REG_TEMPLATE.format( + python_c_inplace_func_reg_str = PYTHON_C_FUNCTION_REG_TEMPLATE.format( forward_api_name_prefix, inplaced_forward_api_name, namespace, inplaced_forward_api_name, inplaced_forward_api_name) + # self.forward_api_name ending with '_' means it only has inplace api + if self.forward_api_name[-1] == '_': + self.python_c_function_str = python_c_inplace_func_str + # Generate Python-C Function Registration + self.python_c_function_reg_str = python_c_inplace_func_reg_str + else: + self.python_c_function_str += python_c_inplace_func_str + # Generate Python-C Function Registration + self.python_c_function_reg_str += "\n," + python_c_inplace_func_reg_str + def run(self): # Initialized is_forward_only self.CollectIsForwardOnly() diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index c7e73cec47bead602e4a93b9e436c1a90dca9fa0..d3430ba81b8599ff6c696656f49a5a0a6cc8408b 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1510,12 +1510,14 @@ def assign(x, output=None): # isinstance(VarBase, Variable) == False. It will cause return None # after this api. if isinstance(input, (Variable, core.VarBase)): - if _non_static_mode(): + if in_dygraph_mode(): if output is None: - if _in_legacy_dygraph(): - output = core.VarBase() - else: - output = core.eager.Tensor() + output = _C_ops.final_state_assign(input) + else: + _C_ops.final_state_assign_out_(input, output) + elif _in_legacy_dygraph(): + if output is None: + output = core.VarBase() _C_ops.assign(input, output) else: check_dtype(input.dtype, 'input', [ @@ -1575,7 +1577,7 @@ def assign(x, output=None): value_name: values }) - if is_inplace and _non_static_mode(): + if is_inplace and _in_legacy_dygraph(): output._bump_inplace_version() return output diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 6c15b4a0128330f71dde56e2e4f02244e1f13808..1a740f47f46f5b8008c7d25a6c4a537ef189565f 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -189,6 +189,18 @@ func : assign backward : assign_grad +- api : assign_out_ + args : (Tensor x, Tensor output) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + param : [x] + kernel : + func : assign + param : [x] + inplace : (output -> out) + backward : assign_out__grad + # atan - api : atan args : (Tensor x) diff --git a/python/paddle/utils/code_gen/api_base.py b/python/paddle/utils/code_gen/api_base.py index 96896b65f404115a6c4cbd6d227af8be06342a9c..ac9a4315937761a8260f7db4fc130055823fa904 100644 --- a/python/paddle/utils/code_gen/api_base.py +++ b/python/paddle/utils/code_gen/api_base.py @@ -311,7 +311,7 @@ class BaseAPI(object): view_map = {} in_out_mapping_list = api_item_yaml[mode].split(',') for item in in_out_mapping_list: - result = re.search(r"(?P\w+)\s*->\s(?P\w+)", item) + result = re.search(r"(?P\w+)\s*->\s*(?P\w+)", item) in_val = result.group('in') out_val = result.group('out') assert in_val in self.inputs['names'], \ @@ -840,6 +840,8 @@ PADDLE_API {self.get_return_type()} {self.api}({params_code}) {{ if self.is_base_api: api_code = self.gene_base_api_code() if len(self.inplace_map) > 0: + if self.api[-1] == '_': + api_code = "" api_code = api_code + self.gene_base_api_code(inplace_flag=True) return api_code diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index 9b3d2d94c9341eb07bd1f145281323a3dc21ecab..19343c5873db63aec180b3aa1d5a71f2118cee69 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -132,7 +132,15 @@ output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta - param : [out_grad] + kernel : + func : assign + +- backward_api : assign_out__grad + forward : assign_out_ (Tensor x, Tensor output) -> Tensor(out) + args : (Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta kernel : func : assign