未验证 提交 52f07ab4 编写于 作者: Z zhangbo9674 提交者: GitHub

Fix amp with optiontional api bug (#40980)

* fix amp with optiontional api bug

* refine optional code for amp
上级 0695e1ac
......@@ -801,10 +801,10 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
if is_optional:
arg_str = f"const paddle::optional<const paddle::experimental::Tensor&> {name}"
amp_tensors_vector_optional_list.append(
f"if ({name}.is)initialized() amp_tensors_vector.push_back({name}.get()));\n"
f"if ({name}.get_ptr() != nullptr) amp_tensors_vector.push_back({{ *({name}.get_ptr()) }});\n"
)
amp_autocast_optional_list.append(
f"auto NEW_{name} = {name}.is_initialized() ? egr::EagerAmpAutoCast(\"{name}\", {name}, amp_dst_dtype, op_name) : {name};\n"
f"auto NEW_{name} = ({name}.get_ptr() != nullptr) ? paddle::make_optional<const paddle::experimental::Tensor&>(egr::EagerAmpAutoCast(\"{name}\", *({name}.get_ptr()), amp_dst_dtype, op_name)) : {name};\n"
)
else:
if inplace_map and name in inplace_map.keys():
......@@ -895,7 +895,8 @@ class DygraphSingleFunctionGenerator(FunctionGeneratorBase):
amp_tensors_vector_optional_list)
amp_get_dst_dtype_str = f"auto amp_dst_dtype = egr::GetAmpDestDtype(op_name, amp_tensors_vector);\n"
amp_autocast_list_str = " ".join(
amp_autocast_list) + " ".join(amp_autocast_optional_list)
amp_autocast_list) + " " + " ".join(
amp_autocast_optional_list)
amp_inputs_call_args_str = ", ".join(amp_inputs_call_list)
amp_call_str = f"return {forward_function_name}({amp_inputs_call_args_str});"
if is_inplaced or (forward_api_name == "cast"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册