未验证 提交 9f52a36f 编写于 作者: R Reza Yazdani 提交者: GitHub

tracking optimizer step in cpu-adam when loading checkpoint (#564)

* tracking optimizer step in cpu-adam when loading checkpoint

* add warning/error message for updating optimizer step count

* resolve build issue

* supporting state update from the python side

* track step from python in all cases

* remove comma
上级 c78c29f9
......@@ -28,10 +28,7 @@ void Adam_Optimizer::Step(float* _params,
float betta1_minus1 = 1 - _betta1;
float betta2_minus1 = 1 - _betta2;
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
float w_decay = -1 * _alpha * _weight_decay;
size_t rounded_size = 0;
......@@ -48,7 +45,7 @@ void Adam_Optimizer::Step(float* _params,
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
......@@ -130,7 +127,7 @@ void Adam_Optimizer::Step(float* _params,
variance = grad * betta2_minus1 + variance;
grad = sqrt(variance);
grad = grad * bias_correction2 + _eps;
grad = grad * _bias_correction2 + _eps;
grad = momentum / grad;
if (_weight_decay > 0 && _adamw_mode) { param += w_decay * param; }
param = grad * step_size + param;
......@@ -172,16 +169,13 @@ void Adam_Optimizer::Step_4(float* _params,
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
......@@ -386,16 +380,13 @@ void Adam_Optimizer::Step_8(float* _params,
AVX_Data betta2_minus1_4;
betta2_minus1_4.data = SIMD_SET(betta2_minus1);
float bias_correction1 = 1 - _betta1_t;
float bias_correction2 = 1 / sqrt(1 - _betta2_t);
// AVX_Data bias_correction1_4 = SIMD_SET(bias_correction1);
AVX_Data bias2_sqrt;
bias2_sqrt.data = SIMD_SET(bias_correction2);
bias2_sqrt.data = SIMD_SET(_bias_correction2);
AVX_Data eps_4;
eps_4.data = SIMD_SET(_eps);
float step_size = -1 * _alpha / bias_correction1;
float step_size = -1 * _alpha / _bias_correction1;
AVX_Data step_size_4;
step_size_4.data = SIMD_SET(step_size);
......@@ -611,6 +602,11 @@ void Adam_Optimizer::Step_8(float* _params,
int ds_adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
......@@ -628,8 +624,8 @@ int ds_adam_step(int optimizer_id,
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_lr(lr);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0));
return 0;
......@@ -638,6 +634,11 @@ int ds_adam_step(int optimizer_id,
int ds_adam_step_plus_copy(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
......@@ -658,8 +659,8 @@ int ds_adam_step_plus_copy(int optimizer_id,
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step);
opt->update_lr(lr);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(
params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, params_c.size(0), gpu_params_ptr);
......
......@@ -89,18 +89,40 @@ public:
float* _exp_avg_sq,
size_t _param_size,
__half* dev_params = nullptr);
inline void IncrementStep(size_t step)
inline void IncrementStep(size_t step, float beta1, float beta2)
{
if (_step < step) {
if (beta1 != _betta1 || beta2 != _betta2) {
_step = step;
_betta1 = beta1;
_betta2 = beta2;
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
} else {
_step++;
if (_step != step) {
throw std::runtime_error("Optimizer lost track of step count!\n");
_betta1_t = std::pow(_betta1, step);
_betta2_t = std::pow(_betta2, step);
_step = step;
} else {
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
_betta1_t *= _betta1;
_betta2_t *= _betta2;
}
}
inline void update_lr(float lr) { _alpha = lr; }
inline void update_state(float lr, float epsilon, float weight_decay, bool bias_correction)
{
_alpha = lr;
_eps = epsilon;
_weight_decay = weight_decay;
_bias_correction1 = 1.0f;
_bias_correction2 = 1.0f;
if (bias_correction == 1) {
_bias_correction1 = 1 - _betta1_t;
_bias_correction2 = 1 / sqrt(1 - _betta2_t);
}
}
private:
#if defined(__AVX512__) or defined(__AVX256__)
......@@ -124,6 +146,9 @@ private:
float _betta2_t;
size_t _step;
float _bias_correction1;
float _bias_correction2;
float* _doubled_buffer[2];
bool _buf_index;
bool _adamw_mode;
......
......@@ -50,6 +50,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
def __init__(self,
model_params,
lr=1e-3,
bias_correction=True,
betas=(0.9,
0.999),
eps=1e-8,
......@@ -61,6 +62,7 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
betas=betas,
eps=eps,
weight_decay=weight_decay,
bias_correction=bias_correction,
amsgrad=amsgrad)
super(DeepSpeedCPUAdam, self).__init__(model_params, default_args)
......@@ -112,12 +114,18 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
#memory_format=torch.preserve_format)
state['step'] += 1
beta1, beta2 = group['betas']
if fp16_param_groups is not None:
self.ds_opt_adam.adam_update_copy(
self.opt_id,
state['step'],
group['lr'],
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
......@@ -127,6 +135,11 @@ class DeepSpeedCPUAdam(torch.optim.Optimizer):
self.ds_opt_adam.adam_update(self.opt_id,
state['step'],
group['lr'],
beta1,
beta2,
group['eps'],
group['weight_decay'],
group['bias_correction'],
p.data,
p.grad.data,
state['exp_avg'],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册