提交 8b8a02d6 编写于 作者: F flytocc

add update_freq option for gradient accumulation

上级 ed820223
...@@ -15,7 +15,7 @@ Global: ...@@ -15,7 +15,7 @@ Global:
save_inference_dir: ./inference save_inference_dir: ./inference
# training model under @to_static # training model under @to_static
to_static: False to_static: False
update_freq: 8
# model ema # model ema
EMA: EMA:
...@@ -51,7 +51,7 @@ Optimizer: ...@@ -51,7 +51,7 @@ Optimizer:
one_dim_param_no_weight_decay: True one_dim_param_no_weight_decay: True
lr: lr:
name: Cosine name: Cosine
learning_rate: 5e-4 learning_rate: 4e-3 # lr 4e-3 for total_batch_size 4096
eta_min: 1e-6 eta_min: 1e-6
warmup_epoch: 20 warmup_epoch: 20
warmup_start_lr: 0 warmup_start_lr: 0
......
...@@ -119,6 +119,9 @@ class Engine(object): ...@@ -119,6 +119,9 @@ class Engine(object):
# EMA model # EMA model
self.ema = "EMA" in self.config and self.mode == "train" self.ema = "EMA" in self.config and self.mode == "train"
# gradient accumulation
self.update_freq = self.config["Global"].get("update_freq", 1)
if "class_num" in config["Global"]: if "class_num" in config["Global"]:
global_class_num = config["Global"]["class_num"] global_class_num = config["Global"]["class_num"]
if "class_num" not in config["Arch"]: if "class_num" not in config["Arch"]:
...@@ -229,7 +232,7 @@ class Engine(object): ...@@ -229,7 +232,7 @@ class Engine(object):
if self.mode == 'train': if self.mode == 'train':
self.optimizer, self.lr_sch = build_optimizer( self.optimizer, self.lr_sch = build_optimizer(
self.config["Optimizer"], self.config["Global"]["epochs"], self.config["Optimizer"], self.config["Global"]["epochs"],
len(self.train_dataloader), len(self.train_dataloader) // self.update_freq,
[self.model, self.train_loss_func]) [self.model, self.train_loss_func])
# for amp training # for amp training
...@@ -312,6 +315,7 @@ class Engine(object): ...@@ -312,6 +315,7 @@ class Engine(object):
self.max_iter = len(self.train_dataloader) - 1 if platform.system( self.max_iter = len(self.train_dataloader) - 1 if platform.system(
) == "Windows" else len(self.train_dataloader) ) == "Windows" else len(self.train_dataloader)
self.max_iter = self.max_iter // self.update_freq * self.update_freq
for epoch_id in range(best_metric["epoch"] + 1, for epoch_id in range(best_metric["epoch"] + 1,
self.config["Global"]["epochs"] + 1): self.config["Global"]["epochs"] + 1):
acc = 0.0 acc = 0.0
......
...@@ -53,25 +53,32 @@ def train_epoch(engine, epoch_id, print_batch_step): ...@@ -53,25 +53,32 @@ def train_epoch(engine, epoch_id, print_batch_step):
out = forward(engine, batch) out = forward(engine, batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
# loss
loss = loss_dict["loss"] / engine.update_freq
# step opt # step opt
if engine.amp: if engine.amp:
scaled = engine.scaler.scale(loss_dict["loss"]) scaled = engine.scaler.scale(loss)
scaled.backward() scaled.backward()
for i in range(len(engine.optimizer)): if (iter_id + 1) % engine.update_freq == 0:
engine.scaler.minimize(engine.optimizer[i], scaled) for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled)
else: else:
loss_dict["loss"].backward() loss.backward()
if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)):
engine.optimizer[i].step()
if (iter_id + 1) % engine.update_freq == 0:
# clear grad
for i in range(len(engine.optimizer)): for i in range(len(engine.optimizer)):
engine.optimizer[i].step() engine.optimizer[i].clear_grad()
# clear grad # step lr
for i in range(len(engine.optimizer)): for i in range(len(engine.lr_sch)):
engine.optimizer[i].clear_grad() engine.lr_sch[i].step()
# step lr # update ema
for i in range(len(engine.lr_sch)): if engine.ema:
engine.lr_sch[i].step() engine.model_ema.update(engine.model)
# update ema
if engine.ema:
engine.model_ema.update(engine.model)
# below code just for logging # below code just for logging
# update metric_for_logger # update metric_for_logger
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册