未验证 提交 f30d0055 编写于 作者: Z Zhen Wang 提交者: GitHub

Fix the compiling error of update_loss_scaling when using cuda9. (#30538)

上级 81217a94
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
#pragma once #pragma once
#if defined(PADDLE_WITH_CUDA) && defined(__NVCC__)
#include <cuda.h>
#endif // PADDLE_WITH_CUDA && __NVCC__
#include <cmath> #include <cmath>
#include <vector> #include <vector>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -29,13 +32,23 @@ namespace operators { ...@@ -29,13 +32,23 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
HOSTDEVICE void Update(const bool* found_inf_data, inline HOSTDEVICE bool check_finite(T value) {
const T* pre_loss_scaling_data, const int* good_in_data, #if defined(PADDLE_WITH_CUDA) && defined(__NVCC__)
const int* bad_in_data, const int incr_every_n_steps, return isfinite(value);
const int decr_every_n_nan_or_inf, #else
const float incr_ratio, const float decr_ratio, return std::isfinite(value);
T* updated_loss_scaling_data, int* good_out_data, #endif
int* bad_out_data) { }
template <typename T>
inline HOSTDEVICE void Update(const bool* found_inf_data,
const T* pre_loss_scaling_data,
const int* good_in_data, const int* bad_in_data,
const int incr_every_n_steps,
const int decr_every_n_nan_or_inf,
const float incr_ratio, const float decr_ratio,
T* updated_loss_scaling_data, int* good_out_data,
int* bad_out_data) {
if (*found_inf_data) { if (*found_inf_data) {
*good_out_data = 0; *good_out_data = 0;
*bad_out_data = *bad_in_data + 1; *bad_out_data = *bad_in_data + 1;
...@@ -51,7 +64,7 @@ HOSTDEVICE void Update(const bool* found_inf_data, ...@@ -51,7 +64,7 @@ HOSTDEVICE void Update(const bool* found_inf_data,
*good_out_data = *good_in_data + 1; *good_out_data = *good_in_data + 1;
if (*good_out_data == incr_every_n_steps) { if (*good_out_data == incr_every_n_steps) {
T new_loss_scaling = *pre_loss_scaling_data * incr_ratio; T new_loss_scaling = *pre_loss_scaling_data * incr_ratio;
*updated_loss_scaling_data = std::isfinite(new_loss_scaling) *updated_loss_scaling_data = check_finite(new_loss_scaling)
? new_loss_scaling ? new_loss_scaling
: *pre_loss_scaling_data; : *pre_loss_scaling_data;
*good_out_data = 0; *good_out_data = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册