diff --git a/demo/seqToseq/dataprovider.py b/demo/seqToseq/dataprovider.py index 0614d54aaf8bb64f8a4d4a8af364fb993fd8cc7e..5174092df26089bc5661a7d98da62dc7a124c54d 100755 --- a/demo/seqToseq/dataprovider.py +++ b/demo/seqToseq/dataprovider.py @@ -25,11 +25,17 @@ def hook(settings, src_dict_path, trg_dict_path, is_generating, file_list, # job_mode = 0: generating mode settings.job_mode = not is_generating settings.src_dict = dict() - for line_count, line in enumerate(open(src_dict_path, "r")): - settings.src_dict[line.strip()] = line_count + with open(src_dict_path, "r") as fin: + settings.src_dict = { + line.strip(): line_count + for line_count, line in enumerate(fin) + } settings.trg_dict = dict() - for line_count, line in enumerate(open(trg_dict_path, "r")): - settings.trg_dict[line.strip()] = line_count + with open(trg_dict_path, "r") as fin: + settings.trg_dict = { + line.strip(): line_count + for line_count, line in enumerate(fin) + } settings.logger.info("src dict len : %d" % (len(settings.src_dict))) settings.sample_count = 0