提交 bf315f9a 编写于 作者: A Aston Zhang

update data_iter

上级 878d70f6
......@@ -59,7 +59,7 @@ plt.show()
## 读取数据
在训练模型的时候,我们需要遍历数据集并不断读取小批量数据样本。这里我们定义一个函数:它每次返回批量大小个随机样本的特征和标签。
在训练模型的时候,我们需要遍历数据集并不断读取小批量数据样本。这里我们定义一个函数:它每次返回`batch_size`(批量大小)个随机样本的特征和标签。
```{.python .input n=5}
batch_size = 10
......@@ -74,7 +74,7 @@ def data_iter(batch_size, features, labels):
yield features.take(j), labels.take(j)
```
让我们读取第一个小批量数据样本并打印。每个批量的特征形状为`(10, 2)`,分别对应批量大小和输入个数;标签形状为批量大小。
让我们读取第一个小批量数据样本并打印。每个批量的特征形状为(10, 2),分别对应批量大小和输入个数;标签形状为批量大小。
```{.python .input n=6}
for X, y in data_iter(batch_size, features, labels):
......@@ -86,7 +86,7 @@ for X, y in data_iter(batch_size, features, labels):
## 初始化模型参数
我们将权重随机化成均值0方差为0.01的正态随机数,偏差则初始化成0。
我们将权重初始化成均值0标准差为0.01的正态随机数,偏差则初始化成0。
```{.python .input n=7}
w = nd.random.normal(scale=0.01, shape=(num_inputs, 1))
......@@ -123,7 +123,7 @@ def squared_loss(y_hat, y):
## 定义优化算法
以下的`sgd`函数实现了上一节中介绍的小批量随机梯度下降算法中的模型更新
以下的`sgd`函数实现了上一节中介绍的小批量随机梯度下降算法。它通过不断更新模型参数来优化损失函数
```{.python .input n=11}
def sgd(params, lr, batch_size):
......
......@@ -96,7 +96,7 @@ def fit_and_plot(lambd):
train_ls = []
test_ls = []
for _ in range(num_epochs):
for X, y in gb.data_iter(batch_size, n_train, features, labels):
for X, y in gb.data_iter(batch_size, features, labels):
with autograd.record():
# 添加了 L2 范数惩罚项。
l = loss(net(X, w, b), y) + lambd * l2_penalty(w)
......
......@@ -93,7 +93,7 @@ def optimize(batch_size, rho, num_epochs, log_interval):
ls = [loss(net(features, w, b), labels).mean().asnumpy()]
for epoch in range(1, num_epochs + 1):
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
l.backward()
......
......@@ -100,7 +100,7 @@ def optimize(batch_size, lr, num_epochs, log_interval):
ls = [loss(net(features, w, b), labels).mean().asnumpy()]
for epoch in range(1, num_epochs + 1):
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
l.backward()
......
......@@ -109,7 +109,7 @@ def optimize(batch_size, lr, num_epochs, log_interval):
t = 0
for epoch in range(1, num_epochs + 1):
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
l.backward()
......
......@@ -172,7 +172,7 @@ def optimize(batch_size, lr, num_epochs, log_interval, decay_epoch):
if decay_epoch and epoch > decay_epoch:
lr *= 0.1
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
# 先对 l 中元素求和,得到小批量损失之和,然后求参数的梯度。
......
......@@ -135,7 +135,7 @@ def optimize(batch_size, lr, mom, num_epochs, log_interval):
if epoch > 2:
lr *= 0.1
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
l.backward()
......
......@@ -89,7 +89,7 @@ def optimize(batch_size, lr, gamma, num_epochs, log_interval):
ls = [loss(net(features, w, b), labels).mean().asnumpy()]
for epoch in range(1, num_epochs + 1):
for batch_i, (X, y) in enumerate(
gb.data_iter(batch_size, num_examples, features, labels)):
gb.data_iter(batch_size, features, labels)):
with autograd.record():
l = loss(net(X, w, b), y)
l.backward()
......
......@@ -377,8 +377,9 @@ def train_and_predict_rnn(rnn, is_random_iter, num_epochs, num_steps,
ctx, idx_to_char, char_to_idx, get_inputs, is_lstm))
def data_iter(batch_size, num_examples, features, labels):
def data_iter(batch_size, features, labels):
"""Iterate through a data set."""
num_examples = len(features)
indices = list(range(num_examples))
random.shuffle(indices)
for i in range(0, num_examples, batch_size):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册