From 8c7d113ee3205844273bf434acf07dc95fa53185 Mon Sep 17 00:00:00 2001 From: zhang wenhui Date: Mon, 31 Aug 2020 10:53:23 +0800 Subject: [PATCH] fix esmm (#172) * fix esmm * fix esmm Co-authored-by: wuzhihua <35824027+fuyinno4@users.noreply.github.com> Co-authored-by: tangwei12 --- models/multitask/esmm/README.md | 23 ++++++++++++++++------- models/multitask/esmm/config.yaml | 14 +++++++------- 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/models/multitask/esmm/README.md b/models/multitask/esmm/README.md index 91a1df76..aecd9eda 100644 --- a/models/multitask/esmm/README.md +++ b/models/multitask/esmm/README.md @@ -50,11 +50,6 @@ ESMM是发表在 SIGIR’2018 的论文[《Entire Space Multi-Task Model: An E 数据地址:[Ali-CCP:Alibaba Click and Conversion Prediction]( https://tianchi.aliyun.com/datalab/dataSet.html?dataId=408 ) -``` -cd data -sh run.sh -``` - 数据格式参见demo数据:data/train @@ -108,11 +103,25 @@ CPU环境 ## 论文复现 -用原论文的完整数据复现论文效果需要在config.yaml中修改batch_size=1000, thread_num=8, epoch_num=4 +由于原论文的数据太大,我们选取了部分数据作为训练和测试数据, 建议使用gpu训练。 + +我们的测试ctr auc为0.79+,ctcvr auc为0.82+。 +``` +wget https://paddlerec.bj.bcebos.com/esmm/traindata_10w.csv +wget https://paddlerec.bj.bcebos.com/esmm/testdata_10w.csv +mkdir data/train_data data/test_data +mv traindata_10w.csv data/train_data +mv testdata_10w.csv data/test_data +``` -修改后运行方案:修改config.yaml中的'workspace'为config.yaml的目录位置,执行 +用原论文的完整数据复现论文效果需要在config.yaml中修改batch_size=1024, epoch=10, device=gpu, selected_gpus:"0" +具体配置可以下载config_10w.yaml文件 +``` +wget https://paddlerec.bj.bcebos.com/esmm/config_10w.yaml +``` +修改后运行 ``` python -m paddlerec.run -m /home/your/dir/config.yaml #调试模式 直接指定本地config的绝对路径 ``` diff --git a/models/multitask/esmm/config.yaml b/models/multitask/esmm/config.yaml index 2a4478ba..25cfbe91 100644 --- a/models/multitask/esmm/config.yaml +++ b/models/multitask/esmm/config.yaml @@ -17,19 +17,19 @@ workspace: "models/multitask/esmm" dataset: - name: dataset_train - batch_size: 1 + batch_size: 5 type: QueueDataset data_path: "{workspace}/data/train" data_converter: "{workspace}/esmm_reader.py" - name: dataset_infer - batch_size: 1 + batch_size: 5 type: QueueDataset data_path: "{workspace}/data/test" data_converter: "{workspace}/esmm_reader.py" hyper_parameters: - vocab_size: 10000 - embed_size: 128 + vocab_size: 737946 + embed_size: 12 optimizer: class: adam learning_rate: 0.001 @@ -43,15 +43,15 @@ runner: class: train device: cpu epochs: 3 - save_checkpoint_interval: 2 + save_checkpoint_interval: 1 save_inference_interval: 4 - save_checkpoint_path: "increment" + save_checkpoint_path: "increment_esmm" save_inference_path: "inference" print_interval: 10 phases: [train] - name: infer_runner class: infer - init_model_path: "increment/1" + init_model_path: "increment_esmm/1" device: cpu print_interval: 1 phases: [infer] -- GitLab