提交 f5e277b1 编写于 作者: G guochaorong

fix get cuda for language_model

上级 a79a7bcf
......@@ -164,10 +164,13 @@ def train(train_reader,
print("finish training")
def get_cards():
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
def get_cards(enable_ce):
if enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(","))
return num
else:
return fluid.core.get_cuda_device_count()
def train_net():
......@@ -175,7 +178,7 @@ def train_net():
batch_size = 20
args = parse_args()
vocab, train_reader, test_reader = utils.prepare_data(
batch_size=batch_size * get_cards(), buffer_size=1000, \
batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \
word_freq_threshold=0, enable_ce = args.enable_ce)
train(
train_reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册