From 52f07ab4c6ffe091ce7cb72a4097e030b45ec906 Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Sun, 27 Mar 2022 08:38:57 +0800 Subject: [PATCH] Fix amp with optiontional api bug (#40980) * fix amp with optiontional api bug * refine optional code for amp --- .../auto_code_generator/final_state_generator/eager_gen.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 33c01c83c47..b87e7d5f8c1 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -801,10 +801,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): if is_optional: arg_str = f"const paddle::optional {name}" amp_tensors_vector_optional_list.append( - f"if ({name}.is)initialized() amp_tensors_vector.push_back({name}.get()));\n" + f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n" ) amp_autocast_optional_list.append( - f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name) : {name};\n" + f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n" ) else: if inplace_map and name in inplace_map.keys(): @@ -895,7 +895,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): amp_tensors_vector_optional_list) amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n" amp_autocast_list_str = " ".join( - amp_autocast_list) + " ".join(amp_autocast_optional_list) + amp_autocast_list) + " " + " ".join( + amp_autocast_optional_list) amp_inputs_call_args_str = ", ".join(amp_inputs_call_list) amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});" if is_inplaced or (forward_api_name == "cast"): -- GitLab