提交 c44ffb08 编写于 作者: P phlrain

fix python35 bug

上级 5805cf16
......@@ -35,9 +35,10 @@ def parse_log(log):
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
kpi_name = fs[0]
kpi_value = float(fs[1])
yield kpi_name, kpi_value
if len(fs) == 3 and fs[0] == 'ptblm':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
......
......@@ -28,7 +28,10 @@ Py3 = sys.version_info[0] == 3
def _read_words(filename):
data = []
with open(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
if Py3:
return f.read().replace("\n", "<eos>").split()
else:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename):
......
......@@ -258,7 +258,8 @@ def train():
fetch_list=[
loss.name, last_hidden.name,
last_cell.name, 'learning_rate'
])
],
use_program_cache=True)
cost_train = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1])
......@@ -282,8 +283,9 @@ def train():
print("train ppl", ppl[0])
if epoch_id == max_epoch - 1 and args.enable_ce:
print("lstm_language_model_duration\t%s" % (total_time / max_epoch))
print("lstm_language_model_loss\t%s" % ppl[0])
print("ptblm\tlstm_language_model_duration\t%s" %
(total_time / max_epoch))
print("ptblm\tlstm_language_model_loss\t%s" % ppl[0])
model_path = os.path.join("model_new/", str(epoch_id))
if not os.path.isdir(model_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册