提交 e03b6b7b 编写于 作者: L liaogang

Update readme api

上级 dc347297
......@@ -329,7 +329,29 @@ def resnet_cifar10(ipt, depth=32):
### 优化算法
通过 `paddle.optimizer`模块设置训练的优化算法,并指定batch size 、初始学习率、momentum以及L2正则。
## 训练模型
### 定义参数
首先依据模型配置的`cost`定义模型参数。
```python
# Create parameters
parameters = paddle.parameters.create(cost)
```
可以打印参数名字,如果在网络配置中没有指定名字,则默认生成。
```python
print parameters.keys()
```
### 构造训练(Trainer)
根据网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法(momentum设置为0),同时设定了学习率、正则等。
```python
# Create optimizer
......@@ -341,6 +363,11 @@ momentum_optimizer = paddle.optimizer.Momentum(
learning_rate_decay_b=50000 * 100,
learning_rate_schedule='discexp',
batch_size=128)
# Create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=momentum_optimizer)
```
通过 `learning_rate_decay_a` (简写$a$) 、`learning_rate_decay_b` (简写$b$) 和 `learning_rate_schedule` 指定学习率调整策略,这里采用离散指数的方式调节学习率,计算公式如下, $n$ 代表已经处理过的累计总样本数,$lr_{0}$ 即为 `settings` 里设置的 `learning_rate`
......@@ -348,7 +375,25 @@ momentum_optimizer = paddle.optimizer.Momentum(
$$ lr = lr_{0} * a^ {\lfloor \frac{n}{ b}\rfloor} $$
## 模型训练
### 训练
cifar.train10()每次产生一条样本,在完成shuffle和batch之后,作为训练的输入。
```python
reader=paddle.reader.batched(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128)
```
通过`reader_dict`来指定每一个数据和`paddle.layer.data`的对应关系。例如: `cifar.train10()`产生数据的第0列对应image层的特征。
```python
reader_dict={'image': 0,
'label': 1}
```
可以使用`event_handler`回调函数来观察训练过程,或进行测试等, 该回调函数是`trainer.train`函数里设定。
```python
# End batch and end pass event handler
......@@ -367,23 +412,18 @@ def event_handler(event):
reader_dict={'image': 0,
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
```
# Create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=momentum_optimizer)
通过`trainer.train`函数训练:
```python
trainer.train(
reader=paddle.reader.batched(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128),
reader=reader,
num_passes=200,
event_handler=event_handler,
reader_dict={'image': 0,
'label': 1})
reader_dict=reader_dict)
```
一轮训练log示例如下所示,经过1个pass, 训练集上平均error为0.6875 ,测试集上平均error为0.8852 。
```text
......
......@@ -81,7 +81,7 @@ def main():
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128),
num_passes=5,
num_passes=200,
event_handler=event_handler,
reader_dict={'image': 0,
'label': 1})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册