未验证 提交 9a722e53 编写于 作者: T Teng Xi 提交者: GitHub

fix windows CPU envs (#315)

上级 3dfa7179
......@@ -151,9 +151,9 @@ def train(exe, train_program, train_out, test_program, test_out, args):
def build_program(program, startup, args, is_train=True):
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
num_trainers = fluid.core.get_cuda_device_count()
else:
num_trainers = 1
num_trainers = int(os.environ.get('CPU_NUM', 1))
places = fluid.cuda_places() if args.use_gpu else fluid.CPUPlace()
train_dataset = CASIA_Face(root=args.train_data_dir)
......@@ -302,9 +302,9 @@ def main():
args = parser.parse_args()
if args.use_gpu:
num_trainers = len(os.getenv('CUDA_VISIBLE_DEVICES').split(','))
num_trainers = fluid.core.get_cuda_device_count()
else:
num_trainers = 1
num_trainers = int(os.environ.get('CPU_NUM', 1))
print(args)
print('num_trainers: {}'.format(num_trainers))
if args.save_ckpt == None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册