提交 76e08faa 编写于 作者: W wuzewu

Show processing information when evaluating the model

上级 19d9ceb4
......@@ -23,7 +23,7 @@ import paddle
from paddle.distributed import ParallelEnv
from visualdl import LogWriter
from paddlehub.utils.log import logger
from paddlehub.utils.log import logger, processing
from paddlehub.utils.utils import Timer
......@@ -111,14 +111,14 @@ class Trainer(object):
self.current_epoch, metric_msg))
# load model checkpoint
model_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdparmas')
model_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdparams')
state_dict = paddle.load(model_params_path)
self.model.set_dict(state_dict)
self.model.set_state_dict(state_dict)
# load optimizer checkpoint
optim_params_path = os.path.join(self.checkpoint_dir, 'epoch_{}'.format(self.current_epoch), 'model.pdopt')
state_dict = paddle.load(optim_params_path)
self.optimizer.set_dict(state_dict)
self.optimizer.set_state_dict(state_dict)
def _save_checkpoint(self):
'''Save model checkpoint and state dict'''
......@@ -131,7 +131,7 @@ class Trainer(object):
model_params_path = os.path.join(save_dir, 'model.pdparams')
optim_params_path = os.path.join(save_dir, 'model.pdopt')
paddle.save(self.model.state_dict(), model_params_path)
paddle.save(self.model.state_dict(), optim_params_path)
paddle.save(self.optimizer.state_dict(), optim_params_path)
def _save_metrics(self):
with open(os.path.join(self.checkpoint_dir, 'metrics.pkl'), 'wb') as file:
......@@ -162,10 +162,6 @@ class Trainer(object):
log_interval(int) : Log the train infomation every `log_interval` steps.
save_interval(int) : Save the checkpoint every `save_interval` epochs.
'''
use_gpu = True
place = 'gpu' if use_gpu else 'cpu'
paddle.set_device(place)
batch_sampler = paddle.io.DistributedBatchSampler(
train_dataset, batch_size=batch_size, shuffle=True, drop_last=False)
loader = paddle.io.DataLoader(
......@@ -253,10 +249,6 @@ class Trainer(object):
batch_size(int) : Batch size of per step, default is 1.
num_workers(int) : Number of subprocess to load data, default is 0.
'''
use_gpu = True
place = 'gpu' if use_gpu else 'cpu'
paddle.set_device(place)
batch_sampler = paddle.io.DistributedBatchSampler(
eval_dataset, batch_size=batch_size, shuffle=False, drop_last=False)
......@@ -268,18 +260,19 @@ class Trainer(object):
sum_metrics = defaultdict(int)
avg_metrics = defaultdict(int)
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
with processing('Evaluation on validation dataset'):
for batch_idx, batch in enumerate(loader):
result = self.validation_step(batch, batch_idx)
loss = result.get('loss', None)
metrics = result.get('metrics', {})
bs = batch[0].shape[0]
num_samples += bs
if loss:
avg_loss += loss.numpy()[0] * bs
if loss:
avg_loss += loss.numpy()[0] * bs
for metric, value in metrics.items():
sum_metrics[metric] += value.numpy()[0] * bs
for metric, value in metrics.items():
sum_metrics[metric] += value.numpy()[0] * bs
# print avg metrics and loss
print_msg = '[Evaluation result]'
......
......@@ -13,11 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import copy
import functools
import logging
import sys
import time
import threading
from typing import List
import colorlog
......@@ -102,6 +104,13 @@ class Logger(object):
self.logger.log(log_level, msg)
@contextlib.contextmanager
def use_terminator(self, terminator: str):
old_terminator = self.handler.terminator
self.handler.terminator = terminator
yield
self.handler.terminator = old_terminator
class ProgressBar(object):
'''
......@@ -161,6 +170,28 @@ class ProgressBar(object):
sys.stdout.write('\n')
@contextlib.contextmanager
def processing(msg: str, interval: float = 0.1):
'''
'''
end = False
def _printer():
index = 0
flags = ['\\', '|', '/', '-']
while not end:
flag = flags[index % len(flags)]
with logger.use_terminator('\r'):
logger.info('{}: {}'.format(msg, flag))
time.sleep(interval)
index += 1
t = threading.Thread(target=_printer)
t.start()
yield
end = True
class FormattedText(object):
'''
Cross-platform formatted string
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册