提交 b0ff535b 编写于 作者: S ScXfjiang

dump momentum_buffer for each iteration

上级 2fe865bf
......@@ -18,6 +18,8 @@ import os
from functools import partial
import pickle as pkl
def reduce_loss_dict(loss_dict):
"""
Reduce the loss dictionary from all processes so that process with rank
......@@ -44,6 +46,7 @@ def reduce_loss_dict(loss_dict):
def do_train(
cfg,
model,
data_loader,
optimizer,
......@@ -103,6 +106,17 @@ def do_train(
losses.backward()
optimizer.step()
if not os.path.exists("model_name2momentum_buffer/"):
os.makedirs("model_name2momentum_buffer/")
state_dict = optimizer.state_dict()
model_name2momentum_buffer = {}
for key, value in model.named_parameters():
if value.requires_grad:
momentum_buffer = state_dict['state'][id(value)]['momentum_buffer'].cpu().detach().numpy()
model_name2momentum_buffer[key] = momentum_buffer
pkl.dump(model_name2momentum_buffer, open("model_name2momentum_buffer/" + os.path.basename(cfg.MODEL.WEIGHT) \
+ "-iteration-" + str(iteration) +'-model_name2momentum_buffer.pkl', 'w'))
batch_time = time.time() - end
end = time.time()
meters.update(time=batch_time, data=data_time)
......@@ -110,7 +124,7 @@ def do_train(
eta_seconds = meters.time.global_avg * (max_iter - iteration)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if iteration % 20 == 0 or iteration == max_iter:
if iteration % 1 == 0 or iteration == max_iter:
logger.info(
meters.delimiter.join(
[
......
......@@ -25,8 +25,6 @@ from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.logger import setup_logger
from maskrcnn_benchmark.utils.miscellaneous import mkdir
import pickle as pkl
def train(cfg, local_rank, distributed):
model = build_detection_model(cfg)
device = torch.device(cfg.MODEL.DEVICE)
......@@ -53,14 +51,6 @@ def train(cfg, local_rank, distributed):
)
extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT)
arguments.update(extra_checkpoint_data)
state_dict = optimizer.state_dict()
model_name2momentum_buffer = {}
for key, value in model.named_parameters():
if value.requires_grad:
momentum_buffer = state_dict['state'][id(value)]['momentum_buffer'].cpu().detach().numpy()
model_name2momentum_buffer[key] = momentum_buffer
pkl.dump(model_name2momentum_buffer, open(os.path.basename(cfg.MODEL.WEIGHT) + '.model_name2momentum_buffer.pkl', 'w'))
data_loader = make_data_loader(
cfg,
......@@ -73,6 +63,7 @@ def train(cfg, local_rank, distributed):
arguments["fake_image"] = cfg.DATALOADER.FAKE_IMAGE_DATA_PATH
do_train(
cfg,
model,
data_loader,
optimizer,
......
......@@ -6,7 +6,7 @@ rm -f last_checkpoint
rm -f model_final.pth
rm -f log.txt
rm -f model_0090000.pth
rm -f e2e_mask_rcnn_R_50_FPN_1x.pth.model_name2momentum_buffer.pkl
rm -rf model_name2momentum_buffer
CUDA_VISIBLE_DEVICES=1 \
python ./tools/train_net.py \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册