diff --git a/python/paddle/dataset/sentiment.py b/python/paddle/dataset/sentiment.py index 10bc33ac69c41444491a220a7c7fd664b2ffb7c2..721cb5a819282d5ef130de4d4596116326349d71 100644 --- a/python/paddle/dataset/sentiment.py +++ b/python/paddle/dataset/sentiment.py @@ -26,14 +26,17 @@ import six import collections from itertools import chain +import os import nltk from nltk.corpus import movie_reviews -import ssl -ssl._create_default_https_context = ssl._create_unverified_context +import zipfile from functools import cmp_to_key import paddle.dataset.common +URL = "https://corpora.bj.bcebos.com/movie_reviews%2Fmovie_reviews.zip" +MD5 = '155de2b77c6834dd8eea7cbe88e93acb' + __all__ = ['train', 'test', 'get_word_dict'] NUM_TRAINING_INSTANCES = 1600 NUM_TOTAL_INSTANCES = 2000 @@ -44,6 +47,14 @@ def download_data_if_not_yet(): Download the data set, if the data set is not download. """ try: + # download and extract movie_reviews.zip + paddle.dataset.common.download( + URL, 'corpora', md5sum=MD5, save_name='movie_reviews.zip') + path = os.path.join(paddle.dataset.common.DATA_HOME, 'corpora') + filename = os.path.join(path, 'movie_reviews.zip') + zip_file = zipfile.ZipFile(filename) + zip_file.extractall(path) + zip_file.close() # make sure that nltk can find the data if paddle.dataset.common.DATA_HOME not in nltk.data.path: nltk.data.path.append(paddle.dataset.common.DATA_HOME) diff --git a/python/paddle/fluid/tests/unittests/test_dataset_sentiment.py b/python/paddle/fluid/tests/unittests/test_dataset_sentiment.py index f92c1e5264b3dd4347dc6bc1d2f958408f5d7b3a..b5d5d33fa3fc32a054c23c80d471ce70dd745d08 100644 --- a/python/paddle/fluid/tests/unittests/test_dataset_sentiment.py +++ b/python/paddle/fluid/tests/unittests/test_dataset_sentiment.py @@ -31,15 +31,6 @@ MD5 = '155de2b77c6834dd8eea7cbe88e93acb' class TestDatasetSentiment(unittest.TestCase): """ TestCases for Sentiment. """ - def setUp(self): - paddle.dataset.common.download( - URL, 'corpora', md5sum=MD5, save_name='movie_reviews.zip') - path = os.path.join(paddle.dataset.common.DATA_HOME, 'corpora') - filename = os.path.join(path, 'movie_reviews.zip') - zip_file = zipfile.ZipFile(filename) - zip_file.extractall(path) - zip_file.close() - def test_get_word_dict(self): """ Testcase for get_word_dict. """ words_freq_sorted = paddle.dataset.sentiment.get_word_dict()