未验证 提交 e32e4a1d 编写于 作者: Z zyfncg 提交者: GitHub

Fix bug of amp code-gen (#44570)

* fix bug of amp code_gen

* fix bug
上级 6b6f7a21
......@@ -172,6 +172,25 @@ AMP_DYGRAPH_FUNCTION_TEMPLATE = \
}}
"""
INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE = \
"""
using result_type = decltype({}({}));
std::unique_ptr<result_type> out_ptr;
// AMP Logic
if (egr::Controller::Instance().GetAMPLevel() != paddle::imperative::AmpLevel::O0) {{
VLOG(5) << "Check and Prepare For AMP";
{}
paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallVectorSize> amp_tensors_vector = {};
{}
{}
{}
out_ptr = std::make_unique<result_type>({}({}));
}} else {{
out_ptr = std::make_unique<result_type>({}({}));
}}
result_type& out = *out_ptr;
"""
FUNCTION_SET_DEVICE_TEMPLATE = \
"""{} if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -531,7 +550,7 @@ class PythonCSingleFunctionGenerator(FunctionGeneratorBase):
inplaced_fwd_function_name, dygraph_function_call_str,
inplaced_fwd_function_name, dygraph_function_call_str)
inplace_amp_dygraph_function_str = AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplace_amp_dygraph_function_str = INPLACE_AMP_DYGRAPH_FUNCTION_TEMPLATE.format(
inplaced_fwd_function_name, dygraph_function_call_str,
kernel_trans2_op_name_str, amp_tensors_vector_list_str,
amp_tensors_vector_optional_list_str, amp_get_dst_dtype_str,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册