diff --git a/fluid/PaddleNLP/language_model/lstm/_ce.py b/fluid/PaddleNLP/language_model/lstm/_ce.py index f537f6aa62dd502c79174fadedda0da621c8eb7b..338d09b63468a982e47f740030bd5ca967d4d00e 100644 --- a/fluid/PaddleNLP/language_model/lstm/_ce.py +++ b/fluid/PaddleNLP/language_model/lstm/_ce.py @@ -35,9 +35,10 @@ def parse_log(log): for line in log.split('\n'): fs = line.strip().split('\t') print(fs) - kpi_name = fs[0] - kpi_value = float(fs[1]) - yield kpi_name, kpi_value + if len(fs) == 3 and fs[0] == 'ptblm': + kpi_name = fs[1] + kpi_value = float(fs[2]) + yield kpi_name, kpi_value def log_to_ce(log): diff --git a/fluid/PaddleNLP/language_model/lstm/reader.py b/fluid/PaddleNLP/language_model/lstm/reader.py index 50e8835ec8b96bf37a7b972700a588034d41425c..8c0551d84f253d9bfd8420e56185780bee7c3145 100644 --- a/fluid/PaddleNLP/language_model/lstm/reader.py +++ b/fluid/PaddleNLP/language_model/lstm/reader.py @@ -28,7 +28,10 @@ Py3 = sys.version_info[0] == 3 def _read_words(filename): data = [] with open(filename, "r") as f: - return f.read().decode("utf-8").replace("\n", "").split() + if Py3: + return f.read().replace("\n", "").split() + else: + return f.read().decode("utf-8").replace("\n", "").split() def _build_vocab(filename): diff --git a/fluid/PaddleNLP/language_model/lstm/train.py b/fluid/PaddleNLP/language_model/lstm/train.py index fc058c6a0e80f4aeba76656fe505207846d66e2f..42bab12bcbdafdabc5ab14370860b4b0ae269dda 100644 --- a/fluid/PaddleNLP/language_model/lstm/train.py +++ b/fluid/PaddleNLP/language_model/lstm/train.py @@ -258,7 +258,8 @@ def train(): fetch_list=[ loss.name, last_hidden.name, last_cell.name, 'learning_rate' - ]) + ], + use_program_cache=True) cost_train = np.array(fetch_outs[0]) init_hidden = np.array(fetch_outs[1]) @@ -282,8 +283,9 @@ def train(): print("train ppl", ppl[0]) if epoch_id == max_epoch - 1 and args.enable_ce: - print("lstm_language_model_duration\t%s" % (total_time / max_epoch)) - print("lstm_language_model_loss\t%s" % ppl[0]) + print("ptblm\tlstm_language_model_duration\t%s" % + (total_time / max_epoch)) + print("ptblm\tlstm_language_model_loss\t%s" % ppl[0]) model_path = os.path.join("model_new/", str(epoch_id)) if not os.path.isdir(model_path):