From b420c007e92ce38b1fd097595b41ad60129f9caf Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 25 Jul 2018 05:46:44 +0000 Subject: [PATCH] Add pickle --- .../transformer/reader.py | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/reader.py b/fluid/neural_machine_translation/transformer/reader.py index caede1f7..397b8a3d 100644 --- a/fluid/neural_machine_translation/transformer/reader.py +++ b/fluid/neural_machine_translation/transformer/reader.py @@ -2,7 +2,7 @@ import glob import os import random import tarfile -import time +import cPickle class SortType(object): @@ -179,7 +179,8 @@ class DataReader(object): start_mark="", end_mark="", unk_mark="", - seed=0): + seed=0, + pkl_filename=None): self._src_vocab = self.load_dict(src_vocab_fpath) self._only_src = True if trg_vocab_fpath is not None: @@ -195,8 +196,23 @@ class DataReader(object): self._min_length = min_length self._max_length = max_length self._delimiter = delimiter - self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, - unk_mark) + + if pkl_filename is None: + self.load_src_trg_ids(end_mark, fpattern, start_mark, tar_fname, + unk_mark) + else: + try: + with open(pkl_filename, 'r') as f: + self._src_seq_ids, self._trg_seq_ids, self._sample_infos = cPickle.load( + f) + except: + self.load_src_trg_ids(end_mark, fpattern, start_mark, tarfile, + unk_mark) + with open(pkl_filename, 'w') as f: + cPickle.dump((self._src_seq_ids, self._trg_seq_ids, + self._sample_infos), f, + cPickle.HIGHEST_PROTOCOL) + self._random = random.Random(x=seed) def load_src_trg_ids(self, end_mark, fpattern, start_mark, tar_fname, -- GitLab