提交 8e2bb7a8 编写于 作者: C caojian05

fix accurancy lower then 92

上级 6a7b9743
......@@ -19,7 +19,9 @@ from easydict import EasyDict as edict
cifar_cfg = edict({
'num_classes': 10,
'lr_init': 0.05,
'lr_init': 0.01,
'lr_max': 0.1,
'warmup_epochs': 5,
'batch_size': 64,
'epoch_size': 70,
'momentum': 0.9,
......
......@@ -38,20 +38,25 @@ random.seed(1)
np.random.seed(1)
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
def lr_steps(global_step, lr_init, lr_max, warmup_epochs, total_epochs, steps_per_epoch):
"""Set learning rate."""
lr_each_step = []
total_steps = steps_per_epoch * total_epochs
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
warmup_steps = steps_per_epoch * warmup_epochs
if warmup_steps != 0:
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
else:
inc_each_step = 0
for i in range(total_steps):
if i < decay_epoch_index[0]:
lr_each_step.append(lr_max)
elif i < decay_epoch_index[1]:
lr_each_step.append(lr_max * 0.1)
elif i < decay_epoch_index[2]:
lr_each_step.append(lr_max * 0.01)
if i < warmup_steps:
lr_value = float(lr_init) + inc_each_step * float(i)
else:
lr_each_step.append(lr_max * 0.001)
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
lr_value = float(lr_max) * base * base
if lr_value < 0.0:
lr_value = 0.0
lr_each_step.append(lr_value)
current_step = global_step
lr_each_step = np.array(lr_each_step).astype(np.float32)
learning_rate = lr_each_step[current_step:]
......@@ -86,7 +91,8 @@ if __name__ == '__main__':
if args_opt.pre_trained:
load_param_into_net(net, load_checkpoint(args_opt.pre_trained))
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
lr = lr_steps(0, lr_init=cfg.lr_init, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
weight_decay=cfg.weight_decay)
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册