未验证 提交 e8160b1c 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #1424 from phlrain/fix_lstm_python35

fix python35 bug
...@@ -35,9 +35,10 @@ def parse_log(log): ...@@ -35,9 +35,10 @@ def parse_log(log):
for line in log.split('\n'): for line in log.split('\n'):
fs = line.strip().split('\t') fs = line.strip().split('\t')
print(fs) print(fs)
kpi_name = fs[0] if len(fs) == 3 and fs[0] == 'ptblm':
kpi_value = float(fs[1]) kpi_name = fs[1]
yield kpi_name, kpi_value kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log): def log_to_ce(log):
......
...@@ -28,7 +28,10 @@ Py3 = sys.version_info[0] == 3 ...@@ -28,7 +28,10 @@ Py3 = sys.version_info[0] == 3
def _read_words(filename): def _read_words(filename):
data = [] data = []
with open(filename, "r") as f: with open(filename, "r") as f:
return f.read().decode("utf-8").replace("\n", "<eos>").split() if Py3:
return f.read().replace("\n", "<eos>").split()
else:
return f.read().decode("utf-8").replace("\n", "<eos>").split()
def _build_vocab(filename): def _build_vocab(filename):
......
...@@ -258,7 +258,8 @@ def train(): ...@@ -258,7 +258,8 @@ def train():
fetch_list=[ fetch_list=[
loss.name, last_hidden.name, loss.name, last_hidden.name,
last_cell.name, 'learning_rate' last_cell.name, 'learning_rate'
]) ],
use_program_cache=True)
cost_train = np.array(fetch_outs[0]) cost_train = np.array(fetch_outs[0])
init_hidden = np.array(fetch_outs[1]) init_hidden = np.array(fetch_outs[1])
...@@ -282,8 +283,9 @@ def train(): ...@@ -282,8 +283,9 @@ def train():
print("train ppl", ppl[0]) print("train ppl", ppl[0])
if epoch_id == max_epoch - 1 and args.enable_ce: if epoch_id == max_epoch - 1 and args.enable_ce:
print("lstm_language_model_duration\t%s" % (total_time / max_epoch)) print("ptblm\tlstm_language_model_duration\t%s" %
print("lstm_language_model_loss\t%s" % ppl[0]) (total_time / max_epoch))
print("ptblm\tlstm_language_model_loss\t%s" % ppl[0])
model_path = os.path.join("model_new/", str(epoch_id)) model_path = os.path.join("model_new/", str(epoch_id))
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册