未验证 提交 7e876736 编写于 作者: L LiuChiachi 提交者: GitHub

Fix TranslationDataset bug (#5043)

* fix translation.py bugs

* delete useless comments

* set couplet download root None
上级 0ba45dfe
...@@ -97,15 +97,14 @@ class CoupletDataset(TranslationDataset): ...@@ -97,15 +97,14 @@ class CoupletDataset(TranslationDataset):
EOS_TOKEN = '</s>' EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54' MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root='./'): def __init__(self, mode='train', root=None):
data_select = ('train', 'dev', 'test') data_select = ('train', 'dev', 'test')
if mode not in data_select: if mode not in data_select:
raise TypeError( raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'. '`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode)) format(mode))
# Download data # Download data
root = self.get_data(root=root) self.data = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
self.vocab, _ = self.get_vocab(root) self.vocab, _ = self.get_vocab(root)
self.transform() self.transform()
......
...@@ -66,7 +66,7 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -66,7 +66,7 @@ class TranslationDataset(paddle.io.Dataset):
@classmethod @classmethod
def get_data(cls, mode="train", root=None): def get_data(cls, mode="train", root=None):
""" """
Download dataset if any data file doesn't exist. Download dataset and read raw data.
Args: Args:
mode(str, optional): Data mode to download. It could be 'train', mode(str, optional): Data mode to download. It could be 'train',
'dev' or 'test'. Default: 'train'. 'dev' or 'test'. Default: 'train'.
...@@ -74,13 +74,20 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -74,13 +74,20 @@ class TranslationDataset(paddle.io.Dataset):
provided, dataset will be saved in provided, dataset will be saved in
`/root/.paddlenlp/datasets/machine_translation`. Default: None. `/root/.paddlenlp/datasets/machine_translation`. Default: None.
Returns: Returns:
str: All file paths of dataset. list: Raw data, a list of tuple.
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data() data_path = IWSLT15.get_data()
""" """
root = cls._download_data(mode, root)
data = cls.read_raw_data(mode, root)
return data
@classmethod
def _download_data(cls, mode="train", root=None):
"""Download dataset"""
default_root = os.path.join(DATA_HOME, 'machine_translation', default_root = os.path.join(DATA_HOME, 'machine_translation',
cls.__name__) cls.__name__)
src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[ src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
...@@ -95,7 +102,6 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -95,7 +102,6 @@ class TranslationDataset(paddle.io.Dataset):
filename) if root is None else os.path.join( filename) if root is None else os.path.join(
os.path.expanduser(root), filename) os.path.expanduser(root), filename)
fullname_list.append(fullname) fullname_list.append(fullname)
# print(fullname)
data_hash_list = [ data_hash_list = [
src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3] src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
...@@ -108,9 +114,8 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -108,9 +114,8 @@ class TranslationDataset(paddle.io.Dataset):
warnings.warn( warnings.warn(
'md5 check failed for {}, download {} data to {}'. 'md5 check failed for {}, download {} data to {}'.
format(filename, cls.__name__, default_root)) format(filename, cls.__name__, default_root))
path = get_path_from_url(cls.URL, root, cls.MD5) path = get_path_from_url(cls.URL, default_root, cls.MD5)
return default_root
break
return root if root is not None else default_root return root if root is not None else default_root
@classmethod @classmethod
...@@ -130,9 +135,9 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -130,9 +135,9 @@ class TranslationDataset(paddle.io.Dataset):
.. code-block:: python .. code-block:: python
from paddlenlp.datasets import IWSLT15 from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.get_vocab() (src_vocab, tgt_vocab) = IWSLT15.get_vocab()
""" """
root = cls.get_data(root=root)
root = cls._download_data(root=root)
src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
src_file_path = os.path.join(root, src_vocab_filename) src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename) tgt_file_path = os.path.join(root, tgt_vocab_filename)
...@@ -153,17 +158,15 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -153,17 +158,15 @@ class TranslationDataset(paddle.io.Dataset):
return (src_vocab, tgt_vocab) return (src_vocab, tgt_vocab)
@classmethod @classmethod
def read_raw_data(cls, root, mode): def read_raw_data(cls, mode, root):
"""Read raw data from data files """Read raw data from data files
Args: Args:
root(str): Data directory of dataset. mode(str): Indicates the mode to read. It could be 'train', 'dev' or
mode(str): Indicates the mode to read. It could be 'train', 'dev' or
'test'. 'test'.
root(str): Data directory of dataset.
Returns: Returns:
list: Raw data list. list: Raw data list.
""" """
# print(root)
src_filename, tgt_filename, _, _ = cls.SPLITS[mode] src_filename, tgt_filename, _, _ = cls.SPLITS[mode]
def read_raw_files(corpus_path): def read_raw_files(corpus_path):
...@@ -177,7 +180,6 @@ class TranslationDataset(paddle.io.Dataset): ...@@ -177,7 +180,6 @@ class TranslationDataset(paddle.io.Dataset):
src_path = os.path.join(root, src_filename) src_path = os.path.join(root, src_filename)
tgt_path = os.path.join(root, tgt_filename) tgt_path = os.path.join(root, tgt_filename)
print(src_path, tgt_path)
src_data = read_raw_files(src_path) src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path) tgt_data = read_raw_files(tgt_path)
...@@ -263,9 +265,8 @@ class IWSLT15(TranslationDataset): ...@@ -263,9 +265,8 @@ class IWSLT15(TranslationDataset):
if len(transform_func) != 2: if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for" raise ValueError("`transform_func` must have length of two for"
"source and target.") "source and target.")
# Download data # Download data and read data
root = self.get_data(root=root) self.data = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
if transform_func is not None: if transform_func is not None:
self.data = [(transform_func[0](data[0]), self.data = [(transform_func[0](data[0]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册