diff --git a/demo/slimfacenet/train_eval.py b/demo/slimfacenet/train_eval.py index 5303bfdf0086919d23e9fe794427c9a25ef0e676..4df84d8b563b98deca7e12bad06640ee22825602 100644 --- a/demo/slimfacenet/train_eval.py +++ b/demo/slimfacenet/train_eval.py @@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args): def build_program(program, startup, args, is_train=True): - num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + if args.use_gpu: + num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + else: + num_trainers = 1 places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace() train_dataset = CASIA_Face(root=args.train_data_dir) @@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True): image = fluid.data( name='image', shape=[-1, 3, 112, 96], dtype='float32') label = fluid.data(name='label', shape=[-1, 1], dtype='int64') - train_reader = paddle.fluid.io.batch( + train_reader = fluid.io.batch( train_dataset.reader, batch_size=args.train_batchsize // num_trainers, drop_last=False) @@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True): else: nl, nr, flods, flags = parse_filelist(args.test_data_dir) test_dataset = LFW(nl, nr) - test_reader = paddle.fluid.io.batch( + test_reader = fluid.io.batch( test_dataset.reader, batch_size=args.test_batchsize, drop_last=False) @@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True): def quant_val_reader_batch(): nl, nr, flods, flags = parse_filelist(args.test_data_dir) test_dataset = LFW(nl, nr) - test_reader = paddle.fluid.io.batch( + test_reader = fluid.io.batch( test_dataset.reader, batch_size=1, drop_last=False) shuffle_reader = fluid.io.shuffle(test_reader, 3) @@ -298,14 +301,16 @@ def main(): help='The path of the extract features save, must be .mat file') args = parser.parse_args() - num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + if args.use_gpu: + num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) + else: + num_trainers = 1 print(args) print('num_trainers: {}'.format(num_trainers)) if args.save_ckpt == None: args.save_ckpt = 'output' - if not os.path.exists(args.save_ckpt): - subprocess.call(['mkdir', '-p', args.save_ckpt]) - + if not os.path.isdir(args.save_ckpt): + os.makedirs(args.save_ckpt) with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f: f.writelines(str(args) + '\n') f.writelines('num_trainers: {}'.format(num_trainers) + '\n') @@ -346,7 +351,7 @@ def main(): executor=exe) nl, nr, flods, flags = parse_filelist(args.test_data_dir) test_dataset = LFW(nl, nr) - test_reader = paddle.fluid.io.batch( + test_reader = fluid.io.batch( test_dataset.reader, batch_size=args.test_batchsize, drop_last=False)