未验证 提交 3dfa7179 编写于 作者: T Teng Xi 提交者: GitHub

fix windows env (#312)

* fix windows env

* fix windows envs
上级 cf4d540f
......@@ -150,7 +150,10 @@ def train(exe, train_program, train_out, test_program, test_out, args):
def build_program(program, startup, args, is_train=True):
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir)
......@@ -166,7 +169,7 @@ def build_program(program, startup, args, is_train=True):
image = fluid.data(
name='image', shape=[-1, 3, 112, 96], dtype='float32')
label = fluid.data(name='label', shape=[-1, 1], dtype='int64')
train_reader = paddle.fluid.io.batch(
train_reader = fluid.io.batch(
train_dataset.reader,
batch_size=args.train_batchsize // num_trainers,
drop_last=False)
......@@ -187,7 +190,7 @@ def build_program(program, startup, args, is_train=True):
else:
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
......@@ -231,7 +234,7 @@ def build_program(program, startup, args, is_train=True):
def quant_val_reader_batch():
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader, batch_size=1, drop_last=False)
shuffle_reader = fluid.io.shuffle(test_reader, 3)
......@@ -298,14 +301,16 @@ def main():
help='The path of the extract features save, must be .mat file')
args = parser.parse_args()
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
else:
num_trainers = 1
print(args)
print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None:
args.save_ckpt = 'output'
if not os.path.exists(args.save_ckpt):
subprocess.call(['mkdir', '-p', args.save_ckpt])
if not os.path.isdir(args.save_ckpt):
os.makedirs(args.save_ckpt)
with open(os.path.join(args.save_ckpt, 'log.txt'), 'w+') as f:
f.writelines(str(args) + '\n')
f.writelines('num_trainers: {}'.format(num_trainers) + '\n')
......@@ -346,7 +351,7 @@ def main():
executor=exe)
nl, nr, flods, flags = parse_filelist(args.test_data_dir)
test_dataset = LFW(nl, nr)
test_reader = paddle.fluid.io.batch(
test_reader = fluid.io.batch(
test_dataset.reader,
batch_size=args.test_batchsize,
drop_last=False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册