From e32e4a1d20e5936cdf371a0d37a5e62e3c067c3a Mon Sep 17 00:00:00 2001 From: zyfncg Date: Mon, 25 Jul 2022 09:59:02 +0800 Subject: [PATCH] Fix bug of amp code-gen (#44570) * fix bug of amp code_gen * fix bug --- .../final_state_generator/python_c_gen.py | 21 ++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py index 9d5706f65bd..2aa44de8497 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/python_c_gen.py @@ -172,6 +172,25 @@ AMP_DYGRAPH_FUNCTION_TEMPLATE = \ }} """ +INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \ +""" + using result_type = decltype({}({})); + std::unique_ptr out_ptr; + // AMP Logic + if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{ + VLOG(5) << "Check and Prepare For AMP"; + {} + paddle::small_vector, egr::kSlotSmallVectorSize> amp_tensors_vector = {}; + {} + {} + {} + out_ptr = std::make_unique({}({})); + }} else {{ + out_ptr = std::make_unique({}({})); + }} + result_type& out = *out_ptr; +""" + FUNCTION_SET_DEVICE_TEMPLATE = \ """{} if (paddle::platform::is_gpu_place(place)) {{ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -531,7 +550,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase): inplaced_fwd_function_name, dygraph_function_call_str, inplaced_fwd_function_name, dygraph_function_call_str) - inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format( + inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format( inplaced_fwd_function_name, dygraph_function_call_str, kernel_trans2_op_name_str, amp_tensors_vector_list_str, amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str, -- GitLab