提交 4af9be96 编写于 作者: A Amy 提交者: drpngx

support passing in a source url to the mnist read_data_sets function, to make...

support passing in a source url to the mnist read_data_sets function, to make it easier to use 'fashion mnist' etc. (#12983)
上级 79517578
......@@ -30,7 +30,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.platform import gfile
# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
def _read32(bytestream):
......@@ -215,7 +215,8 @@ def read_data_sets(train_dir,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None):
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:
def fake():
......@@ -227,28 +228,31 @@ def read_data_sets(train_dir,
test = fake()
return base.Datasets(train=train, validation=validation, test=test)
if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
SOURCE_URL + TRAIN_IMAGES)
source_url + TRAIN_IMAGES)
with gfile.Open(local_file, 'rb') as f:
train_images = extract_images(f)
local_file = base.maybe_download(TRAIN_LABELS, train_dir,
SOURCE_URL + TRAIN_LABELS)
source_url + TRAIN_LABELS)
with gfile.Open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)
local_file = base.maybe_download(TEST_IMAGES, train_dir,
SOURCE_URL + TEST_IMAGES)
source_url + TEST_IMAGES)
with gfile.Open(local_file, 'rb') as f:
test_images = extract_images(f)
local_file = base.maybe_download(TEST_LABELS, train_dir,
SOURCE_URL + TEST_LABELS)
source_url + TEST_LABELS)
with gfile.Open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)
......@@ -262,13 +266,13 @@ def read_data_sets(train_dir,
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]
options = dict(dtype=dtype, reshape=reshape, seed=seed)
train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)
return base.Datasets(train=train, validation=validation, test=test)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册