提交 70e77f98 编写于 作者: X Xin Pan

fix lm

上级 3eacba37
...@@ -21,6 +21,11 @@ def parse_args(): ...@@ -21,6 +21,11 @@ def parse_args():
action='store_true', action='store_true',
help='If set, run \ help='If set, run \
the task with continuous evaluation logs.') the task with continuous evaluation logs.')
parser.add_argument(
'--num_devices',
type=int,
default=1,
help='Number of GPU devices')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -165,13 +170,13 @@ def train(train_reader, ...@@ -165,13 +170,13 @@ def train(train_reader,
print("finish training") print("finish training")
def get_cards(enable_ce): def get_cards(args):
if enable_ce: if args.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES') cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(",")) num = len(cards.split(","))
return num return num
else: else:
return fluid.core.get_cuda_device_count() return args.num_devices
def train_net(): def train_net():
...@@ -179,7 +184,7 @@ def train_net(): ...@@ -179,7 +184,7 @@ def train_net():
batch_size = 20 batch_size = 20
args = parse_args() args = parse_args()
vocab, train_reader, test_reader = utils.prepare_data( vocab, train_reader, test_reader = utils.prepare_data(
batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \ batch_size=batch_size * get_cards(args), buffer_size=1000, \
word_freq_threshold=0, enable_ce = args.enable_ce) word_freq_threshold=0, enable_ce = args.enable_ce)
train( train(
train_reader=train_reader, train_reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册