未验证 提交 e95ea8fc 编写于 作者: W wuzhihua 提交者: GitHub

Merge pull request #156 from frankwhzhang/0724_fix_listwise

fix batch bug in listwise
...@@ -17,29 +17,30 @@ workspace: "paddlerec.models.rerank.listwise" ...@@ -17,29 +17,30 @@ workspace: "paddlerec.models.rerank.listwise"
dataset: dataset:
- name: dataset_train - name: dataset_train
batch_size: 5
type: DataLoader type: DataLoader
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/random_reader.py" data_converter: "{workspace}/random_reader.py"
- name: dataset_infer - name: dataset_infer
batch_size: 5
type: DataLoader type: DataLoader
data_path: "{workspace}/data/test" data_path: "{workspace}/data/test"
data_converter: "{workspace}/random_reader.py" data_converter: "{workspace}/random_reader.py"
hyper_parameters: hyper_parameters:
optimizer:
class: sgd
learning_rate: 0.01
strategy: async
hidden_size: 128 hidden_size: 128
user_vocab: 200 user_vocab: 200
item_vocab: 1000 item_vocab: 1000
item_len: 5 item_len: 5
embed_size: 16 embed_size: 16
batch_size: 1 batch_size: 1
optimizer:
class: sgd
learning_rate: 0.01
strategy: async
#use infer_runner mode and modify 'phase' below if infer #use infer_runner mode and modify 'phase' below if infer
mode: train_runner mode: [train_runner, infer_runner]
#mode: infer_runner
runner: runner:
- name: train_runner - name: train_runner
...@@ -48,19 +49,22 @@ runner: ...@@ -48,19 +49,22 @@ runner:
epochs: 3 epochs: 3
save_checkpoint_interval: 2 save_checkpoint_interval: 2
save_inference_interval: 4 save_inference_interval: 4
save_checkpoint_path: "increment" save_checkpoint_path: "increment_listwise"
save_inference_path: "inference" save_inference_path: "inference"
print_interval: 1
phases: [train]
- name: infer_runner - name: infer_runner
class: infer class: infer
init_model_path: "increment/0" init_model_path: "increment_listwise/2"
device: cpu device: cpu
phases: [infer]
phase: phase:
- name: train - name: train
model: "{workspace}/model.py" model: "{workspace}/model.py"
dataset_name: dataset_train dataset_name: dataset_train
thread_num: 1 thread_num: 1
#- name: infer - name: infer
# model: "{workspace}/model.py" model: "{workspace}/model.py"
# dataset_name: dataset_infer dataset_name: dataset_infer
# thread_num: 1 thread_num: 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册