提交 91dee973 编写于 作者: W WenmuZhou

change log name to root

上级 60f8f1e1
...@@ -352,7 +352,8 @@ def preprocess(): ...@@ -352,7 +352,8 @@ def preprocess():
with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f: with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False) yaml.dump(dict(config), f, default_flow_style=False, sort_keys=False)
logger = get_logger(log_file='{}/train.log'.format(save_model_dir)) logger = get_logger(
name='root', log_file='{}/train.log'.format(save_model_dir))
if config['Global']['use_visualdl']: if config['Global']['use_visualdl']:
from visualdl import LogWriter from visualdl import LogWriter
vdl_writer_path = '{}/vdl/'.format(save_model_dir) vdl_writer_path = '{}/vdl/'.format(save_model_dir)
......
...@@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer ...@@ -36,7 +36,6 @@ from ppocr.optimizer import build_optimizer
from ppocr.postprocess import build_post_process from ppocr.postprocess import build_post_process
from ppocr.metrics import build_metric from ppocr.metrics import build_metric
from ppocr.utils.save_load import init_model from ppocr.utils.save_load import init_model
from ppocr.utils.utility import print_dict
import tools.program as program import tools.program as program
dist.get_world_size() dist.get_world_size()
...@@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer): ...@@ -61,7 +60,7 @@ def main(config, device, logger, vdl_writer):
global_config) global_config)
# build model # build model
#for rec algorithm # for rec algorithm
if hasattr(post_process_class, 'character'): if hasattr(post_process_class, 'character'):
char_num = len(getattr(post_process_class, 'character')) char_num = len(getattr(post_process_class, 'character'))
config['Architecture']["Head"]['out_channels'] = char_num config['Architecture']["Head"]['out_channels'] = char_num
...@@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer): ...@@ -81,10 +80,11 @@ def main(config, device, logger, vdl_writer):
# build metric # build metric
eval_class = build_metric(config['Metric']) eval_class = build_metric(config['Metric'])
# load pretrain model # load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer) pre_best_model_dict = init_model(config, model, logger, optimizer)
logger.info('train dataloader has {} iter, valid dataloader has {} iter'.
format(len(train_dataloader), len(valid_dataloader)))
# start train # start train
program.train(config, train_dataloader, valid_dataloader, device, model, program.train(config, train_dataloader, valid_dataloader, device, model,
loss_class, optimizer, lr_scheduler, post_process_class, loss_class, optimizer, lr_scheduler, post_process_class,
...@@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer): ...@@ -92,8 +92,7 @@ def main(config, device, logger, vdl_writer):
def test_reader(config, device, logger): def test_reader(config, device, logger):
loader = build_dataloader(config, 'Train', device) loader = build_dataloader(config, 'Train', device, logger)
# loader = build_dataloader(config, 'Eval', device)
import time import time
starttime = time.time() starttime = time.time()
count = 0 count = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册