提交 2dc06a83 编写于 作者: L liaogang

Update trainning

上级 a0fde5dd
...@@ -87,7 +87,7 @@ We use the [MovieLens ml-1m](http://files.grouplens.org/datasets/movielens/ml-1m ...@@ -87,7 +87,7 @@ We use the [MovieLens ml-1m](http://files.grouplens.org/datasets/movielens/ml-1m
help(paddle.v2.dataset.movielens) help(paddle.v2.dataset.movielens)
``` ```
The raw `MoiveLens` contains movie ratings, relevant features form both movies and users. The raw `MoiveLens` contains movie ratings, relevant features from both movies and users.
For instance, one movie's feature could be: For instance, one movie's feature could be:
```python ```python
...@@ -252,7 +252,7 @@ The movie ID and the movie type are mapped to their corresponding hidden layers. ...@@ -252,7 +252,7 @@ The movie ID and the movie type are mapped to their corresponding hidden layers.
inference = paddle.layer.cos_sim(a=usr_combined_features, b=mov_combined_features, size=1, scale=5) inference = paddle.layer.cos_sim(a=usr_combined_features, b=mov_combined_features, size=1, scale=5)
``` ```
进而,我们使用余弦相似度计算用户特征与电影特征的相似性。并将这个相似性拟合(回归)到用户评分上。 Finally, we can use cosine similarity to calculate the similarity between user characteristics and movie features.
```python ```python
cost = paddle.layer.regression_cost( cost = paddle.layer.regression_cost(
...@@ -261,13 +261,11 @@ cost = paddle.layer.regression_cost( ...@@ -261,13 +261,11 @@ cost = paddle.layer.regression_cost(
name='score', type=paddle.data_type.dense_vector(1))) name='score', type=paddle.data_type.dense_vector(1)))
``` ```
至此,我们的优化目标就是这个网络配置中的cost了。
## Model Training ## Model Training
### Define Parameters ### Define Parameters
First we define the model parameters according to the previous model configuration cost. First we define the model parameters according to the previous model configuration `cost`.
```python ```python
# Create parameters # Create parameters
...@@ -290,18 +288,18 @@ trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, ...@@ -290,18 +288,18 @@ trainer = paddle.trainer.SGD(cost=cost, parameters=parameters,
### Training ### Training
下面我们开始训练过程。 `paddle.dataset.movielens.train` will yield records during each pass, after shuffling, a batch input is generated for training.
我们直接使用Paddle提供的数据集读取程序。paddle.dataset.movielens.train()和paddle.dataset.movielens.test()分别做训练和预测数据集。并且通过reader_dict来指定每一个数据和data_layer的对应关系。
例如,这里的reader_dict表示的是,对于数据层 user_id,使用了reader中每一条数据的第0个元素。gender_id数据层使用了第1个元素。以此类推。
训练过程是完全自动的。我们可以使用event_handler来观察训练过程,或进行测试等。这里我们在event_handler里面绘制了训练误差曲线和测试误差曲线。并且保存了模型。
```python ```python
%matplotlib inline reader=paddle.reader.batch(
paddle.reader.shuffle(
paddle.dataset.movielens.trai(), buf_size=8192),
batch_size=256)
```
import matplotlib.pyplot as plt `feeding` is devoted to specify the correspondence between each yield record and `paddle.layer.data`. For instance, the first column of data generated by `movielens.train` corresponds to `user_id` feature.
from IPython import display
import cPickle
```python
feeding = { feeding = {
'user_id': 0, 'user_id': 0,
'gender_id': 1, 'gender_id': 1,
...@@ -312,12 +310,11 @@ feeding = { ...@@ -312,12 +310,11 @@ feeding = {
'movie_title': 6, 'movie_title': 6,
'score': 7 'score': 7
} }
```
step=0 Callback function `event_handler` is used to track training and testing process that might be triggered once the action to which it is attached is executed.
train_costs=[],[]
test_costs=[],[]
```python
def event_handler(event): def event_handler(event):
global step global step
global train_costs global train_costs
...@@ -342,15 +339,27 @@ def event_handler(event): ...@@ -342,15 +339,27 @@ def event_handler(event):
display.display(plt.gcf()) display.display(plt.gcf())
plt.gcf().clear() plt.gcf().clear()
step += 1 step += 1
```
Finally, we can invoke `trainer.train` to start training:
```python
%matplotlib inline
import matplotlib.pyplot as plt
from IPython import display
import cPickle
step=0
train_costs=[],[]
test_costs=[],[]
trainer.train( trainer.train(
reader=paddle.batch( reader=reader,
paddle.reader.shuffle(
paddle.dataset.movielens.train(), buf_size=8192),
batch_size=256),
event_handler=event_handler, event_handler=event_handler,
feeding=feeding, feeding=feeding,
num_passes=2) num_passes=200)
``` ```
## Conclusion ## Conclusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册