未验证 提交 acbb5dbe 编写于 作者: R ronnywang 提交者: GitHub

[CustomDevice] add amp support (#42035)

上级 c7a258fe
...@@ -220,6 +220,7 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) { ...@@ -220,6 +220,7 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place) || paddle::platform::is_xpu_place(place) ||
paddle::platform::is_mlu_place(place) || paddle::platform::is_mlu_place(place) ||
paddle::platform::is_custom_place(place) ||
paddle::platform::is_npu_place(place) || paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) { paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
......
...@@ -276,9 +276,10 @@ def amp_guard(enable=True, ...@@ -276,9 +276,10 @@ def amp_guard(enable=True,
if enable and not (tracer._expected_place.is_gpu_place() or if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()): tracer._expected_place.is_npu_place() or
tracer._expected_place.is_custom_place()):
warnings.warn( warnings.warn(
'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, and NPUPlace, current place is %s, so it makes no effect.' 'amp_guard can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace, and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place) % tracer._expected_place)
enable = False enable = False
# For npu: # For npu:
...@@ -293,6 +294,10 @@ def amp_guard(enable=True, ...@@ -293,6 +294,10 @@ def amp_guard(enable=True,
if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'): if tracer._expected_place.is_mlu_place() and (dtype == 'bfloat16'):
warnings.warn('MLUPlace only support float16 amp.') warnings.warn('MLUPlace only support float16 amp.')
enable = False enable = False
# For custom device:
if tracer._expected_place.is_custom_place() and (dtype == 'bfloat16'):
warnings.warn('CustomPlace only support float16 amp.')
enable = False
# For gpu float16: Compute Capability should >= 7. # For gpu float16: Compute Capability should >= 7.
# For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11. # For gpu bfloat16: Compute Capability should >= 8 & CUDA Version should >= 11.
if tracer._expected_place.is_gpu_place(): if tracer._expected_place.is_gpu_place():
......
...@@ -107,9 +107,10 @@ class AmpScaler(object): ...@@ -107,9 +107,10 @@ class AmpScaler(object):
if enable and not (tracer._expected_place.is_gpu_place() or if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place() or tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_mlu_place() or tracer._expected_place.is_mlu_place() or
tracer._expected_place.is_npu_place()): tracer._expected_place.is_npu_place() or
tracer._expected_place.is_custom_place()):
warnings.warn( warnings.warn(
'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace and NPUPlace, current place is %s, so it makes no effect.' 'AmpScaler can only be enabled on CUDAPlace, XPUPlace, MLUPlace, NPUPlace and CustomPlace, current place is %s, so it makes no effect.'
% tracer._expected_place) % tracer._expected_place)
enable = False enable = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册