提交 c07d1ffe 编写于 作者: W wuzewu

Update demo

上级 8917fb23
import paddle.fluid as fluid
import paddle
import paddlehub as hub
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed import ParallelEnv
from paddlehub.finetune.trainer import Trainer
from paddlehub.datasets.flowers import Flowers
from paddlehub.process.transforms import Compose, Resize, Normalize
from paddlehub.module.cv_module import ImageClassifierModule
if __name__ == '__main__':
with fluid.dygraph.guard(fluid.CUDAPlace(ParallelEnv().dev_id)):
transforms = Compose([Resize((224, 224)), Normalize()])
flowers = Flowers(transforms)
flowers_validate = Flowers(transforms, mode='val')
paddle.disable_static(paddle.CUDAPlace(ParallelEnv().dev_id))
transforms = Compose([Resize((224, 224)), Normalize()])
flowers = Flowers(transforms)
flowers_validate = Flowers(transforms, mode='val')
model = hub.Module(directory='mobilenet_v2_animals', class_dim=flowers.num_classes)
# model = hub.Module(name='mobilenet_v2_animals', class_dim=flowers.num_classes)
model = hub.Module(name='mobilenet_v2_imagenet', class_dim=flowers.num_classes)
optimizer = fluid.optimizer.AdamOptimizer(learning_rate=0.001, parameter_list=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='test_ckpt_img_cls')
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1)
trainer.train(flowers, epochs=100, batch_size=32, eval_dataset=flowers_validate, save_interval=1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册