retrain.py 1.1 KB
Newer Older
W
wuzewu 已提交
1 2 3
import paddle
import paddle.fluid as fluid

W
wuzewu 已提交
4 5
import paddlehub as hub

W
wuzewu 已提交
6 7

def train():
W
wuzewu 已提交
8
    resnet_module = hub.Module(module_dir="hub_module_ResNet50")
W
wuzewu 已提交
9
    input_dict, output_dict, program = resnet_module.context(
W
wuzewu 已提交
10 11 12 13
        sign_name="feature_map", trainable=True)
    dataset = hub.dataset.Flowers()
    data_reader = hub.ImageClassificationReader(
        image_width=224, image_height=224, dataset=dataset)
W
wuzewu 已提交
14 15 16 17 18
    with fluid.program_guard(program):
        label = fluid.layers.data(name="label", dtype="int64", shape=[1])
        img = input_dict[0]
        feature_map = output_dict[0]

W
wuzewu 已提交
19
        config = hub.RunConfig(
W
wuzewu 已提交
20 21 22
            use_cuda=True,
            num_epoch=10,
            batch_size=32,
W
wuzewu 已提交
23
            strategy=hub.finetune.strategy.DefaultFinetuneStrategy())
W
wuzewu 已提交
24 25 26

        feed_list = [img.name, label.name]

W
wuzewu 已提交
27 28 29 30
        task = hub.append_mlp_classifier(
            feature=feature_map, label=label, num_classes=dataset.num_labels)
        hub.finetune_and_eval(
            task, feed_list=feed_list, data_reader=data_reader, config=config)
W
wuzewu 已提交
31 32 33 34


if __name__ == "__main__":
    train()