diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 33c01c83c47837cc441d4b1a3cdfaed26f3da2bb..b87e7d5f8c14c127fcd9ad4dce9a2b563e4403ae 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -801,10 +801,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): if is_optional: arg_str = f"const paddle::optional {name}" amp_tensors_vector_optional_list.append( - f"if ({name}.is)initialized() amp_tensors_vector.push_back({name}.get()));\n" + f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n" ) amp_autocast_optional_list.append( - f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name) : {name};\n" + f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n" ) else: if inplace_map and name in inplace_map.keys(): @@ -895,7 +895,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase): amp_tensors_vector_optional_list) amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n" amp_autocast_list_str = " ".join( - amp_autocast_list) + " ".join(amp_autocast_optional_list) + amp_autocast_list) + " " + " ".join( + amp_autocast_optional_list) amp_inputs_call_args_str = ", ".join(amp_inputs_call_list) amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});" if is_inplaced or (forward_api_name == "cast"):