# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import sys import collections import six import time import numpy as np import paddle.fluid as fluid import paddle import os import preprocess import io def BuildWord_IdMap(dict_path): word_to_id = dict() id_to_word = dict() with io.open(dict_path, 'r', encoding='utf-8') as f: for line in f: word_to_id[line.split(' ')[0]] = int(line.split(' ')[1]) id_to_word[int(line.split(' ')[1])] = line.split(' ')[0] return word_to_id, id_to_word def prepare_data(file_dir, dict_path, batch_size): w2i, i2w = BuildWord_IdMap(dict_path) vocab_size = len(i2w) reader = fluid.io.batch(test(file_dir, w2i), batch_size) return vocab_size, reader, i2w def check_version(with_shuffle_batch=False): """ Log error and exit when the installed version of paddlepaddle is not satisfied. """ err = "PaddlePaddle version 1.6 or higher is required, " \ "or a suitable develop version is satisfied as well. \n" \ "Please make sure the version is good with your code." \ try: if with_shuffle_batch: fluid.require_version('1.7.0') else: fluid.require_version('1.6.0') except Exception as e: logger.error(err) sys.exit(1) def native_to_unicode(s): if _is_unicode(s): return s try: return _to_unicode(s) except UnicodeDecodeError: res = _to_unicode(s, ignore_errors=True) return res def _is_unicode(s): if six.PY2: if isinstance(s, unicode): return True else: if isinstance(s, str): return True return False def _to_unicode(s, ignore_errors=False): if _is_unicode(s): return s error_mode = "ignore" if ignore_errors else "strict" return s.decode("utf-8", errors=error_mode) def strip_lines(line, vocab): return _replace_oov(vocab, native_to_unicode(line)) def _replace_oov(original_vocab, line): """Replace out-of-vocab words with "". This maintains compatibility with published results. Args: original_vocab: a set of strings (The standard vocabulary for the dataset) line: a unicode string - a space-delimited sequence of words. Returns: a unicode string - a space-delimited sequence of words. """ return u" ".join([ word if word in original_vocab else u"" for word in line.split() ]) def reader_creator(file_dir, word_to_id): def reader(): files = os.listdir(file_dir) for fi in files: with io.open( os.path.join(file_dir, fi), "r", encoding='utf-8') as f: for line in f: if ':' in line: pass else: line = strip_lines(line.lower(), word_to_id) line = line.split() yield [word_to_id[line[0]]], [word_to_id[line[1]]], [ word_to_id[line[2]] ], [word_to_id[line[3]]], [ word_to_id[line[0]], word_to_id[line[1]], word_to_id[line[2]] ] return reader def test(test_dir, w2i): return reader_creator(test_dir, w2i)