提交 73e2cde6 编写于 作者: G gaotingquan 提交者: Wei Shengyu

mv some attrs to __init__()

上级 0d7e595f
...@@ -50,6 +50,9 @@ class Engine(object): ...@@ -50,6 +50,9 @@ class Engine(object):
assert mode in ["train", "eval", "infer", "export"] assert mode in ["train", "eval", "infer", "export"]
self.mode = mode self.mode = mode
self.config = config self.config = config
self.start_eval_epoch = self.config["Global"].get("start_eval_epoch",
0) - 1
self.epochs = self.config["Global"].get("epochs", 1)
# set seed # set seed
self._init_seed() self._init_seed()
...@@ -113,8 +116,6 @@ class Engine(object): ...@@ -113,8 +116,6 @@ class Engine(object):
assert self.mode == "train" assert self.mode == "train"
print_batch_step = self.config['Global']['print_batch_step'] print_batch_step = self.config['Global']['print_batch_step']
save_interval = self.config["Global"]["save_interval"] save_interval = self.config["Global"]["save_interval"]
start_eval_epoch = self.config["Global"].get("start_eval_epoch", 0) - 1
epochs = self.config["Global"]["epochs"]
best_metric = { best_metric = {
"metric": -1.0, "metric": -1.0,
...@@ -140,20 +141,20 @@ class Engine(object): ...@@ -140,20 +141,20 @@ class Engine(object):
# global iter counter # global iter counter
self.global_step = 0 self.global_step = 0
for epoch_id in range(best_metric["epoch"] + 1, epochs + 1): for epoch_id in range(best_metric["epoch"] + 1, self.epochs + 1):
# for one epoch train # for one epoch train
self.train_epoch_func(self, epoch_id, print_batch_step) self.train_epoch_func(self, epoch_id, print_batch_step)
metric_msg = ", ".join( metric_msg = ", ".join(
[self.output_info[key].avg_info for key in self.output_info]) [self.output_info[key].avg_info for key in self.output_info])
logger.info("[Train][Epoch {}/{}][Avg]{}".format(epoch_id, epochs, logger.info("[Train][Epoch {}/{}][Avg]{}".format(
metric_msg)) epoch_id, self.epochs, metric_msg))
self.output_info.clear() self.output_info.clear()
acc = 0.0 acc = 0.0
if self.config["Global"][ if self.config["Global"][
"eval_during_train"] and epoch_id % self.config["Global"][ "eval_during_train"] and epoch_id % self.config["Global"][
"eval_interval"] == 0 and epoch_id > start_eval_epoch: "eval_interval"] == 0 and epoch_id > self.start_eval_epoch:
acc = self.eval(epoch_id) acc = self.eval(epoch_id)
# step lr (by epoch) according to given metric, such as acc # step lr (by epoch) according to given metric, such as acc
...@@ -469,7 +470,6 @@ class Engine(object): ...@@ -469,7 +470,6 @@ class Engine(object):
self.model_ema) self.model_ema)
if metric_info is not None: if metric_info is not None:
best_metric.update(metric_info) best_metric.update(metric_info)
return best_metric
class ExportModel(TheseusLayer): class ExportModel(TheseusLayer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册