提交 3135fbcc 编写于 作者: S seiriosPlus

fix adam

上级 582f593a
......@@ -115,10 +115,10 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
"param_row should have the same size with grad_row"));
T lr_ = lr[0];
T beta1_ = beta1_pow->data<T>()[0];
T beta2_ = beta2_pow->data<T>()[0];
T beta1_pow_ = beta1_pow->data<T>()[0];
T beta2_pow_ = beta2_pow->data<T>()[0];
lr_ *= sqrt(1 - beta1_) / (1 - beta2_);
lr_ *= sqrt(1 - beta2_pow_) / (1 - beta1_pow_);
for (size_t i = 0; i < in_rows.size(); i++) {
auto &params = values[i][0];
......@@ -131,8 +131,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
for (int x = 0; x < grad_width; ++x) {
auto g = grad_v.data<T>()[grad_width * i + x];
m1_data[x] = beta1_ * m1_data[x] + (1 - beta1_) * g;
m2_data[x] = beta2_ * m2_data[x] + (1 - beta2_) * g * g;
m1_data[x] = beta1 * m1_data[x] + (1 - beta1) * g;
m2_data[x] = beta2 * m2_data[x] + (1 - beta2) * g * g;
p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册