未验证 提交 6a8a8f7a 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix log (#6086)

上级 05afe0ef
...@@ -17,9 +17,9 @@ import sys ...@@ -17,9 +17,9 @@ import sys
__dir__ = os.path.dirname(os.path.abspath(__file__)) __dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(__dir__) sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
sys.path.append( sys.path.insert(
os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) 0, os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
import argparse import argparse
...@@ -129,7 +129,6 @@ def main(): ...@@ -129,7 +129,6 @@ def main():
quanter.quantize(model) quanter.quantize(model)
load_model(config, model) load_model(config, model)
model.eval()
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
...@@ -142,6 +141,7 @@ def main(): ...@@ -142,6 +141,7 @@ def main():
# start eval # start eval
metric = program.eval(model, valid_dataloader, post_process_class, metric = program.eval(model, valid_dataloader, post_process_class,
eval_class, model_type, use_srn) eval_class, model_type, use_srn)
model.eval()
logger.info('metric eval ***************') logger.info('metric eval ***************')
for k, v in metric.items(): for k, v in metric.items():
...@@ -156,7 +156,6 @@ def main(): ...@@ -156,7 +156,6 @@ def main():
if arch_config["algorithm"] in ["Distillation", ]: # distillation model if arch_config["algorithm"] in ["Distillation", ]: # distillation model
archs = list(arch_config["Models"].values()) archs = list(arch_config["Models"].values())
for idx, name in enumerate(model.model_name_list): for idx, name in enumerate(model.model_name_list):
model.model_list[idx].eval()
sub_model_save_path = os.path.join(save_path, name, "inference") sub_model_save_path = os.path.join(save_path, name, "inference")
export_single_model(model.model_list[idx], archs[idx], export_single_model(model.model_list[idx], archs[idx],
sub_model_save_path, logger, quanter) sub_model_save_path, logger, quanter)
......
...@@ -92,6 +92,9 @@ class BaseModel(nn.Layer): ...@@ -92,6 +92,9 @@ class BaseModel(nn.Layer):
else: else:
y["head_out"] = x y["head_out"] = x
if self.return_all_feats: if self.return_all_feats:
return y if self.training:
return y
else:
return {"head_out": y["head_out"]}
else: else:
return x return x
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册