From f5e277b1a950aaec105c628b73aadb1388f6f549 Mon Sep 17 00:00:00 2001 From: guochaorong Date: Wed, 8 Aug 2018 15:06:34 +0800 Subject: [PATCH] fix get cuda for language_model --- fluid/language_model/train.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/fluid/language_model/train.py b/fluid/language_model/train.py index ac8dcd29..f3e7a739 100644 --- a/fluid/language_model/train.py +++ b/fluid/language_model/train.py @@ -164,10 +164,13 @@ def train(train_reader, print("finish training") -def get_cards(): - cards = os.environ.get('CUDA_VISIBLE_DEVICES') - num = len(cards.split(",")) - return num +def get_cards(enable_ce): + if enable_ce: + cards = os.environ.get('CUDA_VISIBLE_DEVICES') + num = len(cards.split(",")) + return num + else: + return fluid.core.get_cuda_device_count() def train_net(): @@ -175,7 +178,7 @@ def train_net(): batch_size = 20 args = parse_args() vocab, train_reader, test_reader = utils.prepare_data( - batch_size=batch_size * get_cards(), buffer_size=1000, \ + batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \ word_freq_threshold=0, enable_ce = args.enable_ce) train( train_reader=train_reader, -- GitLab