提交 76e08faa 编写于 作者: W wuzewu

Show processing information when evaluating the model

上级 19d9ceb4
...@@ -23,7 +23,7 @@ import paddle ...@@ -23,7 +23,7 @@ import paddle
from paddle.distributed import ParallelEnv from paddle.distributed import ParallelEnv
from visualdl import LogWriter from visualdl import LogWriter
from paddlehub.utils.log import logger from paddlehub.utils.log import logger, processing
from paddlehub.utils.utils import Timer from paddlehub.utils.utils import Timer
...@@ -111,14 +111,14 @@ class Trainer(object): ...@@ -111,14 +111,14 @@ class Trainer(object):
self.current_epoch, metric_msg)) self.current_epoch, metric_msg))
# load model checkpoint # load model checkpoint
model_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdparmas') model_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdparams')
state_dict = paddle.load(model_params_path) state_dict = paddle.load(model_params_path)
self.model.set_dict(state_dict) self.model.set_state_dict(state_dict)
# load optimizer checkpoint # load optimizer checkpoint
optim_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdopt') optim_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdopt')
state_dict = paddle.load(optim_params_path) state_dict = paddle.load(optim_params_path)
self.optimizer.set_dict(state_dict) self.optimizer.set_state_dict(state_dict)
def _save_checkpoint(self): def _save_checkpoint(self):
'''Save model checkpoint and state dict''' '''Save model checkpoint and state dict'''
...@@ -131,7 +131,7 @@ class Trainer(object): ...@@ -131,7 +131,7 @@ class Trainer(object):
model_params_path = os.path.join(save_dir, 'model.pdparams') model_params_path = os.path.join(save_dir, 'model.pdparams')
optim_params_path = os.path.join(save_dir, 'model.pdopt') optim_params_path = os.path.join(save_dir, 'model.pdopt')
paddle.save(self.model.state_dict(), model_params_path) paddle.save(self.model.state_dict(), model_params_path)
paddle.save(self.model.state_dict(), optim_params_path) paddle.save(self.optimizer.state_dict(), optim_params_path)
def _save_metrics(self): def _save_metrics(self):
with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file: with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file:
...@@ -162,10 +162,6 @@ class Trainer(object): ...@@ -162,10 +162,6 @@ class Trainer(object):
log_interval(int) : Log the train infomation every `log_interval` steps. log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs. save_interval(int) : Save the checkpoint every `save_interval` epochs.
''' '''
use_gpu = True
place = 'gpu' if use_gpu else 'cpu'
paddle.set_device(place)
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False) train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader( loader = paddle.io.DataLoader(
...@@ -253,10 +249,6 @@ class Trainer(object): ...@@ -253,10 +249,6 @@ class Trainer(object):
batch_size(int) : Batch size of per step, default is 1. batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0. num_workers(int) : Number of subprocess to load data, default is 0.
''' '''
use_gpu = True
place = 'gpu' if use_gpu else 'cpu'
paddle.set_device(place)
batch_sampler = paddle.io.DistributedBatchSampler( batch_sampler = paddle.io.DistributedBatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False) eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
...@@ -268,18 +260,19 @@ class Trainer(object): ...@@ -268,18 +260,19 @@ class Trainer(object):
sum_metrics = defaultdict(int) sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int) avg_metrics = defaultdict(int)
for batch_idx, batch in enumerate(loader): with processing('Evaluation on validation dataset'):
result = self.validation_step(batch, batch_idx) for batch_idx, batch in enumerate(loader):
loss = result.get('loss', None) result = self.validation_step(batch, batch_idx)
metrics = result.get('metrics', {}) loss = result.get('loss', None)
bs = batch[0].shape[0] metrics = result.get('metrics', {})
num_samples += bs bs = batch[0].shape[0]
num_samples += bs
if loss: if loss:
avg_loss += loss.numpy()[0] * bs avg_loss += loss.numpy()[0] * bs
for metric, value in metrics.items(): for metric, value in metrics.items():
sum_metrics[metric] += value.numpy()[0] * bs sum_metrics[metric] += value.numpy()[0] * bs
# print avg metrics and loss # print avg metrics and loss
print_msg = '[Evaluation result]' print_msg = '[Evaluation result]'
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import copy import copy
import functools import functools
import logging import logging
import sys import sys
import time import time
import threading
from typing import List from typing import List
import colorlog import colorlog
...@@ -102,6 +104,13 @@ class Logger(object): ...@@ -102,6 +104,13 @@ class Logger(object):
self.logger.log(log_level, msg) self.logger.log(log_level, msg)
@contextlib.contextmanager
def use_terminator(self, terminator: str):
old_terminator = self.handler.terminator
self.handler.terminator = terminator
yield
self.handler.terminator = old_terminator
class ProgressBar(object): class ProgressBar(object):
''' '''
...@@ -161,6 +170,28 @@ class ProgressBar(object): ...@@ -161,6 +170,28 @@ class ProgressBar(object):
sys.stdout.write('\n') sys.stdout.write('\n')
@contextlib.contextmanager
def processing(msg: str, interval: float = 0.1):
'''
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with logger.use_terminator('\r'):
logger.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class FormattedText(object): class FormattedText(object):
''' '''
Cross-platform formatted string Cross-platform formatted string
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册