未验证 提交 66831a70 编写于 作者: L Lyon 提交者: GitHub

Develop document load data (#3205)

* add load_mnist.py for document demo

* remove chinese character

* del test code

* del useless code

* del useless code

* modify load_mnist, add tqdm to setup.py

* add tqdm to dev-requirements.txt

* refine load_mnist.py

* add code api doc

* refine docstring

* undo changes of api docs

* add required library 'requests' to setup.py
Co-authored-by: NShenghang Tsai <jackalcooper@gmail.com>
上级 f41be54a
import os
import hashlib
import numpy as np
from tqdm import tqdm
import requests
from oneflow.python.oneflow_export import oneflow_export
def get_sha256hash(file_path, Bytes=1024):
sha256hash = hashlib.sha256()
with open(file_path, "rb") as f:
while True:
data = f.read(Bytes)
if data:
sha256hash.update(data)
else:
break
ret = sha256hash.hexdigest()
return ret
def download_mnist_file(out_path, url):
resp = requests.get(url=url, stream=True)
size = int(resp.headers["Content-Length"]) / 1024
print("File size: %.4f kb, downloading..." % size)
with open(out_path, "wb") as f:
for data in tqdm(
iterable=resp.iter_content(1024), total=size, unit="k", desc=out_path
):
f.write(data)
print("Done!")
def get_mnist_file(sha256, url, out_dir):
path = os.path.join(out_dir, "mnist.npz")
if not (os.path.isfile(path)):
download_mnist_file(path, url)
print("File mnist.npz already exist, path:", path)
if not get_sha256hash(path) == sha256:
cheksum_fail = "sha256 verification failed, remove {0} and try again".format(
path
)
raise Exception(cheksum_fail)
return path
@oneflow_export("data.load_mnist")
def load_mnist(
train_batch_size=100,
test_batch_size=100,
data_format="NCHW",
url="https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz",
hash_check="63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e",
out_dir=".",
):
r"""Load mnist dataset, return images and labels,
if dataset doesn't exist, then download it to directory that out_dir specified
Args:
train_batch_size (int, optional): batch size for train. Defaults to 100.
test_batch_size (int, optional): batch size for test or evaluate. Defaults to 100.
data_format (str, optional): data format. Defaults to "NCHW".
url (str, optional): url to get mnist.npz. Defaults to "https://oneflow-public.oss-cn-beijing.aliyuncs.com/datasets/mnist.npz".
hash_check (str, optional): file hash value. Defaults to "63d4344077849053dc3036b247fa012b2b381de53fd055a66b539dffd76cf08e".
out_dir (str, optional): dir to save downloaded file. Defaults to "./".
Returns:
[type]: (train_images, train_labels), (test_images, test_labels)
"""
path = get_mnist_file(hash_check, url, out_dir)
with np.load(path, allow_pickle=True) as f:
x_train, y_train = f["x_train"], f["y_train"]
x_test, y_test = f["x_test"], f["y_test"]
def normalize(x, y, batch_size):
x = (x.astype(np.float32) - 128.0) / 255.0
y = y.astype(np.int32)
if data_format == "NCHW":
images = x.reshape((-1, batch_size, 1, x.shape[1], x.shape[2]))
else:
images = x.reshape((-1, batch_size, x.shape[1], x.shape[2], 1))
labels = y.reshape((-1, batch_size))
return images, labels
train_images, train_labels = normalize(x_train, y_train, train_batch_size)
test_images, test_labels = normalize(x_test, y_test, test_batch_size)
return (train_images, train_labels), (test_images, test_labels)
......@@ -37,6 +37,8 @@ sys.argv = ['setup.py'] + remain_args
REQUIRED_PACKAGES = [
'numpy',
'protobuf',
'tqdm',
'requests',
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册