未验证 提交 7e876736 编写于 作者: L LiuChiachi 提交者: GitHub

Fix TranslationDataset bug (#5043)

* fix translation.py bugs

* delete useless comments

* set couplet download root None
上级 0ba45dfe
......@@ -97,15 +97,14 @@ class CoupletDataset(TranslationDataset):
EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root='./'):
def __init__(self, mode='train', root=None):
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download data
root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
self.data = self.get_data(root=root)
self.vocab, _ = self.get_vocab(root)
self.transform()
......
......@@ -66,7 +66,7 @@ class TranslationDataset(paddle.io.Dataset):
@classmethod
def get_data(cls, mode="train", root=None):
"""
Download dataset if any data file doesn't exist.
Download dataset and read raw data.
Args:
mode(str, optional): Data mode to download. It could be 'train',
'dev' or 'test'. Default: 'train'.
......@@ -74,13 +74,20 @@ class TranslationDataset(paddle.io.Dataset):
provided, dataset will be saved in
`/root/.paddlenlp/datasets/machine_translation`. Default: None.
Returns:
str: All file paths of dataset.
list: Raw data, a list of tuple.
Examples:
.. code-block:: python
from paddlenlp.datasets import IWSLT15
data_path = IWSLT15.get_data()
"""
root = cls._download_data(mode, root)
data = cls.read_raw_data(mode, root)
return data
@classmethod
def _download_data(cls, mode="train", root=None):
"""Download dataset"""
default_root = os.path.join(DATA_HOME, 'machine_translation',
cls.__name__)
src_filename, tgt_filename, src_data_hash, tgt_data_hash = cls.SPLITS[
......@@ -95,7 +102,6 @@ class TranslationDataset(paddle.io.Dataset):
filename) if root is None else os.path.join(
os.path.expanduser(root), filename)
fullname_list.append(fullname)
# print(fullname)
data_hash_list = [
src_data_hash, tgt_data_hash, cls.VOCAB_INFO[2], cls.VOCAB_INFO[3]
......@@ -108,9 +114,8 @@ class TranslationDataset(paddle.io.Dataset):
warnings.warn(
'md5 check failed for {}, download {} data to {}'.
format(filename, cls.__name__, default_root))
path = get_path_from_url(cls.URL, root, cls.MD5)
break
path = get_path_from_url(cls.URL, default_root, cls.MD5)
return default_root
return root if root is not None else default_root
@classmethod
......@@ -130,9 +135,9 @@ class TranslationDataset(paddle.io.Dataset):
.. code-block:: python
from paddlenlp.datasets import IWSLT15
(src_vocab, tgt_vocab) = IWSLT15.get_vocab()
"""
root = cls.get_data(root=root)
root = cls._download_data(root=root)
src_vocab_filename, tgt_vocab_filename, _, _ = cls.VOCAB_INFO
src_file_path = os.path.join(root, src_vocab_filename)
tgt_file_path = os.path.join(root, tgt_vocab_filename)
......@@ -153,17 +158,15 @@ class TranslationDataset(paddle.io.Dataset):
return (src_vocab, tgt_vocab)
@classmethod
def read_raw_data(cls, root, mode):
def read_raw_data(cls, mode, root):
"""Read raw data from data files
Args:
root(str): Data directory of dataset.
mode(str): Indicates the mode to read. It could be 'train', 'dev' or
mode(str): Indicates the mode to read. It could be 'train', 'dev' or
'test'.
root(str): Data directory of dataset.
Returns:
list: Raw data list.
"""
# print(root)
src_filename, tgt_filename, _, _ = cls.SPLITS[mode]
def read_raw_files(corpus_path):
......@@ -177,7 +180,6 @@ class TranslationDataset(paddle.io.Dataset):
src_path = os.path.join(root, src_filename)
tgt_path = os.path.join(root, tgt_filename)
print(src_path, tgt_path)
src_data = read_raw_files(src_path)
tgt_data = read_raw_files(tgt_path)
......@@ -263,9 +265,8 @@ class IWSLT15(TranslationDataset):
if len(transform_func) != 2:
raise ValueError("`transform_func` must have length of two for"
"source and target.")
# Download data
root = self.get_data(root=root)
self.data = self.read_raw_data(root, mode)
# Download data and read data
self.data = self.get_data(root=root)
if transform_func is not None:
self.data = [(transform_func[0](data[0]),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册