提交 8b8a02d6 编写于 作者: F flytocc

add update_freq option for gradient accumulation

上级 ed820223
......@@ -15,7 +15,7 @@ Global:
save_inference_dir: ./inference
# training model under @to_static
to_static: False
update_freq: 8
# model ema
EMA:
......@@ -51,7 +51,7 @@ Optimizer:
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 5e-4
learning_rate: 4e-3 # lr 4e-3 for total_batch_size 4096
eta_min: 1e-6
warmup_epoch: 20
warmup_start_lr: 0
......
......@@ -119,6 +119,9 @@ class Engine(object):
# EMA model
self.ema = "EMA" in self.config and self.mode == "train"
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]:
......@@ -229,7 +232,7 @@ class Engine(object):
if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader),
len(self.train_dataloader) // self.update_freq,
[self.model, self.train_loss_func])
# for amp training
......@@ -312,6 +315,7 @@ class Engine(object):
self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader)
self.max_iter = self.max_iter // self.update_freq * self.update_freq
for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1):
acc = 0.0
......
......@@ -53,25 +53,32 @@ def train_epoch(engine, epoch_id, print_batch_step):
out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1])
# loss
loss = loss_dict["loss"] / engine.update_freq
# step opt
if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"])
scaled = engine.scaler.scale(loss)
scaled.backward()
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else:
loss_dict["loss"].backward()
loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
# clear grad
for i in range(len(engine.optimizer)):
engine.optimizer[i].clear_grad()
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
engine.optimizer[i].clear_grad()
# step lr
for i in range(len(engine.lr_sch)):
engine.lr_sch[i].step()
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
# below code just for logging
# update metric_for_logger
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册