未验证 提交 acbb5dbe 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add amp support (#42035)

上级 c7a258fe
......@@ -220,6 +220,7 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) ||
paddle::platform::is_custom_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader
......
......@@ -276,9 +276,10 @@ def amp_guard(enable=True,
if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()):
tracer._expected_place.is_npu_place() or
tracer._expected_place.is_custom_place()):
warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, and NPUPlace, current place is %s, so it makes no effect.'
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
# For npu:
......@@ -293,6 +294,10 @@ def amp_guard(enable=True,
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
warnings.warn('MLUPlace only support float16 amp.')
enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place():
......
......@@ -107,9 +107,10 @@ class AmpScaler(object):
if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()):
tracer._expected_place.is_npu_place() or
tracer._expected_place.is_custom_place()):
warnings.warn(
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace and NPUPlace, current place is %s, so it makes no effect.'
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place)
enable = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册