提交 0e28a39d 编写于 作者: G gaotingquan 提交者: Wei Shengyu

refactor

上级 fad5c8e3
...@@ -447,6 +447,8 @@ class Engine(object): ...@@ -447,6 +447,8 @@ class Engine(object):
level=self.amp_level, level=self.amp_level,
save_dtype='float32') save_dtype='float32')
self.amp_level = engine.config["AMP"].get("level", "O1").upper()
def _init_dist(self): def _init_dist(self):
# check the gpu num # check the gpu num
world_size = dist.get_world_size() world_size = dist.get_world_size()
......
...@@ -36,31 +36,25 @@ def regular_train_epoch(engine, epoch_id, print_batch_step): ...@@ -36,31 +36,25 @@ def regular_train_epoch(engine, epoch_id, print_batch_step):
batch[1] = batch[1].reshape([batch_size, -1]) batch[1] = batch[1].reshape([batch_size, -1])
engine.global_step += 1 engine.global_step += 1
# image input # forward & backward & step opt
if engine.amp: if engine.amp:
amp_level = engine.config["AMP"].get("level", "O1").upper()
with paddle.amp.auto_cast( with paddle.amp.auto_cast(
custom_black_list={ custom_black_list={
"flatten_contiguous_range", "greater_than" "flatten_contiguous_range", "greater_than"
}, },
level=amp_level): level=engine.amp_level):
out = engine.model(batch) out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1]) loss_dict = engine.train_loss_func(out, batch[1])
else: loss = loss_dict["loss"] / engine.update_freq
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
# loss
loss = loss_dict["loss"] / engine.update_freq
# backward & step opt
if engine.amp:
scaled = engine.scaler.scale(loss) scaled = engine.scaler.scale(loss)
scaled.backward() scaled.backward()
if (iter_id + 1) % engine.update_freq == 0: if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)): for i in range(len(engine.optimizer)):
engine.scaler.minimize(engine.optimizer[i], scaled) engine.scaler.minimize(engine.optimizer[i], scaled)
else: else:
out = engine.model(batch)
loss_dict = engine.train_loss_func(out, batch[1])
loss = loss_dict["loss"] / engine.update_freq
loss.backward() loss.backward()
if (iter_id + 1) % engine.update_freq == 0: if (iter_id + 1) % engine.update_freq == 0:
for i in range(len(engine.optimizer)): for i in range(len(engine.optimizer)):
......
import os
import paddle
from . import logger
def _mkdir_if_not_exist(path):
"""
mkdir if not exists, ignore the exception when multiprocess mkdir together
"""
if not os.path.exists(path):
try:
os.makedirs(path)
except OSError as e:
if e.errno == errno.EEXIST and os.path.isdir(path):
logger.warning(
'be happy if some process has already created {}'.format(
path))
else:
raise OSError('Failed to mkdir {}'.format(path))
def _extract_student_weights(all_params, student_prefix="Student."):
s_params = {
key[len(student_prefix):]: all_params[key]
for key in all_params if student_prefix in key
}
return s_params
class ModelSaver(object):
def __init__(self,
engine,
net_name="model",
loss_name="train_loss_func",
opt_name="optimizer",
model_ema_name="model_ema"):
# net, loss, opt, model_ema, output_dir,
self.engine = engine
self.net_name = net_name
self.loss_name = loss_name
self.opt_name = opt_name
self.model_ema_name = model_ema_name
arch_name = engine.config["Arch"]["name"]
self.output_dir = os.path.join(engine.output_dir, arch_name)
_mkdir_if_not_exist(self.output_dir)
def save(self, metric_info, prefix='ppcls', save_student_model=False):
if paddle.distributed.get_rank() != 0:
return
save_dir = os.path.join(self.output_dir, prefix)
params_state_dict = getattr(self.engine, self.net_name).state_dict()
loss = getattr(self.engine, self.loss_name)
if loss is not None:
loss_state_dict = loss.state_dict()
keys_inter = set(params_state_dict.keys()) & set(
loss_state_dict.keys())
assert len(keys_inter) == 0, \
f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
params_state_dict.update(loss_state_dict)
if save_student_model:
s_params = _extract_student_weights(params_state_dict)
if len(s_params) > 0:
paddle.save(s_params, save_dir + "_student.pdparams")
paddle.save(params_state_dict, save_dir + ".pdparams")
model_ema = getattr(self.engine, self.model_ema_name)
if model_ema is not None:
paddle.save(model_ema.module.state_dict(),
save_dir + ".ema.pdparams")
optimizer = getattr(self.engine, self.opt_name)
paddle.save([opt.state_dict() for opt in optimizer],
save_dir + ".pdopt")
paddle.save(metric_info, save_dir + ".pdstates")
logger.info("Already save model in {}".format(save_dir))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册