Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
models
提交
8e4eebfc
M
models
项目概览
PaddlePaddle
/
models
大约 1 年 前同步成功
通知
222
Star
6828
Fork
2962
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
602
列表
看板
标记
里程碑
合并请求
255
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
models
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
602
Issue
602
列表
看板
标记
里程碑
合并请求
255
合并请求
255
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8e4eebfc
编写于
2月 04, 2020
作者:
A
anpark
提交者:
GitHub
2月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add kdd2020-p3ac (#4238)
* update README * fix monopoly info * add kdd2020-p3ac * add kdd2020-p3ac
上级
a026656e
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
1667 addition
and
7 deletion
+1667
-7
PaddleST/README.md
PaddleST/README.md
+1
-0
PaddleST/Research/CIKM2019-MONOPOLY/README.md
PaddleST/Research/CIKM2019-MONOPOLY/README.md
+1
-1
PaddleST/Research/CIKM2019-MONOPOLY/conf/house_price/house_price.local.template
...2019-MONOPOLY/conf/house_price/house_price.local.template
+1
-1
PaddleST/Research/CIKM2019-MONOPOLY/nets/house_price/house_price.py
...esearch/CIKM2019-MONOPOLY/nets/house_price/house_price.py
+8
-5
PaddleST/Research/KDD2020-P3AC/README.md
PaddleST/Research/KDD2020-P3AC/README.md
+78
-0
PaddleST/Research/KDD2020-P3AC/conf/poi_qac_personalized/poi_qac_personalized.local.conf.template
...qac_personalized/poi_qac_personalized.local.conf.template
+342
-0
PaddleST/Research/KDD2020-P3AC/datasets/poi_qac_personalized/__init__.py
...ch/KDD2020-P3AC/datasets/poi_qac_personalized/__init__.py
+0
-0
PaddleST/Research/KDD2020-P3AC/datasets/poi_qac_personalized/qac_personalized.py
...20-P3AC/datasets/poi_qac_personalized/qac_personalized.py
+577
-0
PaddleST/Research/KDD2020-P3AC/docs/framework.png
PaddleST/Research/KDD2020-P3AC/docs/framework.png
+0
-0
PaddleST/Research/KDD2020-P3AC/nets/poi_qac_personalized/__init__.py
...search/KDD2020-P3AC/nets/poi_qac_personalized/__init__.py
+0
-0
PaddleST/Research/KDD2020-P3AC/nets/poi_qac_personalized/qac_personalized.py
...DD2020-P3AC/nets/poi_qac_personalized/qac_personalized.py
+659
-0
PaddleST/Research/KDD2020-P3AC/test/__init__.py
PaddleST/Research/KDD2020-P3AC/test/__init__.py
+0
-0
未找到文件。
PaddleST/README.md
浏览文件 @
8e4eebfc
...
...
@@ -19,3 +19,4 @@ The full list of frontier industrial projects:
|应用项目|项目简介|开源地址|
|----|----|----|
||||
PaddleST/Research/CIKM2019-MONOPOLY/README.md
浏览文件 @
8e4eebfc
...
...
@@ -29,7 +29,7 @@ We have conducted extensive experiments with the large-scale urban data of sever
1.
paddle安装
本项目依赖于Paddle Fluid 1.
5
.1 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装
本项目依赖于Paddle Fluid 1.
6
.1 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装
2.
下载代码
...
...
PaddleST/Research/CIKM2019-MONOPOLY/conf/house_price/house_price.local.template
浏览文件 @
8e4eebfc
...
...
@@ -280,7 +280,7 @@ num_in_dimension: ${DEFAULT:num_in_dimension}
num_out_dimension: ${DEFAULT:num_out_dimension}
# Directory where the results are saved to
eval_dir: ${Train:train_dir}/
epoch<s>
eval_dir: ${Train:train_dir}/
checkpoint_1
# The number of samples in each batch
batch_size: ${DEFAULT:eval_batch_size}
PaddleST/Research/CIKM2019-MONOPOLY/nets/house_price/house_price.py
浏览文件 @
8e4eebfc
...
...
@@ -77,8 +77,7 @@ class HousePrice(BaseNet):
act
=
act
)
return
_fc
def
pred_format
(
self
,
result
):
def
pred_format
(
self
,
result
,
**
kwargs
):
"""
format pred output
"""
...
...
@@ -118,7 +117,7 @@ class HousePrice(BaseNet):
max_house_num
=
FLAGS
.
max_house_num
max_public_num
=
FLAGS
.
max_public_num
pred_keys
=
inputs
.
keys
()
#step1. get house self feature
if
FLAGS
.
with_house_attr
:
def
_get_house_attr
(
name
,
attr_vec_size
):
...
...
@@ -136,6 +135,10 @@ class HousePrice(BaseNet):
else
:
#no house attr
house_vec
=
fluid
.
layers
.
reshape
(
inputs
[
"house_business"
],
[
-
1
,
self
.
city_info
.
business_num
])
pred_keys
.
remove
(
'house_wuye'
)
pred_keys
.
remove
(
'house_kfs'
)
pred_keys
.
remove
(
'house_age'
)
pred_keys
.
remove
(
'house_lou'
)
house_self
=
self
.
fc_fn
(
house_vec
,
1
,
act
=
'sigmoid'
,
layer_name
=
'house_self'
,
FLAGS
=
FLAGS
)
house_self
=
fluid
.
layers
.
reshape
(
house_self
,
[
-
1
,
1
])
...
...
@@ -192,8 +195,8 @@ class HousePrice(BaseNet):
net_output
=
{
"debug_output"
:
debug_output
,
"model_output"
:
model_output
}
model_output
[
'feeded_var_names'
]
=
inputs
.
keys
()
model_output
[
'
target_var
s'
]
=
[
label
,
pred
]
model_output
[
'feeded_var_names'
]
=
pred_keys
model_output
[
'
fetch_target
s'
]
=
[
label
,
pred
]
model_output
[
'loss'
]
=
avg_cost
#debug_output['pred'] = pred
...
...
PaddleST/Research/KDD2020-P3AC/README.md
0 → 100644
浏览文件 @
8e4eebfc
# P3AC
## 任务说明(Introduction)
TODO
![](
docs/framework.png
)
## 安装说明(Install Guide)
### 环境准备
1.
paddle安装
本项目依赖于Paddle Fluid 1.6.1 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装
2.
下载代码
克隆数据集代码库到本地, 本代码依赖[Paddle-EPEP框架](https://github.com/PaddlePaddle/epep)
```
git clone https://github.com/PaddlePaddle/epep.git
cd epep
git clone https://github.com/PaddlePaddle/models.git
ln -s models/PaddleST/Research/KDD2020-P3AC/conf/poi_qac_personalized conf/poi_qac_personalized
ln -s models/PaddleST/Research/KDD2020-P3AC/datasets/poi_qac_personalized datasets/poi_qac_personalized
ln -s models/PaddleST/Research/KDD2020-P3AC/nets/poi_qac_personalized nets/poi_qac_personalized
```
3.
环境依赖
python版本依赖python 2.7
### 实验说明
1.
数据准备
TODO
```
#script to download
```
2.
模型训练
```
cp conf/poi_qac_personalized/poi_qac_personalized.local.conf.template conf/poi_qac_personalized/poi_qac_personalized.local.conf
sh run.sh -c conf/poi_qac_personalized/poi_qac_personalized.local.conf -m train [ -g 0 ]
```
3.
模型评估
```
pred_gpu=$1
mode=$2 #query, poi, eval
if [ $# -lt 2 ];then
exit 1
fi
#编辑conf/poi_qac_personalized/poi_qac_personalized.local.conf.template,打开 CUDA_VISIBLE_DEVICES: <pred_gpu>
cp conf/poi_qac_personalized/poi_qac_personalized.local.conf.template conf/poi_qac_personalized/poi_qac_personalized.local.conf
sed -i "s#<pred_gpu>#$pred_gpu#g" conf/poi_qac_personalized/poi_qac_personalized.local.conf
sed -i "s#<mode>#$mode#g" conf/poi_qac_personalized/poi_qac_personalized.local.conf
sh run.sh -c poi_qac_personalized.local -m predict 1>../tmp/$mode-pred$pred_gpu.out 2>../tmp/$mode-pred$pred_gpu.err
```
## 论文下载(Paper Download)
Please feel free to review our paper :)
TODO
## 引用格式(Paper Citation)
TODO
PaddleST/Research/KDD2020-P3AC/conf/poi_qac_personalized/poi_qac_personalized.local.conf.template
0 → 100644
浏览文件 @
8e4eebfc
[DEFAULT]
sample_seed: 1234
# The value in `DEFAULT` section will be referenced by other sections.
# For convinence, we will put the variables which changes frequently here and
# let other section refer them
debug_mode: False
#reader: dataset | pyreader | async | datafeed | sync
#data_reader: dataset
dataset_mode: Memory
#data_reader: datafeed
data_reader: pyreader
py_reader_iterable: False
#model_type: lstm_net
model_type: cnn_net
vocab_size: 93896
#emb_dim: 200
emb_dim: 128
time_size: 28
tag_size: 371
fc_dim: 64
emb_lr: 1.0
base_lr: 0.001
margin: 0.35
window_size: 3
pooling_type: max
#activate: sigmoid
activate: None
use_attention: True
use_personal: True
max_seq_len: 128
prefix_word_id: True
#print_period: 200
#TODO personal_resident_drive + neg_only_sample
#query cityid trendency, poi tag/alias
#local-cpu | local-gpu | pserver-cpu | pserver-gpu | nccl2
platform: local-gpu
# Input settings
dataset_name: PoiQacPersonalized
CUDA_VISIBLE_DEVICES: 0,1,2,3
#CUDA_VISIBLE_DEVICES: <pred_gpu>
train_batch_size: 128
#train_batch_size: 2
eval_batch_size: 2
#file_list: ../tmp/data/poi/qac/train_data/part-00000
dataset_dir: ../tmp/data/poi/qac/train_data
#init_train_params: ../tmp/data/poi/qac/tencent_pretrain.words
tag_dict_path: None
qac_dict_path: None
kv_path: None
#qac_dict_path: ./datasets/poi_qac_personalized/qac_term.dict
#tag_dict_path: ./datasets/poi_qac_personalized/poi_tag.dict
#kv_path: ../tmp/data/poi/qac/kv
# Model settings
model_name: PoiQacPersonalized
preprocessing_name: None
#file_pattern: %s-part-*
file_pattern: part-
num_in_dimension: 3
num_out_dimension: 4
# Learning options
num_samples_train: 100
num_samples_eval: 10
max_number_of_steps: 155000
[Convert]
# The name of the dataset to convert
dataset_name: ${DEFAULT:dataset_name}
#dataset_dir: ${DEFAULT:dataset_dir}
dataset_dir: stream
# The output Records file name prefix.
dataset_split_name: train
# The number of Records per shard
num_per_shard: 100000
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The output file name pattern with two placeholders ("%s" and "%d"),
# it must correspond to the glob `file_pattern' in Train and Evaluate
# config sections
#file_pattern: %s-part-%05d
file_pattern: part-
[Train]
#######################
# Dataset Configure #
#######################
# The name of the dataset to load
dataset_name: ${DEFAULT:dataset_name}
# The directory where the dataset files are stored
dataset_dir: ${DEFAULT:dataset_dir}
# dataset_split_name
dataset_split_name: train
batch_shuffle_size: 128
#log_exp or hinge
#loss_func: hinge
loss_func: log_exp
neg_sample_num: 5
reader_batch: True
drop_last_batch: False
# The glob pattern for data path, `file_pattern' must contain only one "%s"
# which is the placeholder for split name (such as 'train', 'validation')
file_pattern: ${DEFAULT:file_pattern}
# The file type text or record
file_type: record
# kv path, used in image_sim
kv_path: ${DEFAULT:kv_path}
# The number of input sample for training
num_samples: ${DEFAULT:num_samples_train}
# The number of parallel readers that read data from the dataset
num_readers: 2
# The number of threads used to create the batches
num_preprocessing_threads: 2
# Number of epochs from dataset source
num_epochs_input: 10
###########################
# Basic Train Configure #
###########################
# Directory where checkpoints and event logs are written to.
train_dir: ../tmp/model/poi/qac/save_model
# The max number of ckpt files to store variables
save_max_to_keep: 40
# The frequency with which the model is saved, in seconds.
save_model_secs: None
# The frequency with which the model is saved, in steps.
save_model_steps: 5000
# The name of the architecture to train
model_name: ${DEFAULT:model_name}
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The dimensions of net output vector, it will be num of classes in image classify task
num_out_dimension: ${DEFAULT:num_out_dimension}
#####################################
# Training Optimization Configure #
#####################################
# The number of samples in each batch
batch_size: ${DEFAULT:train_batch_size}
# The maximum number of training steps
max_number_of_steps: ${DEFAULT:max_number_of_steps}
# The weight decay on the model weights
#weight_decay: 0.00000001
weight_decay: None
# The decay to use for the moving average. If left as None, then moving averages are not used
moving_average_decay: None
# ***************** learning rate options ***************** #
# Specifies how the learning rate is decayed. One of "fixed", "exponential" or "polynomial"
learning_rate_decay_type: fixed
# Learning rate decay factor
learning_rate_decay_factor: 0.1
# Proportion of training steps to perform linear learning rate warmup for
learning_rate_warmup_proportion: 0.1
init_learning_rate: 0
learning_rate_warmup_steps: 10000
# The minimal end learning rate used by a polynomial decay learning rate
end_learning_rate: 0.0001
# Number of epochs after which learning rate decays
num_epochs_per_decay: 10
# A boolean, whether or not it should cycle beyond decay_steps
learning_rate_polynomial_decay_cycle: False
# ******************* optimizer options ******************* #
# The name of the optimizer, one of the following:
# "adadelta", "adagrad", "adam", "ftrl", "momentum", "sgd" or "rmsprop"
#optimizer: weight_decay_adam
optimizer: adam
#optimizer: sgd
# Epsilon term for the optimizer, used for adadelta, adam, rmsprop
opt_epsilon: 1e-8
# conf for adadelta
# The decay rate for adadelta
adadelta_rho: 0.95
# Starting value for the AdaGrad accumulators
adagrad_initial_accumulator_value: 0.1
# conf for adam
# The exponential decay rate for the 1st moment estimates
adam_beta1: 0.9
# The exponential decay rate for the 2nd moment estimates
adam_beta2: 0.997
adam_weight_decay: 0.01
#adam_exclude_from_weight_decay: LayerNorm,layer_norm,bias
# conf for ftrl
# The learning rate power
ftrl_learning_rate_power: -0.1
# Starting value for the FTRL accumulators
ftrl_initial_accumulator_value: 0.1
# The FTRL l1 regularization strength
ftrl_l1: 0.0
# The FTRL l2 regularization strength
ftrl_l2: 0.01
# conf for momentum
# The momentum for the MomentumOptimizer and RMSPropOptimizer
momentum: 0.9
# conf for rmsprop
# Decay term for RMSProp
rmsprop_decay: 0.9
# Number of model clones to deploy
num_gpus: 3
#############################
# Log and Trace Configure #
#############################
# The frequency with which logs are print
log_every_n_steps: 100
# The frequency with which logs are trace.
trace_every_n_steps: 1
[Evaluate]
# process mode: pred, eval or export
#proc_name: eval
proc_name: pred
#data_reader: datafeed
py_reader_iterable: True
#platform: hadoop
platform: local-gpu
qac_dict_path: ./datasets/poi_qac_personalized/qac_term.dict
tag_dict_path: ./datasets/poi_qac_personalized/poi_tag.dict
#kv_path: ../tmp/data/poi/qac/kv
# The directory where the dataset files are stored
#file_list: ../tmp/x.bug
file_list: ../tmp/data/poi/qac/recall_data/<mode>/part-0<pred_gpu>
#file_list: ../tmp/data/poi/qac/ltr_data/<mode>/part-0<pred_gpu>
#dataset_dir: stream_record
# The directory where the model was written to or an absolute path to a checkpoint file
init_pretrain_model: ../tmp/model/poi/qac/save_model_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_personal_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_wordid_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_personal_wordid_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_attention_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_attention_personal_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_attention_wordid_logexp/checkpoint_125000
#init_pretrain_model: ../tmp/model/poi/qac/save_model_attention_personal_wordid_logexp/checkpoint_125000
model_type: cnn_net
fc_dim: 64
use_attention: False
use_personal: False
prefix_word_id: False
#dump_vec: query
#dump_vec: <mode>
dump_vec: eval
# The number of samples in each batch
#batch_size: ${DEFAULT:eval_batch_size}
batch_size: 1
# The file type text or record
#file_type: record
file_type: text
reader_batch: False
# only exectute evaluation once
eval_once: True
#######################
# Dataset Configure #
#######################
# The name of the dataset to load
dataset_name: ${DEFAULT:dataset_name}
# The name of the train/test split
dataset_split_name: validation
# The glob pattern for data path, `file_pattern' must contain only one "%s"
# which is the placeholder for split name (such as 'train', 'validation')
file_pattern: ${DEFAULT:file_pattern}
# The number of input sample for evaluation
num_samples: ${DEFAULT:num_samples_eval}
# The number of parallel readers that read data from the dataset
num_readers: 2
# The number of threads used to create the batches
num_preprocessing_threads: 1
# Number of epochs from dataset source
num_epochs_input: 1
# The name of the architecture to evaluate
model_name: ${DEFAULT:model_name}
# The dimensions of net input vectors, it is just used by svm dataset
# which of input are sparse tensors now
num_in_dimension: ${DEFAULT:num_in_dimension}
# The dimensions of net output vector, it will be num of classes in image classify task
num_out_dimension: ${DEFAULT:num_out_dimension}
# Directory where the results are saved to
eval_dir: ${Train:train_dir}/checkpoint_1
PaddleST/Research/KDD2020-P3AC/datasets/poi_qac_personalized/__init__.py
0 → 100644
浏览文件 @
8e4eebfc
PaddleST/Research/KDD2020-P3AC/datasets/poi_qac_personalized/qac_personalized.py
0 → 100644
浏览文件 @
8e4eebfc
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Specify the brief poi_qac_personalized.py
"""
import
os
import
sys
import
re
import
time
import
numpy
as
np
import
random
import
paddle.fluid
as
fluid
from
datasets.base_dataset
import
BaseDataset
reload
(
sys
)
sys
.
setdefaultencoding
(
'gb18030'
)
base_rule
=
re
.
compile
(
"[
\1\2
]"
)
class
PoiQacPersonalized
(
BaseDataset
):
"""
PoiQacPersonalized dataset
"""
def
__init__
(
self
,
flags
):
super
(
PoiQacPersonalized
,
self
).
__init__
(
flags
)
self
.
inited_dict
=
False
def
parse_context
(
self
,
inputs
):
"""
provide input context
"""
"""
set inputs_kv: please set key as the same as layer.data.name
notice:
(1)
If user defined "inputs key" is different from layer.data.name,
the frame will rewrite "inputs key" with layer.data.name
(2)
The param "inputs" will be passed to user defined nets class through
the nets class interface function : net(self, FLAGS, inputs),
"""
if
self
.
_flags
.
use_personal
:
#inputs['user_loc_geoid'] = fluid.layers.data(name="user_loc_geoid", shape=[40],
# dtype="int64", lod_level=0) #from clk poi
#inputs['user_bound_geoid'] = fluid.layers.data(name="user_bound_geoid", shape=[40],
# dtype="int64", lod_level=0) #from clk poi
#inputs['user_time_id'] = fluid.layers.data(name="user_time_geoid", shape=[1],
# dtype="int64", lod_level=1) #from clk poi
inputs
[
'user_clk_geoid'
]
=
fluid
.
layers
.
data
(
name
=
"user_clk_geoid"
,
shape
=
[
40
],
dtype
=
"int64"
,
lod_level
=
0
)
#from clk poi
inputs
[
'user_tag_id'
]
=
fluid
.
layers
.
data
(
name
=
"user_tag_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
#from clk poi
inputs
[
'user_resident_geoid'
]
=
fluid
.
layers
.
data
(
name
=
"user_resident_geoid"
,
shape
=
[
40
],
dtype
=
"int64"
,
lod_level
=
0
)
#home, company
inputs
[
'user_navi_drive'
]
=
fluid
.
layers
.
data
(
name
=
"user_navi_drive"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
0
)
#driver or not
inputs
[
'prefix_letter_id'
]
=
fluid
.
layers
.
data
(
name
=
"prefix_letter_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
if
self
.
_flags
.
prefix_word_id
:
inputs
[
'prefix_word_id'
]
=
fluid
.
layers
.
data
(
name
=
"prefix_word_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'prefix_loc_geoid'
]
=
fluid
.
layers
.
data
(
name
=
"prefix_loc_geoid"
,
shape
=
[
40
],
dtype
=
"int64"
,
lod_level
=
0
)
if
self
.
_flags
.
use_personal
:
inputs
[
'prefix_time_id'
]
=
fluid
.
layers
.
data
(
name
=
"prefix_time_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'pos_name_letter_id'
]
=
fluid
.
layers
.
data
(
name
=
"pos_name_letter_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'pos_name_word_id'
]
=
fluid
.
layers
.
data
(
name
=
"pos_name_word_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'pos_addr_letter_id'
]
=
fluid
.
layers
.
data
(
name
=
"pos_addr_letter_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'pos_addr_word_id'
]
=
fluid
.
layers
.
data
(
name
=
"pos_addr_word_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'pos_loc_geoid'
]
=
fluid
.
layers
.
data
(
name
=
"pos_loc_geoid"
,
shape
=
[
40
],
dtype
=
"int64"
,
lod_level
=
0
)
if
self
.
_flags
.
use_personal
:
inputs
[
'pos_tag_id'
]
=
fluid
.
layers
.
data
(
name
=
"pos_tag_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
if
self
.
is_training
:
inputs
[
'neg_name_letter_id'
]
=
fluid
.
layers
.
data
(
name
=
"neg_name_letter_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'neg_name_word_id'
]
=
fluid
.
layers
.
data
(
name
=
"neg_name_word_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'neg_addr_letter_id'
]
=
fluid
.
layers
.
data
(
name
=
"neg_addr_letter_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'neg_addr_word_id'
]
=
fluid
.
layers
.
data
(
name
=
"neg_addr_word_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
inputs
[
'neg_loc_geoid'
]
=
fluid
.
layers
.
data
(
name
=
"neg_loc_geoid"
,
shape
=
[
40
],
dtype
=
"int64"
,
lod_level
=
0
)
if
self
.
_flags
.
use_personal
:
inputs
[
'neg_tag_id'
]
=
fluid
.
layers
.
data
(
name
=
"neg_tag_id"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
1
)
else
:
#for predict label
inputs
[
'label'
]
=
fluid
.
layers
.
data
(
name
=
"label"
,
shape
=
[
1
],
dtype
=
"int64"
,
lod_level
=
0
)
context
=
{
"inputs"
:
inputs
}
#set debug list, print info during training
#debug_list = [key for key in inputs]
#context["debug_list"] = ["prefix_ids", "label"]
return
context
def
_init_dict
(
self
):
"""
init dict
"""
if
self
.
inited_dict
:
return
if
self
.
_flags
.
platform
in
(
'local-gpu'
,
'pserver-gpu'
,
'slurm'
):
gpu_id
=
int
(
os
.
environ
.
get
(
'FLAGS_selected_gpus'
,
0
))
self
.
place
=
fluid
.
CUDAPlace
(
gpu_id
)
else
:
self
.
place
=
fluid
.
CPUPlace
()
self
.
term_dict
=
{}
if
self
.
_flags
.
qac_dict_path
is
not
None
:
with
open
(
self
.
_flags
.
qac_dict_path
,
'r'
)
as
f
:
for
line
in
f
:
term
,
term_id
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
)
self
.
term_dict
[
term
]
=
int
(
term_id
)
self
.
tag_info
=
{}
if
self
.
_flags
.
tag_dict_path
is
not
None
:
with
open
(
self
.
_flags
.
tag_dict_path
,
'r'
)
as
f
:
for
line
in
f
:
tag
,
level
,
tid
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
)
self
.
tag_info
[
tag
]
=
map
(
int
,
tid
.
split
(
','
))
self
.
user_kv
=
None
self
.
poi_kv
=
None
if
self
.
_flags
.
kv_path
is
not
None
:
self
.
poi_kv
=
{}
with
open
(
self
.
_flags
.
kv_path
+
"/sug_raw.dat"
,
"r"
)
as
f
:
for
line
in
f
:
pid
,
val
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
,
1
)
self
.
poi_kv
[
pid
]
=
val
self
.
user_kv
=
{}
with
open
(
self
.
_flags
.
kv_path
+
"/user_profile.dat"
,
"r"
)
as
f
:
for
line
in
f
:
uid
,
val
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
,
1
)
self
.
user_kv
[
uid
]
=
val
sys
.
stderr
.
write
(
"load user kv:%s
\n
"
%
self
.
_flags
.
kv_path
)
self
.
inited_dict
=
True
sys
.
stderr
.
write
(
"loaded term dict:%s, tag_dict:%s
\n
"
%
(
len
(
self
.
term_dict
),
len
(
self
.
tag_info
)))
def
_get_time_id
(
self
,
ts
):
"""
get time id:0-27
"""
ts_struct
=
time
.
localtime
(
ts
)
week
=
ts_struct
[
6
]
hour
=
ts_struct
[
3
]
base
=
0
if
hour
>=
0
and
hour
<
6
:
base
=
0
elif
hour
>=
6
and
hour
<
12
:
base
=
1
elif
hour
>=
12
and
hour
<
18
:
base
=
2
else
:
base
=
3
final
=
week
*
4
+
base
return
final
def
_pad_batch_data
(
self
,
insts
,
pad_idx
,
return_max_len
=
True
,
return_num_token
=
False
):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list
=
[]
max_len
=
max
(
len
(
inst
)
for
inst
in
insts
)
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data
=
np
.
array
(
[
inst
+
[
pad_idx
]
*
(
max_len
-
len
(
inst
))
for
inst
in
insts
])
return_list
+=
[
inst_data
.
astype
(
"int64"
).
reshape
([
-
1
,
1
])]
if
return_max_len
:
return_list
+=
[
max_len
]
if
return_num_token
:
num_token
=
0
for
inst
in
insts
:
num_token
+=
len
(
inst
)
return_list
+=
[
num_token
]
return
return_list
if
len
(
return_list
)
>
1
else
return_list
[
0
]
def
_get_tagid
(
self
,
tag_str
):
if
len
(
tag_str
.
strip
())
<
1
:
return
[]
tags
=
set
()
for
t
in
tag_str
.
split
():
if
':'
in
t
:
t
=
t
.
split
(
':'
)[
0
]
t
=
t
.
lower
()
if
t
in
self
.
tag_info
:
tags
.
update
(
self
.
tag_info
[
t
])
return
list
(
tags
)
def
_get_ids
(
self
,
seg_info
):
#phraseseg, basicseg = seg_info
if
len
(
seg_info
)
<
2
:
return
[
0
],
[
0
]
_
,
bt
=
[
x
.
split
(
'
\3
'
)
for
x
in
seg_info
]
rq
=
""
.
join
(
bt
)
bl
=
[
t
.
encode
(
'gb18030'
)
for
t
in
rq
.
decode
(
'gb18030'
)]
letter_ids
=
[]
for
t
in
bl
:
letter_ids
.
append
(
self
.
term_dict
.
get
(
t
.
lower
(),
1
))
if
len
(
letter_ids
)
>=
self
.
_flags
.
max_seq_len
:
break
word_ids
=
[]
for
t
in
bt
:
word_ids
.
append
(
self
.
term_dict
.
get
(
t
.
lower
(),
1
))
if
len
(
word_ids
)
>=
self
.
_flags
.
max_seq_len
:
break
return
letter_ids
,
word_ids
def
_get_poi_ids
(
self
,
poi_str
,
max_num
=
0
):
if
len
(
poi_str
)
<
1
:
return
[]
ids
=
[]
all_p
=
poi_str
.
split
(
'
\1
'
)
pidx
=
range
(
0
,
len
(
all_p
))
if
max_num
>
0
:
#neg sample: last 10 is negative sampling
if
len
(
all_p
)
>
max_num
:
neg_s_idx
=
len
(
all_p
)
-
10
pidx
=
[
1
,
2
]
+
random
.
sample
(
pidx
[
3
:
neg_s_idx
],
max_num
-
13
)
+
pidx
[
neg_s_idx
:]
else
:
pidx
=
pidx
[
1
:]
bids
=
set
()
for
x
in
pidx
:
poi_seg
=
all_p
[
x
].
split
(
'
\2
'
)
tagid
=
[
0
]
if
len
(
poi_seg
)
>=
9
:
#name, uid, index, name_lid, name_wid, addr_lid, addr_wid, geohash, tagid
bid
=
poi_seg
[
1
]
name_letter_id
=
map
(
int
,
poi_seg
[
3
].
split
())[:
self
.
_flags
.
max_seq_len
]
name_word_id
=
map
(
int
,
poi_seg
[
4
].
split
())[:
self
.
_flags
.
max_seq_len
]
addr_letter_id
=
map
(
int
,
poi_seg
[
5
].
split
())[:
self
.
_flags
.
max_seq_len
]
addr_word_id
=
map
(
int
,
poi_seg
[
6
].
split
())[:
self
.
_flags
.
max_seq_len
]
ghid
=
map
(
int
,
poi_seg
[
7
].
split
(
','
))
if
len
(
poi_seg
[
8
])
>
0
:
tagid
=
map
(
int
,
poi_seg
[
8
].
split
(
','
))
else
:
#raw_text: uid, name, addr, xy, tag, alias
bid
=
poi_seg
[
0
]
name_letter_id
,
name_word_id
=
self
.
_get_ids
(
poi_seg
[
1
])
addr_letter_id
,
addr_word_id
=
self
.
_get_ids
(
poi_seg
[
2
])
ghid
=
map
(
int
,
poi_seg
[
3
].
split
(
','
))
if
len
(
poi_seg
[
4
])
>
0
:
tagid
=
map
(
int
,
poi_seg
[
4
].
split
(
','
))
if
not
self
.
is_training
and
name_letter_id
==
[
0
]:
continue
# empty name
if
bid
in
bids
:
continue
bids
.
add
(
bid
)
ids
.
append
([
name_letter_id
,
name_word_id
,
addr_letter_id
,
addr_word_id
,
ghid
,
tagid
])
return
ids
def
_get_user_ids
(
self
,
cuid
,
user_str
):
if
self
.
user_kv
:
if
cuid
in
self
.
user_kv
:
val
=
self
.
user_kv
[
cuid
]
drive_conf
,
clk_p
,
res_p
=
val
.
split
(
'
\t
'
)
else
:
return
[]
else
:
if
len
(
user_str
)
<
1
:
return
[]
drive_conf
,
clk_p
,
res_p
=
user_str
.
split
(
'
\1
'
)
ids
=
[]
conf1
,
conf2
=
drive_conf
.
split
(
'
\2
'
)
is_driver
=
0
if
float
(
conf1
)
>
0.5
or
float
(
conf2
)
>
1.5
:
is_driver
=
1
user_clk_geoid
=
[
0
]
*
40
user_tag_id
=
set
()
if
len
(
clk_p
)
>
0
:
if
self
.
user_kv
:
for
p
in
clk_p
.
split
(
'
\1
'
):
bid
,
time
,
loc
,
bound
=
p
.
split
(
'
\2
'
)
if
bid
in
self
.
poi_kv
:
v
=
self
.
poi_kv
[
bid
]
v
=
base_rule
.
sub
(
""
,
v
)
info
=
v
.
split
(
'
\t
'
)
#name, addr, ghid, tag, alias
ghid
=
map
(
int
,
info
[
2
].
split
(
','
))
for
i
in
range
(
len
(
user_clk_geoid
)):
user_clk_geoid
[
i
]
=
user_clk_geoid
[
i
]
|
ghid
[
i
]
user_tag_id
.
update
(
self
.
_get_tagid
(
info
[
4
]))
else
:
for
p
in
clk_p
.
split
(
'
\2
'
):
bid
,
gh
,
tags
=
p
.
split
(
'
\3
'
)
ghid
=
map
(
int
,
gh
.
split
(
','
))
for
i
in
range
(
len
(
user_clk_geoid
)):
user_clk_geoid
[
i
]
=
user_clk_geoid
[
i
]
|
ghid
[
i
]
if
len
(
tags
)
>
0
:
user_tag_id
.
update
(
tags
.
split
(
','
))
if
len
(
user_tag_id
)
<
1
:
user_tag_id
=
[
0
]
user_tag_id
=
map
(
int
,
list
(
user_tag_id
))
ids
.
append
(
user_clk_geoid
)
ids
.
append
(
user_tag_id
)
user_res_geoid
=
[
0
]
*
40
if
len
(
res_p
)
>
0
:
if
self
.
user_kv
:
for
p
in
res_p
.
split
(
'
\1
'
):
bid
,
conf
=
p
.
split
(
'
\2
'
)
if
bid
in
self
.
poi_kv
:
v
=
self
.
poi_kv
[
bid
]
v
=
base_rule
.
sub
(
""
,
v
)
info
=
v
.
split
(
'
\t
'
)
#name, addr, ghid, tag, alias
ghid
=
map
(
int
,
info
[
2
].
split
(
','
))
for
i
in
range
(
len
(
user_res_geoid
)):
user_res_geoid
[
i
]
=
user_res_geoid
[
i
]
|
ghid
[
i
]
else
:
for
p
in
res_p
.
split
(
'
\2
'
):
bid
,
gh
,
conf
=
p
.
split
(
'
\3
'
)
ghid
=
map
(
int
,
gh
.
split
(
','
))
for
i
in
range
(
len
(
user_res_geoid
)):
user_res_geoid
[
i
]
=
user_res_geoid
[
i
]
|
ghid
[
i
]
ids
.
append
(
user_res_geoid
)
ids
.
append
([
is_driver
])
return
ids
def
parse_batch
(
self
,
data_gen
):
"""
reader_batch must be true: only for train & loss_func is log_exp, other use parse_oneline
pos : neg = 1 : N
"""
batch_data
=
{}
def
_get_lod
(
k
):
#sys.stderr.write("%s\t%s\t%s\n" % (k, " ".join(map(str, batch_data[k][0])),
# " ".join(map(str, batch_data[k][1])) ))
return
fluid
.
create_lod_tensor
(
np
.
array
(
batch_data
[
k
][
0
]).
reshape
([
-
1
,
1
]),
[
batch_data
[
k
][
1
]],
self
.
place
)
keys
=
None
for
line
in
data_gen
():
for
s
in
self
.
parse_oneline
(
line
):
for
k
,
v
in
s
:
if
k
not
in
batch_data
:
batch_data
[
k
]
=
[[],
[]]
if
not
isinstance
(
v
[
0
],
list
):
v
=
[
v
]
#pos 1 to N
for
j
in
v
:
batch_data
[
k
][
0
].
extend
(
j
)
batch_data
[
k
][
1
].
append
(
len
(
j
))
if
keys
is
None
:
keys
=
[
k
for
k
,
_
in
s
]
if
len
(
batch_data
[
keys
[
0
]][
1
])
==
self
.
_flags
.
batch_size
:
yield
[(
k
,
_get_lod
(
k
))
for
k
in
keys
]
batch_data
=
{}
if
not
self
.
_flags
.
drop_last_batch
and
len
(
batch_data
)
!=
0
:
yield
[(
k
,
_get_lod
(
k
))
for
k
in
keys
]
def
parse_oneline
(
self
,
line
):
"""
datareader interface
"""
self
.
_init_dict
()
qid
,
user
,
prefix
,
pos_poi
,
neg_poi
=
line
.
strip
(
"
\r\n
"
).
split
(
"
\t
"
)
cuid
,
time
,
loc_cityid
,
bound_cityid
,
loc_gh
,
bound_gh
=
qid
.
split
(
'_'
)
#step1
user_input
=
[]
if
self
.
_flags
.
use_personal
:
user_ids
=
self
.
_get_user_ids
(
cuid
,
user
)
if
len
(
user_ids
)
<
1
:
user_ids
=
[[
0
]
*
40
,
[
0
],
[
0
]
*
40
,
[
0
]]
user_input
=
[(
"user_clk_geoid"
,
user_ids
[
0
]),
\
(
"user_tag_id"
,
user_ids
[
1
]),
\
(
"user_resident_geoid"
,
user_ids
[
2
]),
\
(
"user_navi_drive"
,
user_ids
[
3
])]
#step2
prefix_seg
=
prefix
.
split
(
'
\2
'
)
prefix_time_id
=
self
.
_get_time_id
(
int
(
time
))
prefix_loc_geoid
=
[
0
]
*
40
if
len
(
prefix_seg
)
>=
4
:
#query, letterid, wordid, ghid, poslen, neglen
prefix_letter_id
=
map
(
int
,
prefix_seg
[
1
].
split
())[:
self
.
_flags
.
max_seq_len
]
prefix_word_id
=
map
(
int
,
prefix_seg
[
2
].
split
())[:
self
.
_flags
.
max_seq_len
]
loc_gh
,
bound_gh
=
prefix_seg
[
3
].
split
(
'_'
)
ghid
=
map
(
int
,
loc_gh
.
split
(
','
))
for
i
in
range
(
len
(
prefix_loc_geoid
)):
prefix_loc_geoid
[
i
]
=
prefix_loc_geoid
[
i
]
|
ghid
[
i
]
ghid
=
map
(
int
,
bound_gh
.
split
(
','
))
for
i
in
range
(
len
(
prefix_loc_geoid
)):
prefix_loc_geoid
[
i
]
=
prefix_loc_geoid
[
i
]
|
ghid
[
i
]
else
:
#raw text
prefix_letter_id
,
prefix_word_id
=
self
.
_get_ids
(
prefix
)
ghid
=
map
(
int
,
loc_gh
.
split
(
','
))
for
i
in
range
(
len
(
prefix_loc_geoid
)):
prefix_loc_geoid
[
i
]
=
prefix_loc_geoid
[
i
]
|
ghid
[
i
]
ghid
=
map
(
int
,
bound_gh
.
split
(
','
))
for
i
in
range
(
len
(
prefix_loc_geoid
)):
prefix_loc_geoid
[
i
]
=
prefix_loc_geoid
[
i
]
|
ghid
[
i
]
prefix_input
=
[(
"prefix_letter_id"
,
prefix_letter_id
),
\
(
"prefix_loc_geoid"
,
prefix_loc_geoid
)]
if
self
.
_flags
.
prefix_word_id
:
prefix_input
.
insert
(
1
,
(
"prefix_word_id"
,
prefix_word_id
))
if
self
.
_flags
.
use_personal
:
prefix_input
.
append
((
"prefix_time_id"
,
[
prefix_time_id
]))
#step3
pos_ids
=
self
.
_get_poi_ids
(
pos_poi
)
pos_num
=
len
(
pos_ids
)
max_num
=
0
if
self
.
is_training
:
max_num
=
max
(
20
,
self
.
_flags
.
neg_sample_num
)
#last 10 is neg sample
neg_ids
=
self
.
_get_poi_ids
(
neg_poi
,
max_num
=
max_num
)
#if not train, add all pois
if
not
self
.
is_training
:
pos_ids
.
extend
(
neg_ids
)
if
len
(
pos_ids
)
<
1
:
pos_ids
.
append
([[
0
],
[
0
],
[
0
],
[
0
],
[
0
]
*
40
,
[
0
]])
#step4
idx
=
0
for
pos_id
in
pos_ids
:
pos_input
=
[(
"pos_name_letter_id"
,
pos_id
[
0
]),
\
(
"pos_name_word_id"
,
pos_id
[
1
]),
\
(
"pos_addr_letter_id"
,
pos_id
[
2
]),
\
(
"pos_addr_word_id"
,
pos_id
[
3
]),
\
(
"pos_loc_geoid"
,
pos_id
[
4
])]
if
self
.
_flags
.
use_personal
:
pos_input
.
append
((
"pos_tag_id"
,
pos_id
[
5
]))
if
self
.
is_training
:
if
len
(
neg_ids
)
>
self
.
_flags
.
neg_sample_num
:
#Noise Contrastive Estimation
#if self._flags.neg_sample_num > 3:
# nids_sample = neg_ids[:3]
nids_sample
=
random
.
sample
(
neg_ids
,
self
.
_flags
.
neg_sample_num
)
else
:
nids_sample
=
neg_ids
if
self
.
_flags
.
reader_batch
:
if
len
(
nids_sample
)
!=
self
.
_flags
.
neg_sample_num
:
continue
neg_batch
=
[[],
[],
[],
[],
[],
[]]
for
neg_id
in
nids_sample
:
for
i
in
range
(
len
(
neg_batch
)):
neg_batch
[
i
].
append
(
neg_id
[
i
])
neg_input
=
[(
"neg_name_letter_id"
,
neg_batch
[
0
]),
\
(
"neg_name_word_id"
,
neg_batch
[
1
]),
\
(
"neg_addr_letter_id"
,
neg_batch
[
2
]),
\
(
"neg_addr_word_id"
,
neg_batch
[
3
]),
\
(
"neg_loc_geoid"
,
neg_batch
[
4
])]
if
self
.
_flags
.
use_personal
:
neg_input
.
append
((
"neg_tag_id"
,
neg_batch
[
5
]))
yield
user_input
+
prefix_input
+
pos_input
+
neg_input
else
:
for
neg_id
in
nids_sample
:
neg_input
=
[(
"neg_name_letter_id"
,
neg_id
[
0
]),
\
(
"neg_name_word_id"
,
neg_id
[
1
]),
\
(
"neg_addr_letter_id"
,
neg_id
[
2
]),
\
(
"neg_addr_word_id"
,
neg_id
[
3
]),
\
(
"neg_loc_geoid"
,
neg_id
[
4
])]
if
self
.
_flags
.
use_personal
:
neg_input
.
append
((
"neg_tag_id"
,
neg_id
[
5
]))
yield
user_input
+
prefix_input
+
pos_input
+
neg_input
else
:
label
=
int
(
idx
<
pos_num
)
yield
user_input
+
prefix_input
+
pos_input
+
[(
"label"
,
[
label
])]
idx
+=
1
if
__name__
==
'__main__'
:
from
utils
import
flags
from
utils.load_conf_file
import
LoadConfFile
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_custom
(
"conf_file"
,
"./conf/test/test.conf"
,
"conf file"
,
action
=
LoadConfFile
,
sec_name
=
"Train"
)
sys
.
stderr
.
write
(
'----------- Configuration Arguments -----------
\n
'
)
for
arg
,
value
in
sorted
(
flags
.
get_flags_dict
().
items
()):
sys
.
stderr
.
write
(
'%s: %s
\n
'
%
(
arg
,
value
))
sys
.
stderr
.
write
(
'------------------------------------------------
\n
'
)
dataset_instance
=
PoiQacPersonalized
(
FLAGS
)
def
_dump_vec
(
data
,
name
):
print
(
"%s
\t
%s"
%
(
name
,
" "
.
join
(
map
(
str
,
np
.
array
(
data
)))))
def
_data_generator
():
"""
stdin sample generator: read from stdin
"""
for
line
in
sys
.
stdin
:
if
not
line
.
strip
():
continue
yield
line
if
FLAGS
.
reader_batch
:
for
sample
in
dataset_instance
.
parse_batch
(
_data_generator
):
_dump_vec
(
sample
[
0
][
1
],
'user_clk_geoid'
)
_dump_vec
(
sample
[
1
][
1
],
'user_tag_id'
)
_dump_vec
(
sample
[
2
][
1
],
'user_resident_geoid'
)
_dump_vec
(
sample
[
3
][
1
],
'user_navi_drive'
)
_dump_vec
(
sample
[
4
][
1
],
'prefix_letter_id'
)
_dump_vec
(
sample
[
5
][
1
],
'prefix_loc_geoid'
)
_dump_vec
(
sample
[
6
][
1
],
'prefix_time_id'
)
_dump_vec
(
sample
[
7
][
1
],
'pos_name_letter_id'
)
_dump_vec
(
sample
[
10
][
1
],
'pos_addr_word_id'
)
_dump_vec
(
sample
[
11
][
1
],
'pos_loc_geoid'
)
_dump_vec
(
sample
[
12
][
1
],
'pos_tag_id'
)
_dump_vec
(
sample
[
13
][
1
],
'neg_name_letter_id or label'
)
else
:
for
line
in
sys
.
stdin
:
for
sample
in
dataset_instance
.
parse_oneline
(
line
):
_dump_vec
(
sample
[
0
][
1
],
'user_clk_geoid'
)
_dump_vec
(
sample
[
1
][
1
],
'user_tag_id'
)
_dump_vec
(
sample
[
2
][
1
],
'user_resident_geoid'
)
_dump_vec
(
sample
[
3
][
1
],
'user_navi_drive'
)
_dump_vec
(
sample
[
4
][
1
],
'prefix_letter_id'
)
_dump_vec
(
sample
[
5
][
1
],
'prefix_loc_geoid'
)
_dump_vec
(
sample
[
6
][
1
],
'prefix_time_id'
)
_dump_vec
(
sample
[
7
][
1
],
'pos_name_letter_id'
)
_dump_vec
(
sample
[
10
][
1
],
'pos_addr_word_id'
)
_dump_vec
(
sample
[
11
][
1
],
'pos_loc_geoid'
)
_dump_vec
(
sample
[
12
][
1
],
'pos_tag_id'
)
_dump_vec
(
sample
[
13
][
1
],
'neg_name_letter_id or label'
)
PaddleST/Research/KDD2020-P3AC/docs/framework.png
0 → 100644
浏览文件 @
8e4eebfc
1.2 MB
PaddleST/Research/KDD2020-P3AC/nets/poi_qac_personalized/__init__.py
0 → 100644
浏览文件 @
8e4eebfc
PaddleST/Research/KDD2020-P3AC/nets/poi_qac_personalized/qac_personalized.py
0 → 100644
浏览文件 @
8e4eebfc
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Specify the brief poi_qac_personalized.py
"""
import
math
import
numpy
as
np
import
logging
import
collections
import
paddle.fluid
as
fluid
from
nets.base_net
import
BaseNet
def
ffn
(
input
,
d_hid
,
d_size
,
name
=
"ffn"
):
"""
Position-wise Feed-Forward Network
"""
hidden
=
fluid
.
layers
.
fc
(
input
=
input
,
size
=
d_hid
,
num_flatten_dims
=
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_innerfc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_innerfc_bias'
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)),
act
=
"leaky_relu"
)
out
=
fluid
.
layers
.
fc
(
input
=
hidden
,
size
=
d_size
,
num_flatten_dims
=
1
,
param_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_outerfc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
name
+
'_outerfc_bias'
,
initializer
=
fluid
.
initializer
.
Constant
(
0.
)))
return
out
def
dot_product_attention
(
query
,
key
,
value
,
d_key
,
q_mask
=
None
,
k_mask
=
None
,
dropout_rate
=
None
):
"""
Args:
query: a tensor with shape [batch, Q_time, Q_dimension]
key: a tensor with shape [batch, time, K_dimension]
value: a tensor with shape [batch, time, V_dimension]
q_lengths: a tensor with shape [batch]
k_lengths: a tensor with shape [batch]
Returns:
a tensor with shape [batch, query_time, value_dimension]
Raises:
AssertionError: if Q_dimension not equal to K_dimension when attention
type is dot.
"""
logits
=
fluid
.
layers
.
matmul
(
x
=
query
,
y
=
key
,
transpose_y
=
True
,
alpha
=
d_key
**
(
-
0.5
))
if
(
q_mask
is
not
None
)
and
(
k_mask
is
not
None
):
mask
=
fluid
.
layers
.
matmul
(
x
=
q_mask
,
y
=
k_mask
,
transpose_y
=
True
)
another_mask
=
fluid
.
layers
.
scale
(
mask
,
scale
=
float
(
2
**
32
-
1
),
bias
=
float
(
-
1
),
bias_after_scale
=
False
)
logits
=
mask
*
logits
+
another_mask
attention
=
fluid
.
layers
.
softmax
(
logits
)
if
dropout_rate
:
attention
=
fluid
.
layers
.
dropout
(
input
=
attention
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
atten_out
=
fluid
.
layers
.
matmul
(
x
=
attention
,
y
=
value
)
return
atten_out
def
safe_cosine_sim
(
x
,
y
):
"""
fluid.layers.cos_sim maybe nan
avoid nan
"""
l2x
=
fluid
.
layers
.
l2_normalize
(
x
,
axis
=-
1
)
l2y
=
fluid
.
layers
.
l2_normalize
(
y
,
axis
=-
1
)
cos
=
fluid
.
layers
.
reduce_sum
(
l2x
*
l2y
,
dim
=
1
,
keep_dim
=
True
)
return
cos
def
loss_neg_log_of_pos
(
pos_score
,
neg_score_n
,
gama
=
5.0
):
'''
pos_score: batch_size x 1
neg_score_n: batch_size x n
'''
# n x batch_size
neg_score_n
=
fluid
.
layers
.
transpose
(
neg_score_n
,
[
1
,
0
])
# 1 x batch_size
pos_score
=
fluid
.
layers
.
reshape
(
pos_score
,
[
1
,
-
1
])
exp_pos_score
=
fluid
.
layers
.
exp
(
pos_score
*
gama
)
exp_neg_score_n
=
fluid
.
layers
.
exp
(
neg_score_n
*
gama
)
## (n+1) x batch_size
pos_neg_score
=
fluid
.
layers
.
concat
([
exp_pos_score
,
exp_neg_score_n
],
axis
=
0
)
## 1 x batch_size
exp_sum
=
fluid
.
layers
.
reduce_sum
(
pos_neg_score
,
dim
=
0
,
keep_dim
=
True
)
## 1 x batch_size
loss
=
-
1.0
*
fluid
.
layers
.
log
(
exp_pos_score
/
exp_sum
)
# batch_size
loss
=
fluid
.
layers
.
reshape
(
loss
,
[
-
1
,
1
])
#return [loss, exp_pos_score, exp_neg_score_n, pos_neg_score, exp_sum]
return
loss
def
loss_pairwise_hinge
(
pos
,
neg
,
margin
=
0.8
):
"""
pairwise
"""
loss_part1
=
fluid
.
layers
.
elementwise_sub
(
fluid
.
layers
.
fill_constant_batch_size_like
(
input
=
pos
,
shape
=
[
-
1
,
1
],
value
=
margin
,
dtype
=
'float32'
),
pos
)
loss_part2
=
fluid
.
layers
.
elementwise_add
(
loss_part1
,
neg
)
loss_part3
=
fluid
.
layers
.
elementwise_max
(
fluid
.
layers
.
fill_constant_batch_size_like
(
input
=
loss_part2
,
shape
=
[
-
1
,
1
],
value
=
0.0
,
dtype
=
'float32'
),
loss_part2
)
return
loss_part3
class
PoiQacPersonalized
(
BaseNet
):
"""
This module provide nets for poi classification
"""
def
__init__
(
self
,
FLAGS
):
super
(
PoiQacPersonalized
,
self
).
__init__
(
FLAGS
)
self
.
hid_dim
=
128
def
net
(
self
,
inputs
):
"""
PoiQacPersonalized interface
"""
# debug output info during training
debug_output
=
collections
.
OrderedDict
()
model_output
=
{}
net_output
=
{
"debug_output"
:
debug_output
,
"model_output"
:
model_output
}
user_input_keys
=
[
'user_clk_geoid'
,
'user_tag_id'
,
'user_resident_geoid'
,
'user_navi_drive'
]
pred_input_keys
=
[
'prefix_letter_id'
,
'prefix_loc_geoid'
,
'pos_name_letter_id'
,
'pos_name_word_id'
,
'pos_addr_letter_id'
,
'pos_addr_word_id'
,
'pos_loc_geoid'
]
query_key_num
=
2
if
self
.
_flags
.
use_personal
:
pred_input_keys
.
insert
(
2
,
'prefix_time_id'
)
pred_input_keys
.
append
(
'pos_tag_id'
)
query_key_num
+=
2
if
self
.
_flags
.
prefix_word_id
:
pred_input_keys
.
insert
(
1
,
'prefix_word_id'
)
query_key_num
+=
1
pred_input_keys
=
user_input_keys
+
pred_input_keys
query_key_num
+=
len
(
user_input_keys
)
elif
self
.
_flags
.
prefix_word_id
:
pred_input_keys
.
insert
(
1
,
'prefix_word_id'
)
query_key_num
+=
1
#for p in pred_input_keys:
# debug_output[p] = inputs[p]
prefix_vec
,
prefix_pool
=
self
.
_get_query_vec
(
inputs
)
pos_vec
,
pos_pool
=
self
.
_get_poi_vec
(
inputs
,
'pos'
)
pos_score
=
safe_cosine_sim
(
pos_vec
,
prefix_vec
)
#fluid.layers.Print(pos_score, summarize=10000)
if
self
.
is_training
:
neg_vec
,
neg_pool
=
self
.
_get_poi_vec
(
inputs
,
'neg'
)
if
self
.
_flags
.
loss_func
==
'log_exp'
:
neg_vec
=
fluid
.
layers
.
reshape
(
neg_vec
,
[
-
1
,
self
.
_flags
.
fc_dim
])
prefix_expand
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
expand
(
prefix_vec
,
[
1
,
self
.
_flags
.
neg_sample_num
]),
[
-
1
,
self
.
_flags
.
fc_dim
])
neg_score
=
safe_cosine_sim
(
neg_vec
,
prefix_expand
)
cost
=
loss_neg_log_of_pos
(
pos_score
,
fluid
.
layers
.
reshape
(
neg_score
,
[
-
1
,
self
.
_flags
.
neg_sample_num
]),
15
)
else
:
neg_score
=
safe_cosine_sim
(
neg_vec
,
prefix_vec
)
cost
=
loss_pairwise_hinge
(
pos_score
,
neg_score
,
self
.
_flags
.
margin
)
#debug_output["pos_score"] = pos_score
#debug_output["neg_score"] = neg_score
#debug_output['prefix_pool'] = prefix_pool
#debug_output['pos_pool'] = pos_pool
#debug_output['neg_pool'] = neg_pool
loss
=
fluid
.
layers
.
mean
(
x
=
cost
)
if
self
.
_flags
.
init_learning_rate
>
0
:
# define the optimizer
#d_model = 1 / (warmup_steps * (learning_rate ** 2))
with
fluid
.
default_main_program
().
_lr_schedule_guard
():
learning_rate
=
fluid
.
layers
.
learning_rate_scheduler
.
noam_decay
(
self
.
_flags
.
emb_dim
,
self
.
_flags
.
learning_rate_warmup_steps
)
*
self
.
_flags
.
init_learning_rate
optimizer
=
fluid
.
optimizer
.
AdamOptimizer
(
learning_rate
=
learning_rate
,
beta1
=
self
.
_flags
.
adam_beta1
,
beta2
=
self
.
_flags
.
adam_beta2
,
epsilon
=
self
.
_flags
.
opt_epsilon
)
logging
.
info
(
"use noam_decay learning_rate_scheduler for optimizer."
)
net_output
[
"optimizer"
]
=
optimizer
net_output
[
"loss"
]
=
loss
model_output
[
'fetch_targets'
]
=
[
inputs
[
"prefix_letter_id"
],
pos_score
]
else
:
if
self
.
_flags
.
dump_vec
==
"query"
:
model_output
[
'fetch_targets'
]
=
[
prefix_vec
]
pred_input_keys
=
pred_input_keys
[:
query_key_num
]
elif
self
.
_flags
.
dump_vec
==
"poi"
:
model_output
[
'fetch_targets'
]
=
[
prefix_vec
,
pos_score
,
pos_vec
]
else
:
model_output
[
'fetch_targets'
]
=
[
inputs
[
"prefix_letter_id"
],
pos_score
,
inputs
[
"label"
]]
model_output
[
'feeded_var_names'
]
=
pred_input_keys
return
net_output
def
_get_query_vec
(
self
,
inputs
):
"""
get query & user vec
"""
if
self
.
_flags
.
use_personal
:
#user_tag_id
#embedding layer
tag_emb
=
fluid
.
layers
.
embedding
(
input
=
inputs
[
'user_tag_id'
],
is_sparse
=
True
,
size
=
[
self
.
_flags
.
tag_size
,
self
.
_flags
.
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"tagid_embedding"
,
learning_rate
=
self
.
_flags
.
emb_lr
),
padding_idx
=
0
)
tag_emb
=
fluid
.
layers
.
sequence_pool
(
tag_emb
,
pool_type
=
"sum"
)
user_clk_geoid
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
cast
(
inputs
[
'user_clk_geoid'
],
dtype
=
"float32"
),
[
-
1
,
40
])
user_resident_geoid
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
cast
(
inputs
[
'user_resident_geoid'
],
dtype
=
"float32"
),
[
-
1
,
40
])
user_profile
=
fluid
.
layers
.
cast
(
inputs
[
'user_navi_drive'
],
dtype
=
"float32"
)
user_pool
=
fluid
.
layers
.
concat
([
tag_emb
,
user_clk_geoid
,
user_resident_geoid
,
user_profile
],
axis
=
1
)
#fc layer
user_vec
=
fluid
.
layers
.
fc
(
input
=
user_pool
,
size
=
self
.
_flags
.
emb_dim
,
act
=
"leaky_relu"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'user_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'user_fc_bias'
))
#fluid.layers.Print(user_vec)
loc_vec
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
cast
(
x
=
inputs
[
'prefix_loc_geoid'
],
dtype
=
"float32"
),
[
-
1
,
40
])
if
self
.
_flags
.
model_type
==
"bilstm_net"
:
network
=
self
.
bilstm_net
elif
self
.
_flags
.
model_type
==
"bow_net"
:
network
=
self
.
bow_net
elif
self
.
_flags
.
model_type
==
"cnn_net"
:
network
=
self
.
cnn_net
elif
self
.
_flags
.
model_type
==
"lstm_net"
:
network
=
self
.
lstm_net
elif
self
.
_flags
.
model_type
==
"gru_net"
:
network
=
self
.
gru_net
else
:
raise
ValueError
(
"Unknown network type!"
)
prefix_letter_pool
=
network
(
inputs
[
"prefix_letter_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
if
self
.
_flags
.
use_attention
:
#max-pooling
prefix_letter_pool
=
fluid
.
layers
.
sequence_pool
(
prefix_letter_pool
,
pool_type
=
"max"
)
prefix_vec
=
prefix_letter_pool
if
self
.
_flags
.
prefix_word_id
:
prefix_word_pool
=
network
(
inputs
[
"prefix_word_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
if
self
.
_flags
.
use_attention
:
#max-pooling
prefix_word_pool
=
fluid
.
layers
.
sequence_pool
(
prefix_word_pool
,
pool_type
=
"max"
)
prefix_pool
=
fluid
.
layers
.
concat
([
prefix_letter_pool
,
prefix_word_pool
],
axis
=
1
)
prefix_vec
=
fluid
.
layers
.
fc
(
input
=
prefix_pool
,
size
=
self
.
hid_dim
,
act
=
"leaky_relu"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'prefix_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'prefix_fc_bias'
))
#vector layer
#fluid.layers.Print(inputs["prefix_letter_id"])
#fluid.layers.Print(inputs["prefix_word_id"])
#fluid.layers.Print(prefix_vec)
if
self
.
_flags
.
use_personal
:
#prefix_time_id
time_emb
=
fluid
.
layers
.
embedding
(
input
=
inputs
[
'prefix_time_id'
],
is_sparse
=
True
,
size
=
[
self
.
_flags
.
time_size
,
self
.
_flags
.
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"timeid_embedding"
,
learning_rate
=
self
.
_flags
.
emb_lr
))
time_emb
=
fluid
.
layers
.
sequence_pool
(
time_emb
,
pool_type
=
"sum"
)
context_pool
=
fluid
.
layers
.
concat
([
prefix_vec
,
loc_vec
,
time_emb
,
user_vec
],
axis
=
1
)
else
:
context_pool
=
fluid
.
layers
.
concat
([
prefix_vec
,
loc_vec
],
axis
=
1
)
context_vec
=
fluid
.
layers
.
fc
(
input
=
context_pool
,
size
=
self
.
_flags
.
fc_dim
,
act
=
self
.
_flags
.
activate
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'context_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'context_fc_bias'
))
return
context_vec
,
context_pool
def
_get_poi_vec
(
self
,
inputs
,
tag
):
"""
get poi vec
context layer: same with query
feature extract layer: same with query, same kernal params
vector layer: fc
"""
name_letter_pool
=
self
.
cnn_net
(
inputs
[
tag
+
"_name_letter_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
name_word_pool
=
self
.
cnn_net
(
inputs
[
tag
+
"_name_word_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
addr_letter_pool
=
self
.
cnn_net
(
inputs
[
tag
+
"_addr_letter_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
addr_word_pool
=
self
.
cnn_net
(
inputs
[
tag
+
"_addr_word_id"
],
"wordid_embedding"
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
,
hid_dim
=
self
.
hid_dim
,
fc_dim
=
0
,
emb_lr
=
self
.
_flags
.
emb_lr
)
#fc layer
loc_vec
=
fluid
.
layers
.
reshape
(
fluid
.
layers
.
cast
(
x
=
inputs
[
tag
+
'_loc_geoid'
],
dtype
=
"float32"
),
[
-
1
,
40
])
if
self
.
_flags
.
use_attention
:
addr2name_letter_att
=
dot_product_attention
(
name_letter_pool
,
addr_letter_pool
,
addr_letter_pool
,
self
.
hid_dim
)
name2addr_letter_att
=
dot_product_attention
(
addr_letter_pool
,
name_letter_pool
,
name_letter_pool
,
self
.
hid_dim
)
letter_vec
=
fluid
.
layers
.
sequence_concat
([
addr2name_letter_att
,
name2addr_letter_att
])
letter_att
=
ffn
(
letter_vec
,
self
.
hid_dim
,
self
.
hid_dim
,
"inter_ffn"
)
#max-pooling
name_vec
=
fluid
.
layers
.
sequence_pool
(
letter_att
,
pool_type
=
"max"
)
addr2name_word_att
=
dot_product_attention
(
name_word_pool
,
addr_word_pool
,
addr_word_pool
,
self
.
hid_dim
)
name2addr_word_att
=
dot_product_attention
(
addr_word_pool
,
name_word_pool
,
name_word_pool
,
self
.
hid_dim
)
word_vec
=
fluid
.
layers
.
sequence_concat
([
addr2name_word_att
,
name2addr_word_att
])
word_att
=
ffn
(
word_vec
,
self
.
hid_dim
,
self
.
hid_dim
,
"inter_ffn"
)
#max-pooling
addr_vec
=
fluid
.
layers
.
sequence_pool
(
word_att
,
pool_type
=
"max"
)
else
:
name_pool
=
fluid
.
layers
.
concat
([
name_letter_pool
,
name_word_pool
],
axis
=
1
)
name_vec
=
fluid
.
layers
.
fc
(
input
=
name_pool
,
size
=
self
.
hid_dim
,
act
=
"leaky_relu"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'name_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'name_fc_bias'
))
addr_pool
=
fluid
.
layers
.
concat
([
addr_letter_pool
,
addr_word_pool
],
axis
=
1
)
addr_vec
=
fluid
.
layers
.
fc
(
input
=
addr_pool
,
size
=
self
.
hid_dim
,
act
=
"leaky_relu"
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'addr_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'addr_fc_bias'
))
if
self
.
_flags
.
use_personal
:
tag_emb
=
fluid
.
layers
.
embedding
(
input
=
inputs
[
tag
+
'_tag_id'
],
is_sparse
=
True
,
size
=
[
self
.
_flags
.
tag_size
,
self
.
_flags
.
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"tagid_embedding"
,
learning_rate
=
self
.
_flags
.
emb_lr
),
padding_idx
=
0
)
tag_emb
=
fluid
.
layers
.
sequence_pool
(
tag_emb
,
pool_type
=
"sum"
)
poi_pool
=
fluid
.
layers
.
concat
([
name_vec
,
addr_vec
,
loc_vec
,
tag_emb
],
axis
=
1
)
else
:
poi_pool
=
fluid
.
layers
.
concat
([
name_vec
,
addr_vec
,
loc_vec
],
axis
=
1
)
#vector layer
#fluid.layers.Print(inputs[tag + "_name_letter_id"])
#fluid.layers.Print(inputs[tag + "_name_word_id"])
#fluid.layers.Print(poi_pool)
poi_vec
=
fluid
.
layers
.
fc
(
input
=
poi_pool
,
size
=
self
.
_flags
.
fc_dim
,
act
=
self
.
_flags
.
activate
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'poi_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'poi_fc_bias'
))
return
poi_vec
,
poi_pool
def
train_format
(
self
,
result
,
global_step
,
epoch_id
,
batch_id
):
"""
result: one batch train narray
"""
if
global_step
==
0
or
global_step
%
self
.
_flags
.
log_every_n_steps
!=
0
:
return
#result[0] default is loss.
avg_res
=
np
.
mean
(
np
.
array
(
result
[
0
]))
vec
=
[]
for
i
in
range
(
1
,
len
(
result
)):
res
=
np
.
array
(
result
[
i
])
vec
.
append
(
"%s#%s"
%
(
res
.
shape
,
' '
.
join
(
str
(
j
)
for
j
in
res
.
flatten
())))
logging
.
info
(
"epoch[%s], global_step[%s], batch_id[%s], extra_info: "
"loss[%s], debug[%s]"
%
(
epoch_id
,
global_step
,
batch_id
,
avg_res
,
";"
.
join
(
vec
)))
def
init_params
(
self
,
place
):
"""
init embed
"""
def
_load_parameter
(
pretraining_file
,
vocab_size
,
word_emb_dim
):
pretrain_word2vec
=
np
.
zeros
([
vocab_size
,
word_emb_dim
],
dtype
=
np
.
float32
)
for
line
in
open
(
pretraining_file
,
'r'
):
id
,
_
,
vec
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
)
pretrain_word2vec
[
int
(
id
)]
=
map
(
float
,
vec
.
split
())
return
pretrain_word2vec
embedding_param
=
fluid
.
global_scope
().
find_var
(
"wordid_embedding"
).
get_tensor
()
pretrain_word2vec
=
_load_parameter
(
self
.
_flags
.
init_train_params
,
self
.
_flags
.
vocab_size
,
self
.
_flags
.
emb_dim
)
embedding_param
.
set
(
pretrain_word2vec
,
place
)
logging
.
info
(
"init pretrain word2vec:%s"
%
self
.
_flags
.
init_train_params
)
def
pred_format
(
self
,
result
,
**
kwargs
):
"""
format pred output
"""
if
result
is
None
:
return
if
result
==
'_PRE_'
:
if
self
.
_flags
.
dump_vec
not
in
(
'query'
,
'poi'
):
self
.
idx2word
=
{}
with
open
(
self
.
_flags
.
qac_dict_path
,
'r'
)
as
f
:
for
line
in
f
:
term
,
tag
,
cnt
,
is_stop
,
term_id
=
line
.
strip
(
'
\r\n
'
).
split
(
'
\t
'
)
self
.
idx2word
[
int
(
term_id
)]
=
term
return
if
result
==
'_POST_'
:
if
self
.
_flags
.
init_pretrain_model
is
not
None
:
path
=
"%s/infer_model"
%
(
self
.
_flags
.
export_dir
)
frame_env
=
kwargs
[
'frame_env'
]
fluid
.
io
.
save_inference_model
(
path
,
frame_env
.
paddle_env
[
'feeded_var_names'
],
frame_env
.
paddle_env
[
'fetch_targets'
],
frame_env
.
paddle_env
[
'exe'
],
frame_env
.
paddle_env
[
'program'
])
return
if
self
.
_flags
.
dump_vec
==
"query"
:
prefix_vec
=
np
.
array
(
result
[
0
])
for
q
in
prefix_vec
:
print
(
"qid
\t
%s"
%
(
" "
.
join
(
map
(
str
,
q
))))
elif
self
.
_flags
.
dump_vec
==
"poi"
:
poi_score
=
np
.
array
(
result
[
1
])
poi_vec
=
np
.
array
(
result
[
2
])
for
i
in
range
(
len
(
poi_score
)):
print
(
"bid
\t
%s
\t
%s"
%
(
poi_score
[
i
][
0
],
" "
.
join
(
map
(
str
,
poi_vec
[
i
]))))
else
:
prefix_id
=
result
[
0
]
pred_score
=
np
.
array
(
result
[
1
])
label
=
np
.
array
(
result
[
2
])
for
i
in
range
(
len
(
pred_score
)):
start
=
prefix_id
.
lod
()[
0
][
i
]
end
=
prefix_id
.
lod
()[
0
][
i
+
1
]
words
=
[]
for
idx
in
np
.
array
(
prefix_id
)[
start
:
end
]:
words
.
append
(
self
.
idx2word
.
get
(
idx
[
0
],
"UNK"
))
print
(
"qid_%s
\t
%s
\t
%s"
%
(
""
.
join
(
words
),
label
[
i
][
0
],
pred_score
[
i
][
0
]))
def
bow_net
(
self
,
data
,
layer_name
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
fc_dim
=
128
,
emb_lr
=
0.1
):
"""
bow net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
is_sparse
=
True
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
layer_name
,
learning_rate
=
emb_lr
),
padding_idx
=
0
)
# bow layer
bow
=
fluid
.
layers
.
sequence_pool
(
input
=
emb
,
pool_type
=
'sum'
)
#bow = fluid.layers.tanh(bow)
#bow = fluid.layers.softsign(bow)
# full connect layer
if
fc_dim
>
0
:
bow
=
fluid
.
layers
.
fc
(
input
=
bow
,
size
=
fc_dim
,
act
=
self
.
_flags
.
activate
)
return
bow
def
cnn_net
(
self
,
data
,
layer_name
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
fc_dim
=
96
,
win_size
=
3
,
emb_lr
=
0.1
):
"""
conv net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
is_sparse
=
True
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
layer_name
,
learning_rate
=
emb_lr
),
padding_idx
=
0
)
param_attr
=
fluid
.
ParamAttr
(
name
=
"conv_weight"
,
initializer
=
fluid
.
initializer
.
TruncatedNormalInitializer
(
loc
=
0.0
,
scale
=
0.1
))
bias_attr
=
fluid
.
ParamAttr
(
name
=
"conv_bias"
,
initializer
=
fluid
.
initializer
.
Constant
(
0.0
))
if
self
.
_flags
.
use_attention
:
# convolution layer
conv
=
fluid
.
layers
.
sequence_conv
(
input
=
emb
,
num_filters
=
hid_dim
,
filter_size
=
win_size
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
act
=
"leaky_relu"
)
#tanh
att
=
dot_product_attention
(
conv
,
conv
,
conv
,
hid_dim
)
conv
=
ffn
(
att
,
hid_dim
,
hid_dim
,
"intra_ffn"
)
else
:
# convolution layer
conv
=
fluid
.
nets
.
sequence_conv_pool
(
input
=
emb
,
num_filters
=
hid_dim
,
filter_size
=
win_size
,
param_attr
=
param_attr
,
bias_attr
=
bias_attr
,
act
=
"leaky_relu"
,
#tanh
pool_type
=
"max"
)
# full connect layer
if
fc_dim
>
0
:
conv
=
fluid
.
layers
.
fc
(
input
=
conv
,
size
=
fc_dim
,
act
=
self
.
_flags
.
activate
)
return
conv
def
lstm_net
(
self
,
data
,
layer_name
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
fc_dim
=
96
,
emb_lr
=
0.1
):
"""
lstm net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
is_sparse
=
True
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
layer_name
,
learning_rate
=
emb_lr
),
padding_idx
=
0
)
# Lstm layer
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'lstm_fc_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'lstm_fc_bias'
))
lstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
fc0
,
size
=
hid_dim
*
4
,
is_reverse
=
False
,
param_attr
=
fluid
.
ParamAttr
(
name
=
'lstm_weight'
),
bias_attr
=
fluid
.
ParamAttr
(
name
=
'lstm_bias'
))
# max pooling layer
lstm
=
fluid
.
layers
.
sequence_pool
(
input
=
lstm_h
,
pool_type
=
'max'
)
lstm
=
fluid
.
layers
.
tanh
(
lstm
)
# full connect layer
if
fc_dim
>
0
:
lstm
=
fluid
.
layers
.
fc
(
input
=
lstm
,
size
=
fc_dim
,
act
=
self
.
_flags
.
activate
)
return
lstm
def
bilstm_net
(
self
,
data
,
layer_name
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
fc_dim
=
96
,
emb_lr
=
0.1
):
"""
bi-Lstm net
"""
# embedding layer
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
is_sparse
=
True
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
layer_name
,
learning_rate
=
emb_lr
),
padding_idx
=
0
)
#LSTM layer
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
rfc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
4
)
lstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
fc0
,
size
=
hid_dim
*
4
,
is_reverse
=
False
)
rlstm_h
,
c
=
fluid
.
layers
.
dynamic_lstm
(
input
=
rfc0
,
size
=
hid_dim
*
4
,
is_reverse
=
True
)
# extract last layer
lstm_last
=
fluid
.
layers
.
sequence_last_step
(
input
=
lstm_h
)
rlstm_last
=
fluid
.
layers
.
sequence_last_step
(
input
=
rlstm_h
)
#lstm_last = fluid.layers.tanh(lstm_last)
#rlstm_last = fluid.layers.tanh(rlstm_last)
# concat layer
bi_lstm
=
fluid
.
layers
.
concat
(
input
=
[
lstm_last
,
rlstm_last
],
axis
=
1
)
# full connect layer
if
fc_dim
>
0
:
bi_lstm
=
fluid
.
layers
.
fc
(
input
=
bi_lstm
,
size
=
fc_dim
,
act
=
self
.
_flags
.
activate
)
return
bi_lstm
def
gru_net
(
self
,
data
,
layer_name
,
dict_dim
,
emb_dim
=
128
,
hid_dim
=
128
,
fc_dim
=
96
,
emb_lr
=
0.1
):
"""
gru net
"""
emb
=
fluid
.
layers
.
embedding
(
input
=
data
,
is_sparse
=
True
,
size
=
[
dict_dim
,
emb_dim
],
param_attr
=
fluid
.
ParamAttr
(
name
=
layer_name
,
learning_rate
=
emb_lr
),
padding_idx
=
0
)
#gru layer
fc0
=
fluid
.
layers
.
fc
(
input
=
emb
,
size
=
hid_dim
*
3
)
gru
=
fluid
.
layers
.
dynamic_gru
(
input
=
fc0
,
size
=
hid_dim
,
is_reverse
=
False
)
gru
=
fluid
.
layers
.
sequence_pool
(
input
=
gru
,
pool_type
=
'max'
)
#gru = fluid.layers.tanh(gru)
if
fc_dim
>
0
:
gru
=
fluid
.
layers
.
fc
(
input
=
gru
,
size
=
fc_dim
,
act
=
self
.
_flags
.
activate
)
return
gru
PaddleST/Research/KDD2020-P3AC/test/__init__.py
0 → 100644
浏览文件 @
8e4eebfc
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录