未验证 提交 6f1bb3d6 编写于 作者: G Guoxia Wang 提交者: GitHub

fix adamw epsilon in cuda kernel (#37746)

上级 340dfb26
......@@ -27,25 +27,25 @@ __global__ void AdamWKernelREG(MT beta1, MT beta2, MT epsilon, MT coeff,
T* param_out, const MT* master_param,
MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = beta1_pow_;
MT beta2_pow = beta2_pow_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
MT p = master_param ? master_param[id] : static_cast<MT>(param[id]);
MT g = static_cast<MT>(grad[id]);
MT mom1 = moment1[id];
MT mom2 = moment2[id];
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
p *= (static_cast<MT>(1.0) - lr * coeff);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
......@@ -63,13 +63,9 @@ __global__ void AdamWKernelMEM(
MT* moment2_out, const MT* lr_, const T* grad, const T* param, T* param_out,
const MT* master_param, MT* master_param_out, int ndim) {
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
MT beta1_pow = *beta1_pow_;
MT beta2_pow = *beta2_pow_;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
int id = blockIdx.x * blockDim.x + threadIdx.x;
for (; id < ndim; id += gridDim.x * blockDim.x) {
......@@ -77,11 +73,15 @@ __global__ void AdamWKernelMEM(
MT g = static_cast<MT>(grad[id]);
MT mom1 = static_cast<MT>(moment1[id]);
MT mom2 = static_cast<MT>(moment2[id]);
p *= (static_cast<MT>(1.0) - lr * coeff);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr_orig * coeff * p;
p -= lr * (mom1 /
(sqrt(mom2) + epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom = (sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
moment1_out[id] = mom1;
moment2_out[id] = mom2;
......@@ -109,10 +109,6 @@ __global__ void SparseAdamWCUDAKernelREG(
int ndim) {
int id = blockIdx.x * blockDim.x + threadIdx.x;
MT lr = *lr_ * lr_ratio;
MT lr_orig = lr;
lr *= sqrt(static_cast<MT>(1.0) - beta2_pow) /
(static_cast<MT>(1.0) - beta1_pow);
for (; id < ndim; id += blockDim.x * gridDim.x) {
auto row_idx =
......@@ -120,17 +116,23 @@ __global__ void SparseAdamWCUDAKernelREG(
if (lazy_mode && row_idx < 0) {
return;
} else {
MT mom1 = mom1_[id];
MT mom2 = mom2_[id];
MT mom1 = static_cast<MT>(mom1_[id]);
MT mom2 = static_cast<MT>(mom2_[id]);
MT p = master_param ? master_param[id] : static_cast<MT>(param_[id]);
MT g = row_idx >= 0
? static_cast<MT>(grad_[row_idx * row_numel + id % row_numel])
: static_cast<MT>(0);
p *= (static_cast<MT>(1.0) - lr * coeff);
mom1 = beta1 * mom1 + (static_cast<MT>(1.0) - beta1) * g;
mom2 = beta2 * mom2 + (static_cast<MT>(1.0) - beta2) * g * g;
p -= lr_orig * coeff * p;
p -= lr * (mom1 / (sqrt(mom2) +
epsilon * sqrt(static_cast<MT>(1.0) - beta2_pow)));
MT denom =
(sqrt(mom2) / sqrt(static_cast<MT>(1.0) - beta2_pow)) + epsilon;
p += (mom1 / denom) * (-(lr / (static_cast<MT>(1.0) - beta1_pow)));
// Write back to global memory
mom1_out_[id] = mom1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册