提交 a2020871 编写于 作者: L liaogang

Update api v2 in image classification

上级 095eb370
......@@ -162,7 +162,7 @@ from vgg import vgg_bn_drop
from resnet import resnet_cifar10
# PaddlePaddle init
paddle.init(use_gpu=True)
paddle.init(use_gpu=False, trainer_count=1)
```
本教程中我们提供了VGG和ResNet两个模型的配置。
......@@ -221,7 +221,7 @@ paddle.init(use_gpu=True)
return fc2
```
2.1. 首先定义了一组卷积网络,即conv_block。卷积核大小为3x3,池化窗口大小为2x2,窗口滑动大小为2,groups决定每组VGG模块是几次连续的卷积操作,dropouts指定Dropout操作的概率。所使用的`img_conv_group`是在`paddle.trainer_config_helpers`中预定义的模块,由若干组 `Conv->BN->ReLu->Dropout` 和 一组 `Pooling` 组成,
2.1. 首先定义了一组卷积网络,即conv_block。卷积核大小为3x3,池化窗口大小为2x2,窗口滑动大小为2,groups决定每组VGG模块是几次连续的卷积操作,dropouts指定Dropout操作的概率。所使用的`img_conv_group`是在`paddle.networks`中预定义的模块,由若干组 `Conv->BN->ReLu->Dropout` 和 一组 `Pooling` 组成,
2.2. 五组卷积操作,即 5个conv_block。 第一、二组采用两次连续的卷积操作。第三、四、五组采用三次连续的卷积操作。每组最后一个卷积后面Dropout概率为0,即不使用Dropout操作。
......@@ -327,11 +327,6 @@ def resnet_cifar10(ipt, depth=32):
return pool
```
### 优化算法
## 训练模型
### 定义参数
......@@ -351,7 +346,7 @@ print parameters.keys()
### 构造训练(Trainer)
根据网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的SGD方法(momentum设置为0),同时设定了学习率、正则等。
根据网络拓扑结构和模型参数来构造出trainer用来训练,在构造时还需指定优化方法,这里使用最基本的Momentum方法,同时设定了学习率、正则等。
```python
# Create optimizer
......@@ -367,7 +362,7 @@ momentum_optimizer = paddle.optimizer.Momentum(
# Create trainer
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=momentum_optimizer)
update_equation=momentum_optimizer)
```
通过 `learning_rate_decay_a` (简写$a$) 、`learning_rate_decay_b` (简写$b$) 和 `learning_rate_schedule` 指定学习率调整策略,这里采用离散指数的方式调节学习率,计算公式如下, $n$ 代表已经处理过的累计总样本数,$lr_{0}$ 即为 `settings` 里设置的 `learning_rate`
......@@ -380,17 +375,17 @@ $$ lr = lr_{0} * a^ {\lfloor \frac{n}{ b}\rfloor} $$
cifar.train10()每次产生一条样本,在完成shuffle和batch之后,作为训练的输入。
```python
reader=paddle.reader.batched(
reader=paddle.reader.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128)
```
通过`reader_dict`来指定每一个数据和`paddle.layer.data`的对应关系。例如: `cifar.train10()`产生数据的第0列对应image层的特征。
通过`feeding`来指定每一个数据和`paddle.layer.data`的对应关系。例如: `cifar.train10()`产生数据的第0列对应image层的特征。
```python
reader_dict={'image': 0,
'label': 1}
feeding={'image': 0,
'label': 1}
```
可以使用`event_handler`回调函数来观察训练过程,或进行测试等, 该回调函数是`trainer.train`函数里设定。
......@@ -407,10 +402,10 @@ def event_handler(event):
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
reader=paddle.reader.batch(
paddle.dataset.cifar.test10(), batch_size=128),
reader_dict={'image': 0,
'label': 1})
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
```
......@@ -421,7 +416,7 @@ trainer.train(
reader=reader,
num_passes=200,
event_handler=event_handler,
reader_dict=reader_dict)
feeding=feeding)
```
一轮训练log示例如下所示,经过1个pass, 训练集上平均error为0.6875 ,测试集上平均error为0.8852 。
......
......@@ -13,7 +13,9 @@
# limitations under the License
import sys
import paddle.v2 as paddle
from vgg import vgg_bn_drop
from resnet import resnet_cifar10
......@@ -23,7 +25,7 @@ def main():
classdim = 10
# PaddlePaddle init
paddle.init(use_gpu=True, trainer_count=1)
paddle.init(use_gpu=False, trainer_count=1)
image = paddle.layer.data(
name="image", type=paddle.data_type.dense_vector(datadim))
......@@ -66,10 +68,10 @@ def main():
sys.stdout.flush()
if isinstance(event, paddle.event.EndPass):
result = trainer.test(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.dataset.cifar.test10(), batch_size=128),
reader_dict={'image': 0,
'label': 1})
feeding={'image': 0,
'label': 1})
print "\nTest with Pass %d, %s" % (event.pass_id, result.metrics)
# Create trainer
......@@ -77,14 +79,14 @@ def main():
parameters=parameters,
update_equation=momentum_optimizer)
trainer.train(
reader=paddle.reader.batched(
reader=paddle.batch(
paddle.reader.shuffle(
paddle.dataset.cifar.train10(), buf_size=50000),
batch_size=128),
num_passes=200,
event_handler=event_handler,
reader_dict={'image': 0,
'label': 1})
feeding={'image': 0,
'label': 1})
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册