未验证 提交 3dfa7179 编写于 作者: T Teng Xi 提交者: GitHub

fix windows env (#312)

* fix windows env

* fix windows envs
上级 cf4d540f
...@@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args): ...@@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args):
def build_program(program, startup, args, is_train=True): def build_program(program, startup, args, is_train=True):
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace() places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir) train_dataset = CASIA_Face(root=args.train_data_dir)
...@@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True):
image = fluid.data( image = fluid.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32') name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64') label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.fluid.io.batch( train_reader = fluid.io.batch(
train_dataset.reader, train_dataset.reader,
batch_size=args.train_batchsize // num_trainers, batch_size=args.train_batchsize // num_trainers,
drop_last=False) drop_last=False)
...@@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True):
else: else:
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch( test_reader = fluid.io.batch(
test_dataset.reader, test_dataset.reader,
batch_size=args.test_batchsize, batch_size=args.test_batchsize,
drop_last=False) drop_last=False)
...@@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True): ...@@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True):
def quant_val_reader_batch(): def quant_val_reader_batch():
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch( test_reader = fluid.io.batch(
test_dataset.reader, batch_size=1, drop_last=False) test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 3) shuffle_reader = fluid.io.shuffle(test_reader, 3)
...@@ -298,14 +301,16 @@ def main(): ...@@ -298,14 +301,16 @@ def main():
help='The path of the extract features save, must be .mat file') help='The path of the extract features save, must be .mat file')
args = parser.parse_args() args = parser.parse_args()
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(',')) num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
print(args) print(args)
print('num_trainers: {}'.format(num_trainers)) print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None: if args.save_ckpt == None:
args.save_ckpt = 'output' args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt): if not os.path.isdir(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt]) os.makedirs(args.save_ckpt)
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f: with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n') f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n') f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
...@@ -346,7 +351,7 @@ def main(): ...@@ -346,7 +351,7 @@ def main():
executor=exe) executor=exe)
nl, nr, flods, flags = parse_filelist(args.test_data_dir) nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr) test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch( test_reader = fluid.io.batch(
test_dataset.reader, test_dataset.reader,
batch_size=args.test_batchsize, batch_size=args.test_batchsize,
drop_last=False) drop_last=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册