提交 70e77f98 编写于 作者: X Xin Pan

fix lm

上级 3eacba37
...@@ -21,6 +21,11 @@ def parse_args(): ...@@ -21,6 +21,11 @@ def parse_args():
action='store_true', action='store_true',
help='If set, run \ help='If set, run \
the task with continuous evaluation logs.') the task with continuous evaluation logs.')
parser.add_argument(
'--num_devices',
type=int,
default=1,
help='Number of GPU devices')
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -132,7 +137,7 @@ def train(train_reader, ...@@ -132,7 +137,7 @@ def train(train_reader,
"src_wordseq": lod_src_wordseq, "src_wordseq": lod_src_wordseq,
"dst_wordseq": lod_dst_wordseq "dst_wordseq": lod_dst_wordseq
}, },
fetch_list=fetch_list) fetch_list=fetch_list)
avg_ppl = np.exp(ret_avg_cost[0]) avg_ppl = np.exp(ret_avg_cost[0])
newest_ppl = np.mean(avg_ppl) newest_ppl = np.mean(avg_ppl)
if i % 100 == 0: if i % 100 == 0:
...@@ -153,7 +158,7 @@ def train(train_reader, ...@@ -153,7 +158,7 @@ def train(train_reader,
print("kpis imikolov_20_avg_ppl %s" % newest_ppl) print("kpis imikolov_20_avg_ppl %s" % newest_ppl)
else: else:
print("kpis imikolov_20_pass_duration_card%s %s" % \ print("kpis imikolov_20_pass_duration_card%s %s" % \
(gpu_num, total_time / epoch_idx)) (gpu_num, total_time / epoch_idx))
print("kpis imikolov_20_avg_ppl_card%s %s" % print("kpis imikolov_20_avg_ppl_card%s %s" %
(gpu_num, newest_ppl)) (gpu_num, newest_ppl))
save_dir = "%s/epoch_%d" % (model_dir, epoch_idx) save_dir = "%s/epoch_%d" % (model_dir, epoch_idx)
...@@ -165,13 +170,13 @@ def train(train_reader, ...@@ -165,13 +170,13 @@ def train(train_reader,
print("finish training") print("finish training")
def get_cards(enable_ce): def get_cards(args):
if enable_ce: if args.enable_ce:
cards = os.environ.get('CUDA_VISIBLE_DEVICES') cards = os.environ.get('CUDA_VISIBLE_DEVICES')
num = len(cards.split(",")) num = len(cards.split(","))
return num return num
else: else:
return fluid.core.get_cuda_device_count() return args.num_devices
def train_net(): def train_net():
...@@ -179,7 +184,7 @@ def train_net(): ...@@ -179,7 +184,7 @@ def train_net():
batch_size = 20 batch_size = 20
args = parse_args() args = parse_args()
vocab, train_reader, test_reader = utils.prepare_data( vocab, train_reader, test_reader = utils.prepare_data(
batch_size=batch_size * get_cards(args.enable_ce), buffer_size=1000, \ batch_size=batch_size * get_cards(args), buffer_size=1000, \
word_freq_threshold=0, enable_ce = args.enable_ce) word_freq_threshold=0, enable_ce = args.enable_ce)
train( train(
train_reader=train_reader, train_reader=train_reader,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册