未验证 提交 7862c3a8 编写于 作者: J Jiaqi Liu 提交者: GitHub

Add CoupletDataset (#5247)

* move coupletdataset to paddlenlp.dataset

* move coupletdataset to paddlenlp.dataset
上级 a911e2d5
...@@ -21,7 +21,7 @@ import numpy as np ...@@ -21,7 +21,7 @@ import numpy as np
import paddle import paddle
from paddlenlp.data import Vocab, Pad from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper from paddlenlp.data import SamplerHelper
from paddlenlp.datasets import TranslationDataset from paddlenlp.datasets import CoupletDataset
def create_train_loader(batch_size=128): def create_train_loader(batch_size=128):
...@@ -65,50 +65,3 @@ def prepare_input(insts, pad_id): ...@@ -65,50 +65,3 @@ def prepare_input(insts, pad_id):
[inst[1] for inst in insts]) [inst[1] for inst in insts])
tgt_mask = (tgt[:, :-1] != pad_id).astype(paddle.get_default_dtype()) tgt_mask = (tgt[:, :-1] != pad_id).astype(paddle.get_default_dtype())
return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask return src, src_length, tgt[:, :-1], tgt[:, 1:, np.newaxis], tgt_mask
class CoupletDataset(TranslationDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("couplet", "train_src.tsv"),
os.path.join("couplet", "train_tgt.tsv"),
"ad137385ad5e264ac4a54fe8c95d1583",
"daf4dd79dbf26040696eee0d645ef5ad"),
'dev': TranslationDataset.META_INFO(
os.path.join("couplet", "dev_src.tsv"),
os.path.join("couplet", "dev_tgt.tsv"),
"65bf9e72fa8fdf0482751c1fd6b6833c",
"3bc3b300b19d170923edfa8491352951"),
'test': TranslationDataset.META_INFO(
os.path.join("couplet", "test_src.tsv"),
os.path.join("couplet", "test_tgt.tsv"),
"f0a7366dfa0acac884b9f4901aac2cc1",
"56664bff3f2edfd7a751a55a689f90c2")
}
VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), os.path.join(
"couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604",
"0bea1445c7c7fb659b856bb07e54a604")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root=None):
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download and read data
self.data = self.get_data(mode=mode, root=root)
self.vocab, _ = self.get_vocab(root)
self.transform()
def transform(self):
eos_id = self.vocab[self.EOS_TOKEN]
bos_id = self.vocab[self.BOS_TOKEN]
self.data = [(
[bos_id] + self.vocab.to_indices(data[0].split("\x02")) + [eos_id],
[bos_id] + self.vocab.to_indices(data[1].split("\x02")) + [eos_id])
for data in self.data]
...@@ -24,4 +24,5 @@ from .squad import * ...@@ -24,4 +24,5 @@ from .squad import *
from .translation import * from .translation import *
from .dureader import * from .dureader import *
from .cnndm import * from .cnndm import *
from .poetry import * from .poetry import *
\ No newline at end of file from .couplet import *
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from paddlenlp.datasets import TranslationDataset
__all__ = ['CoupletDataset']
class CoupletDataset(TranslationDataset):
URL = "https://paddlenlp.bj.bcebos.com/datasets/couplet.tar.gz"
SPLITS = {
'train': TranslationDataset.META_INFO(
os.path.join("couplet", "train_src.tsv"),
os.path.join("couplet", "train_tgt.tsv"),
"ad137385ad5e264ac4a54fe8c95d1583",
"daf4dd79dbf26040696eee0d645ef5ad"),
'dev': TranslationDataset.META_INFO(
os.path.join("couplet", "dev_src.tsv"),
os.path.join("couplet", "dev_tgt.tsv"),
"65bf9e72fa8fdf0482751c1fd6b6833c",
"3bc3b300b19d170923edfa8491352951"),
'test': TranslationDataset.META_INFO(
os.path.join("couplet", "test_src.tsv"),
os.path.join("couplet", "test_tgt.tsv"),
"f0a7366dfa0acac884b9f4901aac2cc1",
"56664bff3f2edfd7a751a55a689f90c2")
}
VOCAB_INFO = (os.path.join("couplet", "vocab.txt"), os.path.join(
"couplet", "vocab.txt"), "0bea1445c7c7fb659b856bb07e54a604",
"0bea1445c7c7fb659b856bb07e54a604")
UNK_TOKEN = '<unk>'
BOS_TOKEN = '<s>'
EOS_TOKEN = '</s>'
MD5 = '5c0dcde8eec6a517492227041c2e2d54'
def __init__(self, mode='train', root=None):
data_select = ('train', 'dev', 'test')
if mode not in data_select:
raise TypeError(
'`train`, `dev` or `test` is supported but `{}` is passed in'.
format(mode))
# Download and read data
self.data = self.get_data(mode=mode, root=root)
self.vocab, _ = self.get_vocab(root)
self.transform()
def transform(self):
eos_id = self.vocab[self.EOS_TOKEN]
bos_id = self.vocab[self.BOS_TOKEN]
self.data = [(
[bos_id] + self.vocab.to_indices(data[0].split("\x02")) + [eos_id],
[bos_id] + self.vocab.to_indices(data[1].split("\x02")) + [eos_id])
for data in self.data]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册