未验证 提交 3b5d7f7b 编写于 作者: S Steffy-zxf 提交者: GitHub

fix trainer bug

上级 63183066
......@@ -5,14 +5,18 @@ from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets import Flowers
if __name__ == '__main__':
transforms = T.Compose([T.Resize((256, 256)),
T.CenterCrop(224),
T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])],
to_rgb=True)
transforms = T.Compose(
[T.Resize((256, 256)),
T.CenterCrop(224),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])],
to_rgb=True)
flowers = Flowers(transforms)
flowers_validate = Flowers(transforms, mode='val')
model = hub.Module(name='resnet50_vd_imagenet_ssld', label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"], load_checkpoint=None)
model = hub.Module(
name='resnet50_vd_imagenet_ssld',
label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"],
load_checkpoint=None)
optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt')
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=10)
\ No newline at end of file
trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt', use_gpu=True)
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=10)
......@@ -57,7 +57,7 @@ class ChnSentiCorp(TextClassificationDataset):
data_file = 'test.tsv'
else:
data_file = 'dev.tsv'
super(ChnSentiCorp, self).__init__(
super().__init__(
base_path=base_path,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
......
......@@ -18,6 +18,7 @@ import time
from collections import defaultdict
from typing import Any, Callable, Generic, List
import numpy as np
import paddle
from visualdl import LogWriter
......@@ -223,6 +224,8 @@ class Trainer(object):
if self.use_vdl:
self.log_writer.add_scalar(
tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value)
if isinstance(value, np.ndarray):
value = value.item()
print_msg += ' {}={:.4f}'.format(metric, value)
print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册