diff --git a/configs/quick_start/ResNet50_vd_finetune_kunlun.yaml b/configs/quick_start/ResNet50_vd_finetune_kunlun.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e931f327693d43550a6e463746d92bcf502c3986 --- /dev/null +++ b/configs/quick_start/ResNet50_vd_finetune_kunlun.yaml @@ -0,0 +1,71 @@ +mode: 'train' +ARCHITECTURE: + name: 'ResNet50_vd' +pretrained_model: "./pretrained/ResNet50_vd_pretrained" +load_static_weights: true +model_save_dir: "./output/" +classes_num: 102 +total_images: 1020 +save_interval: 1 +validate: True +valid_interval: 1 +epochs: 120 +topk: 5 +image_shape: [3, 224, 224] + +LEARNING_RATE: + function: 'Cosine' + params: + lr: 0.00375 + +OPTIMIZER: + function: 'Momentum' + params: + momentum: 0.9 + regularizer: + function: 'L2' + factor: 0.000001 + +TRAIN: + batch_size: 20 + num_workers: 1 + file_list: "./dataset/flowers102/train_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1./255. + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + +VALID: + batch_size: 20 + num_workers: 1 + file_list: "./dataset/flowers102/val_list.txt" + data_dir: "./dataset/flowers102/" + shuffle_seed: 0 + transforms: + - DecodeImage: + to_rgb: True + to_np: False + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: diff --git a/tools/static/train.py b/tools/static/train.py index 247b8ea4587bf753e55d41f899ff94aed32342be..6de487a51ae0c139f1a61b6312a485b9a916afc2 100644 --- a/tools/static/train.py +++ b/tools/static/train.py @@ -63,9 +63,12 @@ def main(args): config = get_config(args.config, overrides=args.override, show=True) # assign the place - use_gpu = config.get("use_gpu", True) - assert use_gpu is True, "gpu must be true in static mode!" - place = paddle.set_device("gpu") + use_gpu = config.get("use_gpu", False) + use_xpu = config.get("use_xpu", False) + assert (use_gpu or use_xpu) is True, "gpu or xpu must be true in static mode!" + assert (use_gpu and use_xpu) is not True, "gpu and xpu can not be true in the same time in static mode!" + + place = paddle.set_device('gpu' if use_gpu else 'xpu') # startup_prog is used to do some parameter init work, # and train prog is used to hold the network @@ -75,12 +78,12 @@ def main(args): best_top1_acc = 0.0 # best top1 acc record train_fetchs, lr_scheduler, train_feeds = program.build( - config, train_prog, startup_prog, is_train=True) + config, train_prog, startup_prog, is_train=True, is_distributed=config.get("is_distributed", True)) if config.validate: valid_prog = paddle.static.Program() valid_fetchs, _, valid_feeds = program.build( - config, valid_prog, startup_prog, is_train=False) + config, valid_prog, startup_prog, is_train=False, is_distributed=config.get("is_distributed", True)) # clone to prune some content which is irrelevant in valid_prog valid_prog = valid_prog.clone(for_test=True) @@ -94,7 +97,10 @@ def main(args): if config.validate and paddle.distributed.get_rank() == 0: valid_dataloader = Reader(config, 'valid', places=place)() - compiled_valid_prog = program.compile(config, valid_prog) + if use_xpu: + compiled_valid_prog = valid_prog + else: + compiled_valid_prog = program.compile(config, valid_prog) vdl_writer = None if args.vdl_dir: