未验证 提交 832364e1 编写于 作者: Q QingshuChen 提交者: GitHub

support static graph train for kunlun (#441)

上级 8e5bde6c
mode: 'train'
ARCHITECTURE:
name: 'ResNet50_vd'
pretrained_model: "./pretrained/ResNet50_vd_pretrained"
load_static_weights: true
model_save_dir: "./output/"
classes_num: 102
total_images: 1020
save_interval: 1
validate: True
valid_interval: 1
epochs: 120
topk: 5
image_shape: [3, 224, 224]
LEARNING_RATE:
function: 'Cosine'
params:
lr: 0.00375
OPTIMIZER:
function: 'Momentum'
params:
momentum: 0.9
regularizer:
function: 'L2'
factor: 0.000001
TRAIN:
batch_size: 20
num_workers: 1
file_list: "./dataset/flowers102/train_list.txt"
data_dir: "./dataset/flowers102/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- RandCropImage:
size: 224
- RandFlipImage:
flip_code: 1
- NormalizeImage:
scale: 1./255.
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
VALID:
batch_size: 20
num_workers: 1
file_list: "./dataset/flowers102/val_list.txt"
data_dir: "./dataset/flowers102/"
shuffle_seed: 0
transforms:
- DecodeImage:
to_rgb: True
to_np: False
channel_first: False
- ResizeImage:
resize_short: 256
- CropImage:
size: 224
- NormalizeImage:
scale: 1.0/255.0
mean: [0.485, 0.456, 0.406]
std: [0.229, 0.224, 0.225]
order: ''
- ToCHWImage:
......@@ -63,9 +63,12 @@ def main(args):
config = get_config(args.config, overrides=args.override, show=True)
# assign the place
use_gpu = config.get("use_gpu", True)
assert use_gpu is True, "gpu must be true in static mode!"
place = paddle.set_device("gpu")
use_gpu = config.get("use_gpu", False)
use_xpu = config.get("use_xpu", False)
assert (use_gpu or use_xpu) is True, "gpu or xpu must be true in static mode!"
assert (use_gpu and use_xpu) is not True, "gpu and xpu can not be true in the same time in static mode!"
place = paddle.set_device('gpu' if use_gpu else 'xpu')
# startup_prog is used to do some parameter init work,
# and train prog is used to hold the network
......@@ -75,12 +78,12 @@ def main(args):
best_top1_acc = 0.0 # best top1 acc record
train_fetchs, lr_scheduler, train_feeds = program.build(
config, train_prog, startup_prog, is_train=True)
config, train_prog, startup_prog, is_train=True, is_distributed=config.get("is_distributed", True))
if config.validate:
valid_prog = paddle.static.Program()
valid_fetchs, _, valid_feeds = program.build(
config, valid_prog, startup_prog, is_train=False)
config, valid_prog, startup_prog, is_train=False, is_distributed=config.get("is_distributed", True))
# clone to prune some content which is irrelevant in valid_prog
valid_prog = valid_prog.clone(for_test=True)
......@@ -94,7 +97,10 @@ def main(args):
if config.validate and paddle.distributed.get_rank() == 0:
valid_dataloader = Reader(config, 'valid', places=place)()
compiled_valid_prog = program.compile(config, valid_prog)
if use_xpu:
compiled_valid_prog = valid_prog
else:
compiled_valid_prog = program.compile(config, valid_prog)
vdl_writer = None
if args.vdl_dir:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册