From 60dbaba78c9f9f01d8f7ac40dbe5411e695256b3 Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Fri, 24 Jul 2020 11:48:29 +0800 Subject: [PATCH] fix batch bug in listwise --- models/rerank/listwise/config.yaml | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/models/rerank/listwise/config.yaml b/models/rerank/listwise/config.yaml index 6d06ab09..82432e1c 100644 --- a/models/rerank/listwise/config.yaml +++ b/models/rerank/listwise/config.yaml @@ -17,29 +17,30 @@ workspace: "paddlerec.models.rerank.listwise" dataset: - name: dataset_train + batch_size: 5 type: DataLoader data_path: "{workspace}/data/train" data_converter: "{workspace}/random_reader.py" - name: dataset_infer + batch_size: 5 type: DataLoader data_path: "{workspace}/data/test" data_converter: "{workspace}/random_reader.py" hyper_parameters: + optimizer: + class: sgd + learning_rate: 0.01 + strategy: async hidden_size: 128 user_vocab: 200 item_vocab: 1000 item_len: 5 embed_size: 16 batch_size: 1 - optimizer: - class: sgd - learning_rate: 0.01 - strategy: async #use infer_runner mode and modify 'phase' below if infer -mode: train_runner -#mode: infer_runner +mode: [train_runner, infer_runner] runner: - name: train_runner @@ -48,19 +49,22 @@ runner: epochs: 3 save_checkpoint_interval: 2 save_inference_interval: 4 - save_checkpoint_path: "increment" + save_checkpoint_path: "increment_listwise" save_inference_path: "inference" + print_interval: 1 + phases: [train] - name: infer_runner class: infer - init_model_path: "increment/0" + init_model_path: "increment_listwise/2" device: cpu + phases: [infer] phase: - name: train model: "{workspace}/model.py" dataset_name: dataset_train thread_num: 1 - #- name: infer - # model: "{workspace}/model.py" - # dataset_name: dataset_infer - # thread_num: 1 +- name: infer + model: "{workspace}/model.py" + dataset_name: dataset_infer + thread_num: 1 -- GitLab