From f30d00553ae41c543aaa829145c7b1bce7458b49 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 19 Jan 2021 16:45:44 +0800 Subject: [PATCH] Fix the compiling error of update_loss_scaling when using cuda9. (#30538) --- .../operators/amp/update_loss_scaling_op.h | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op.h b/paddle/fluid/operators/amp/update_loss_scaling_op.h index db768f3f87..decc3c3b92 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op.h +++ b/paddle/fluid/operators/amp/update_loss_scaling_op.h @@ -14,6 +14,9 @@ #pragma once +#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__) +#include +#endif // PADDLE_WITH_CUDA && __NVCC__ #include #include #include "paddle/fluid/framework/operator.h" @@ -29,13 +32,23 @@ namespace operators { using Tensor = framework::Tensor; template -HOSTDEVICE void Update(const bool* found_inf_data, - const T* pre_loss_scaling_data, const int* good_in_data, - const int* bad_in_data, const int incr_every_n_steps, - const int decr_every_n_nan_or_inf, - const float incr_ratio, const float decr_ratio, - T* updated_loss_scaling_data, int* good_out_data, - int* bad_out_data) { +inline HOSTDEVICE bool check_finite(T value) { +#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__) + return isfinite(value); +#else + return std::isfinite(value); +#endif +} + +template +inline HOSTDEVICE void Update(const bool* found_inf_data, + const T* pre_loss_scaling_data, + const int* good_in_data, const int* bad_in_data, + const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, + const float incr_ratio, const float decr_ratio, + T* updated_loss_scaling_data, int* good_out_data, + int* bad_out_data) { if (*found_inf_data) { *good_out_data = 0; *bad_out_data = *bad_in_data + 1; @@ -51,7 +64,7 @@ HOSTDEVICE void Update(const bool* found_inf_data, *good_out_data = *good_in_data + 1; if (*good_out_data == incr_every_n_steps) { T new_loss_scaling = *pre_loss_scaling_data * incr_ratio; - *updated_loss_scaling_data = std::isfinite(new_loss_scaling) + *updated_loss_scaling_data = check_finite(new_loss_scaling) ? new_loss_scaling : *pre_loss_scaling_data; *good_out_data = 0; -- GitLab