未验证 提交 69dd43d1 编写于 作者: F furnace 提交者: GitHub

[NPU] add AMP O1 support (#40362)

* [NPU] add AMP O1 support

* [NPU] fix NOTE and warnings
上级 2c5edb4f
...@@ -209,7 +209,9 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) { ...@@ -209,7 +209,9 @@ inline bool NeedCast(const std::shared_ptr<VarType>& var) {
auto data_type = GetDataType<VarType>(var); auto data_type = GetDataType<VarType>(var);
if (paddle::platform::is_gpu_place(place) || if (paddle::platform::is_gpu_place(place) ||
paddle::platform::is_cuda_pinned_place(place) || paddle::platform::is_cuda_pinned_place(place) ||
paddle::platform::is_xpu_place(place)) { paddle::platform::is_xpu_place(place) ||
paddle::platform::is_npu_place(place) ||
paddle::platform::is_npu_pinned_place(place)) {
// CudaPinndePlace is added for varbase created by dataloader // CudaPinndePlace is added for varbase created by dataloader
if (data_type == paddle::framework::proto::VarType::FP32 || if (data_type == paddle::framework::proto::VarType::FP32 ||
data_type == paddle::framework::proto::VarType::FP16 || data_type == paddle::framework::proto::VarType::FP16 ||
......
...@@ -88,6 +88,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -88,6 +88,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"nce", {"nce",
{"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs", {"Input", "Label", "Weight", "Bias", "SampleWeight", "CustomDistProbs",
"CustomDistAlias", "CustomDistAliasProbs"}}, "CustomDistAlias", "CustomDistAliasProbs"}},
{"check_finite_and_unscale", {"X", "Scale", "FloatStatus"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
......
...@@ -271,14 +271,19 @@ def amp_guard(enable=True, ...@@ -271,14 +271,19 @@ def amp_guard(enable=True,
"current_tracer is None, maybe it is not in imperative mode.") "current_tracer is None, maybe it is not in imperative mode.")
# check device_type: # check device_type:
# NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16. # NOTE: Now, amp only support gpu for float16 and bfloat16, xpu for float16, npu for float16.
# Maybe we will support cpu for bfloat16. # Maybe we will support cpu for bfloat16.
if enable and not (tracer._expected_place.is_gpu_place() or if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place()): tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_npu_place()):
warnings.warn( warnings.warn(
'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' 'amp_guard can only be enabled on CUDAPlace, XPUPlace, and NPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place) % tracer._expected_place)
enable = False enable = False
# For npu:
if tracer._expected_place.is_npu_place() and (dtype == 'bfloat16'):
warnings.warn('NPUPlace only support float16 amp.')
enable = False
# For xpu: # For xpu:
if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'): if tracer._expected_place.is_xpu_place() and (dtype == 'bfloat16'):
warnings.warn('XPUPlace only support float16 amp.') warnings.warn('XPUPlace only support float16 amp.')
......
...@@ -105,9 +105,10 @@ class AmpScaler(object): ...@@ -105,9 +105,10 @@ class AmpScaler(object):
"current_tracer is None, maybe it is not in imperative mode.") "current_tracer is None, maybe it is not in imperative mode.")
if enable and not (tracer._expected_place.is_gpu_place() or if enable and not (tracer._expected_place.is_gpu_place() or
tracer._expected_place.is_xpu_place()): tracer._expected_place.is_xpu_place() or
tracer._expected_place.is_npu_place()):
warnings.warn( warnings.warn(
'AmpScaler can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' 'AmpScaler can only be enabled on CUDAPlace, XPUPlace and NPUPlace, current place is %s, so it makes no effect.'
% tracer._expected_place) % tracer._expected_place)
enable = False enable = False
...@@ -286,6 +287,19 @@ class AmpScaler(object): ...@@ -286,6 +287,19 @@ class AmpScaler(object):
) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32 ) and (param._grad_ivar().dtype == core.VarDesc.VarType.FP32
) )
] ]
if core.is_compiled_with_npu():
float_status = _C_ops.alloc_float_status()
_C_ops.clear_float_status(float_status, float_status)
if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
float_status, param_grads_fp16,
self._temp_found_inf_fp16)
if len(param_grads_fp32):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
float_status, param_grads_fp32,
self._temp_found_inf_fp32)
else:
if len(param_grads_fp16): if len(param_grads_fp16):
_C_ops.check_finite_and_unscale(param_grads_fp16, self._scale, _C_ops.check_finite_and_unscale(param_grads_fp16, self._scale,
param_grads_fp16, param_grads_fp16,
...@@ -294,6 +308,7 @@ class AmpScaler(object): ...@@ -294,6 +308,7 @@ class AmpScaler(object):
_C_ops.check_finite_and_unscale(param_grads_fp32, self._scale, _C_ops.check_finite_and_unscale(param_grads_fp32, self._scale,
param_grads_fp32, param_grads_fp32,
self._temp_found_inf_fp32) self._temp_found_inf_fp32)
if len(param_grads_fp16) and len(param_grads_fp32): if len(param_grads_fp16) and len(param_grads_fp32):
self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32 self._found_inf = self._temp_found_inf_fp16 or self._temp_found_inf_fp32
elif len(param_grads_fp16): elif len(param_grads_fp16):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册