From c07d1ffe345a846e7aff14e1272c2a2f6ff3cad5 Mon Sep 17 00:00:00 2001 From: wuzewu Date: Sun, 27 Sep 2020 10:37:25 +0800 Subject: [PATCH] Update demo --- demo/image_classification/train.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/demo/image_classification/train.py b/demo/image_classification/train.py index 421059cf..01eb6dcb 100644 --- a/demo/image_classification/train.py +++ b/demo/image_classification/train.py @@ -1,21 +1,20 @@ -import paddle.fluid as fluid +import paddle import paddlehub as hub -from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.distributed import ParallelEnv from paddlehub.finetune.trainer import Trainer from paddlehub.datasets.flowers import Flowers from paddlehub.process.transforms import Compose, Resize, Normalize from paddlehub.module.cv_module import ImageClassifierModule if __name__ == '__main__': - with fluid.dygraph.guard(fluid.CUDAPlace(ParallelEnv().dev_id)): - transforms = Compose([Resize((224, 224)), Normalize()]) - flowers = Flowers(transforms) - flowers_validate = Flowers(transforms, mode='val') + paddle.disable_static(paddle.CUDAPlace(ParallelEnv().dev_id)) + transforms = Compose([Resize((224, 224)), Normalize()]) + flowers = Flowers(transforms) + flowers_validate = Flowers(transforms, mode='val') - model = hub.Module(directory='mobilenet_v2_animals', class_dim=flowers.num_classes) - # model = hub.Module(name='mobilenet_v2_animals', class_dim=flowers.num_classes) + model = hub.Module(name='mobilenet_v2_imagenet', class_dim=flowers.num_classes) - optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.001, parameter_list=model.parameters()) - trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') + optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) + trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls') - trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1) + trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1) -- GitLab