diff --git a/fluid/language_model/train.py b/fluid/language_model/train.py index ac8dcd2949cd4054c091c3f64ce04164e2dcbe9b..f3e7a7398bf13e14c74ce1d10d90b7bf34031698 100644 --- a/fluid/language_model/train.py +++ b/fluid/language_model/train.py @@ -164,10 +164,13 @@ def train(train_reader, print("finish training") -def get_cards(): - cards = os.environ.get('CUDA_VISIBLE_DEVICES') - num = len(cards.split(",")) - return num +def get_cards(enable_ce): + if enable_ce: + cards = os.environ.get('CUDA_VISIBLE_DEVICES') + num = len(cards.split(",")) + return num + else: + return fluid.core.get_cuda_device_count() def train_net(): @@ -175,7 +178,7 @@ def train_net(): batch_size = 20 args = parse_args() vocab, train_reader, test_reader = utils.prepare_data( - batch_size=batch_size * get_cards(), buffer_size=1000, \ + batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \ word_freq_threshold=0, enable_ce = args.enable_ce) train( train_reader=train_reader,