提交 73e2cde6 编写于 作者: G gaotingquan 提交者: Wei Shengyu

mv some attrs to __init__()

上级 0d7e595f
......@@ -50,6 +50,9 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"]
self.mode = mode
self.config = config
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
# set seed
self._init_seed()
......@@ -113,8 +116,6 @@ class Engine(object):
assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"]
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
epochs = self.config["Global"]["epochs"]
best_metric = {
"metric": -1.0,
......@@ -140,20 +141,20 @@ class Engine(object):
# global iter counter
self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, epochs + 1):
for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(epoch_id, epochs,
metric_msg))
logger.info("[Train][Epoch {}/{}][Avg]{}".format(
epoch_id, self.epochs, metric_msg))
self.output_info.clear()
acc = 0.0
if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch:
"eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc
......@@ -469,7 +470,6 @@ class Engine(object):
self.model_ema)
if metric_info is not None:
best_metric.update(metric_info)
return best_metric
class ExportModel(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册