diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 9d5706f65bdf0c3e4ca12897667a44748ad0db8e..2aa44de8497731285aa85adecaee4c125e1fa066 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -172,6 +172,25 @@ AMP_DYGRAPH_FUNCTION_TEMPLATE = \ }} """ +INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \ +""" + using result_type = decltype({}({})); + std::unique_ptr out_ptr; + // AMP Logic + if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{ + VLOG(5) << "Check and Prepare For AMP"; + {} + paddle::small_vector, egr::kSlotSmallVectorSize> amp_tensors_vector = {}; + {} + {} + {} + out_ptr = std::make_unique({}({})); + }} else {{ + out_ptr = std::make_unique({}({})); + }} + result_type& out = *out_ptr; +""" + FUNCTION_SET_DEVICE_TEMPLATE = \ """{} if (paddle::platform::is_gpu_place(place)) {{ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -531,7 +550,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): inplaced_fwd_function_name, dygraph_function_call_str, inplaced_fwd_function_name, dygraph_function_call_str) - inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format( + inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format( inplaced_fwd_function_name, dygraph_function_call_str, kernel_trans2_op_name_str, amp_tensors_vector_list_str, amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,