diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index 67a4c0caf8e19d8e09cc7d90fc682f8851dc9200..0ad84f5890acaf1c793000859ed3fbc7c1fc22d3 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -207,6 +207,21 @@ void CheckNanInf( PrintNanInf(value, numel, print_num, op_type, var_name); } } + +template <> +void CheckNanInf( + const paddle::platform::bfloat16* value, const size_t numel, int print_num, + const std::string& op_type, const std::string& var_name) { + float sum = 0.0f; +#pragma omp parallel for reduction(+ : sum) + for (size_t i = 0; i < numel; ++i) { + sum += static_cast(value[i] - value[i]); + } + + if (std::isnan(sum) || std::isinf(sum)) { + PrintNanInf(value, numel, print_num, op_type, var_name); + } +} #endif template <>