diff --git a/models/match/dssm/data/preprocess.py b/models/match/dssm/data/preprocess.py index c2a41051eeff01c861acdfb93a0188110684085b..566a149422f7dc53b019de1e718819a5f53a4fee 100644 --- a/models/match/dssm/data/preprocess.py +++ b/models/match/dssm/data/preprocess.py @@ -63,7 +63,8 @@ print("build dict done") #划分训练集和测试集 query_list = list(pos_dict.keys()) #print(len(query_list)) -#random.shuffle(query_list) +np.random.seed(107) +np.random.shuffle(query_list) train_query = query_list[:11600] test_query = query_list[11600:] diff --git a/models/match/dssm/readme.md b/models/match/dssm/readme.md index 5228e634b3bd0d97805130b58a12a3ee07a2f6e0..7acff15ba760ad2d2b08ebcfa410f8911a3b70df 100644 --- a/models/match/dssm/readme.md +++ b/models/match/dssm/readme.md @@ -153,11 +153,11 @@ label.txt中对应的测试集中的标签 4. 退回dssm目录中,打开文件config.yaml,更改其中的参数 将workspace改为您当前的绝对路径。(可用pwd命令获取绝对路径) -将dataset_train中的batch_size从8改为128 -将hyper_parameters中的slice_end从8改为128.当您需要改变batchsize的时候,这个参数也需要随之变化 -将dataset_train中的data_path改为{workspace}/data/big_train -将dataset_infer中的data_path改为{workspace}/data/big_test -将hyper_parameters中的trigram_d改为5913 +将dataset_train中的batch_size从8改为128 +将hyper_parameters中的slice_end从8改为128.当您需要改变batchsize的时候,这个参数也需要随之变化 +将dataset_train中的data_path改为{workspace}/data/big_train +将dataset_infer中的data_path改为{workspace}/data/big_test +将hyper_parameters中的trigram_d改为5913 5. 执行脚本,开始训练.脚本会运行python -m paddlerec.run -m ./config.yaml启动训练,并将结果输出到result文件中。然后启动transform.py整合数据,最后计算出正逆序指标: ``` diff --git a/models/match/readme.md b/models/match/readme.md index 440ad9796605aefe94c2f822daebb73062bfed24..4e614b48f9a81635a6a54334d32f8747dcd68e2d 100755 --- a/models/match/readme.md +++ b/models/match/readme.md @@ -53,6 +53,6 @@ python -m paddlerec.run -m models/contentunderstanding/match-pyramid/config.yaml | 数据集 | 模型 | 正逆序比 | map | | :------------------: | :--------------------: | :---------: |:---------: | -| zhidao | DSSM | 2.25 | -- | +| zhidao | DSSM | 2.75 | -- | | Letor07 | match-pyramid | -- | 0.42 | -| zhidao | multiview-simnet | 1.72 | -- | +| zhidao | multiview-simnet | 13.67 | -- |