retrain.py 1.3 KB
Newer Older
W
wuzewu 已提交
1 2 3
import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
4 5
import paddlehub as hub

W
wuzewu 已提交
6 7

def train():
W
wuzewu 已提交
8
    resnet_module = hub.Module(module_dir="ResNet50.hub_module")
W
wuzewu 已提交
9
    input_dict, output_dict, program = resnet_module.context(
W
wuzewu 已提交
10 11
        sign_name="feature_map", trainable=True)
    dataset = hub.dataset.Flowers()
W
wuzewu 已提交
12
    data_reader = hub.reader.ImageClassificationReader(
13 14 15 16 17
        image_width=resnet_module.get_excepted_image_width(),
        image_height=resnet_module.get_excepted_image_height(),
        images_mean=resnet_module.get_pretrained_images_mean(),
        images_std=resnet_module.get_pretrained_images_std(),
        dataset=dataset)
W
wuzewu 已提交
18 19 20 21 22
    with fluid.program_guard(program):
        label = fluid.layers.data(name="label", dtype="int64", shape=[1])
        img = input_dict[0]
        feature_map = output_dict[0]

W
wuzewu 已提交
23
        config = hub.RunConfig(
W
wuzewu 已提交
24 25 26
            use_cuda=True,
            num_epoch=10,
            batch_size=32,
27
            enable_memory_optim=False,
W
wuzewu 已提交
28
            strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
W
wuzewu 已提交
29 30 31

        feed_list = [img.name, label.name]

W
wuzewu 已提交
32
        task = hub.create_img_classification_task(
W
wuzewu 已提交
33 34 35
            feature=feature_map, label=label, num_classes=dataset.num_labels)
        hub.finetune_and_eval(
            task, feed_list=feed_list, data_reader=data_reader, config=config)
W
wuzewu 已提交
36 37 38 39


if __name__ == "__main__":
    train()