From 3b5d7f7ba93fe410902a128e2ba9764d19c0e821 Mon Sep 17 00:00:00 2001 From: Steffy-zxf <48793257+Steffy-zxf@users.noreply.github.com> Date: Tue, 1 Dec 2020 16:26:03 +0800 Subject: [PATCH] fix trainer bug --- demo/image_classification/train.py | 20 ++++++++++++-------- paddlehub/datasets/chnsenticorp.py | 2 +- paddlehub/finetune/trainer.py | 3 +++ 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/demo/image_classification/train.py b/demo/image_classification/train.py index 17af08bc..3a8a7035 100644 --- a/demo/image_classification/train.py +++ b/demo/image_classification/train.py @@ -5,14 +5,18 @@ from paddlehub.finetune.trainer import Trainer from paddlehub.datasets import Flowers if __name__ == '__main__': - transforms = T.Compose([T.Resize((256, 256)), - T.CenterCrop(224), - T.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])], - to_rgb=True) - + transforms = T.Compose( + [T.Resize((256, 256)), + T.CenterCrop(224), + T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])], + to_rgb=True) + flowers = Flowers(transforms) flowers_validate = Flowers(transforms, mode='val') - model = hub.Module(name='resnet50_vd_imagenet_ssld', label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"], load_checkpoint=None) + model = hub.Module( + name='resnet50_vd_imagenet_ssld', + label_list=["roses", "tulips", "daisy", "sunflowers", "dandelion"], + load_checkpoint=None) optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) - trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt') - trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=10) \ No newline at end of file + trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt', use_gpu=True) + trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=10) diff --git a/paddlehub/datasets/chnsenticorp.py b/paddlehub/datasets/chnsenticorp.py index d0daef48..3c342851 100644 --- a/paddlehub/datasets/chnsenticorp.py +++ b/paddlehub/datasets/chnsenticorp.py @@ -57,7 +57,7 @@ class ChnSentiCorp(TextClassificationDataset): data_file = 'test.tsv' else: data_file = 'dev.tsv' - super(ChnSentiCorp, self).__init__( + super().__init__( base_path=base_path, tokenizer=tokenizer, max_seq_len=max_seq_len, diff --git a/paddlehub/finetune/trainer.py b/paddlehub/finetune/trainer.py index d62c20d2..d0634fc3 100644 --- a/paddlehub/finetune/trainer.py +++ b/paddlehub/finetune/trainer.py @@ -18,6 +18,7 @@ import time from collections import defaultdict from typing import Any, Callable, Generic, List +import numpy as np import paddle from visualdl import LogWriter @@ -223,6 +224,8 @@ class Trainer(object): if self.use_vdl: self.log_writer.add_scalar( tag='TRAIN/{}'.format(metric), step=timer.current_step, value=value) + if isinstance(value, np.ndarray): + value = value.item() print_msg += ' {}={:.4f}'.format(metric, value) print_msg += ' lr={:.6f} step/sec={:.2f} | ETA {}'.format(lr, timer.timing, timer.eta) -- GitLab