提交 8ef34ea6 编写于 作者: M malin10

update w2v

上级 c5908fe6
...@@ -38,11 +38,15 @@ ...@@ -38,11 +38,15 @@
- [FAQ](#FAQ) - [FAQ](#FAQ)
## 模型简介 ## 模型简介
本例实现了skip-gram模式的word2vector模型,如下图中的skip-gram部分: 本例实现了skip-gram模式的word2vector模型,如下图所示:
![](https://ai-studio-static-online.cdn.bcebos.com/bf0217bcb42e455284290a100670072989d432939b7e43e38b78eea6b60732c0) <p align="center">
<img align="center" src="../../../doc/imgs/word2vec.png">
<p>
以每一个词为中心词X,然后在窗口内和临近的词Y组成样本对(X,Y)用于网络训练。在实际训练过程中还会根据自定义的负采样率生成负样本来加强训练的效果 以每一个词为中心词X,然后在窗口内和临近的词Y组成样本对(X,Y)用于网络训练。在实际训练过程中还会根据自定义的负采样率生成负样本来加强训练的效果
具体的训练思路如下: 具体的训练思路如下:
![](https://ai-studio-static-online.cdn.bcebos.com/ad45d6dfff2f4fa69c639a6b5d2bbe46c79ac67658464aa0a069a04eb2800cb6) <p align="center">
<img align="center" src="../../../doc/imgs/w2v_train.png">
<p>
推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124377)教程获取更详细的信息。 推荐用户参考[ IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/124377)教程获取更详细的信息。
......
...@@ -42,38 +42,40 @@ hyper_parameters: ...@@ -42,38 +42,40 @@ hyper_parameters:
window_size: 5 window_size: 5
# select runner by name # select runner by name
mode: train_runner mode: [single_cpu_train, single_cpu_infer]
# config of each runner. # config of each runner.
# runner is a kind of paddle training class, which wraps the train/infer process. # runner is a kind of paddle training class, which wraps the train/infer process.
runner: runner:
- name: train_runner - name: single_cpu_train
class: train class: train
# num of epochs # num of epochs
epochs: 2 epochs: 5
# device to run training or infer # device to run training or infer
device: cpu device: cpu
save_checkpoint_interval: 1 # save model interval of epochs save_checkpoint_interval: 1 # save model interval of epochs
save_inference_interval: 1 # save inference save_inference_interval: 1 # save inference
save_checkpoint_path: "increment" # save checkpoint path save_checkpoint_path: "increment_w2v" # save checkpoint path
save_inference_path: "inference" # save inference path save_inference_path: "inference_w2v" # save inference path
save_inference_feed_varnames: [] # feed vars of save inference save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path init_model_path: "" # load model path
print_interval: 1 print_interval: 1
- name: infer_runner phases: [phase1]
- name: single_cpu_infer
class: infer class: infer
# device to run training or infer # device to run training or infer
device: cpu device: cpu
init_model_path: "increment/0" # load model path init_model_path: "increment_w2v" # load model path
print_interval: 1 print_interval: 1
phases: [phase2]
# runner will run all the phase in each epoch # runner will run all the phase in each epoch
phase: phase:
- name: phase1 - name: phase1
model: "{workspace}/model.py" # user-defined model model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_train # select dataset by name dataset_name: dataset_train # select dataset by name
thread_num: 5
- name: phase2
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 1 thread_num: 1
# - name: phase2
# model: "{workspace}/model.py" # user-defined model
# dataset_name: dataset_infer # select dataset by name
# thread_num: 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册