From c44ffb08b33e5dfc3348c68319ef94548001ad46 Mon Sep 17 00:00:00 2001 From: phlrain Date: Fri, 2 Nov 2018 13:36:42 +0800 Subject: [PATCH] fix python35 bug --- fluid/PaddleNLP/language_model/lstm/_ce.py | 7 ++++--- fluid/PaddleNLP/language_model/lstm/reader.py | 5 ++++- fluid/PaddleNLP/language_model/lstm/train.py | 8 +++++--- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/fluid/PaddleNLP/language_model/lstm/_ce.py b/fluid/PaddleNLP/language_model/lstm/_ce.py index f537f6aa..338d09b6 100644 --- a/fluid/PaddleNLP/language_model/lstm/_ce.py +++ b/fluid/PaddleNLP/language_model/lstm/_ce.py @@ -35,9 +35,10 @@ def parse_log(log): for line in log.split('\n'): fs = line.strip().split('\t') print(fs) - kpi_name = fs[0] - kpi_value = float(fs[1]) - yield kpi_name, kpi_value + if len(fs) == 3 and fs[0] == 'ptblm': + kpi_name = fs[1] + kpi_value = float(fs[2]) + yield kpi_name, kpi_value def log_to_ce(log): diff --git a/fluid/PaddleNLP/language_model/lstm/reader.py b/fluid/PaddleNLP/language_model/lstm/reader.py index 50e8835e..8c0551d8 100644 --- a/fluid/PaddleNLP/language_model/lstm/reader.py +++ b/fluid/PaddleNLP/language_model/lstm/reader.py @@ -28,7 +28,10 @@ Py3 = sys.version_info[0] == 3 def _read_words(filename): data = [] with open(filename, "r") as f: - return f.read().decode("utf-8").replace("\n", "").split() + if Py3: + return f.read().replace("\n", "").split() + else: + return f.read().decode("utf-8").replace("\n", "").split() def _build_vocab(filename): diff --git a/fluid/PaddleNLP/language_model/lstm/train.py b/fluid/PaddleNLP/language_model/lstm/train.py index fc058c6a..42bab12b 100644 --- a/fluid/PaddleNLP/language_model/lstm/train.py +++ b/fluid/PaddleNLP/language_model/lstm/train.py @@ -258,7 +258,8 @@ def train(): fetch_list=[ loss.name, last_hidden.name, last_cell.name, 'learning_rate' - ]) + ], + use_program_cache=True) cost_train = np.array(fetch_outs[0]) init_hidden = np.array(fetch_outs[1]) @@ -282,8 +283,9 @@ def train(): print("train ppl", ppl[0]) if epoch_id == max_epoch - 1 and args.enable_ce: - print("lstm_language_model_duration\t%s" % (total_time / max_epoch)) - print("lstm_language_model_loss\t%s" % ppl[0]) + print("ptblm\tlstm_language_model_duration\t%s" % + (total_time / max_epoch)) + print("ptblm\tlstm_language_model_loss\t%s" % ppl[0]) model_path = os.path.join("model_new/", str(epoch_id)) if not os.path.isdir(model_path): -- GitLab