未验证 提交 69d53643 编写于 作者: T tangwei12 提交者: GitHub

Merge pull request #12 from frankwhzhang/200520_listwise

add listwise
......@@ -177,6 +177,7 @@ python -m paddlerec.run -m ./models/rank/dnn/config.yaml -b backend.yaml
| 多任务 | [ESMM](models/multitask/esmm/model.py) | ✓ | ✓ | ✓ |
| 多任务 | [MMOE](models/multitask/mmoe/model.py) | ✓ | ✓ | ✓ |
| 多任务 | [ShareBottom](models/multitask/share-bottom/model.py) | ✓ | ✓ | ✓ |
| 重排序 | [Listwise](models/rerank/listwise/model.py) | ✓ | x | ✓ |
......
......@@ -37,6 +37,10 @@ class Model(object):
self._fetch_interval = 20
self._namespace = "train.model"
self._platform = envs.get_platform()
self._init_hyper_parameters()
def _init_hyper_parameters(self):
pass
def _init_slots(self):
sparse_slots = envs.get_global_env("sparse_slots", None,
......@@ -129,12 +133,37 @@ class Model(object):
print(">>>>>>>>>>>.learnig rate: %s" % learning_rate)
return self._build_optimizer(optimizer, learning_rate)
@abc.abstractmethod
def input_data(self, is_infer=False):
return None
def net(self, is_infer=False):
return None
def _construct_reader(self, is_infer=False):
if is_infer:
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
else:
dataset_class = envs.get_global_env("dataset_class", None,
"train.reader")
if dataset_class == "DataLoader":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
def train_net(self):
"""R
"""
pass
input_data = self.input_data(is_infer=False)
self._data_var = input_data
self._construct_reader(is_infer=False)
self.net(input_data, is_infer=False)
@abc.abstractmethod
def infer_net(self):
pass
input_data = self.input_data(is_infer=True)
self._infer_data_var = input_data
self._construct_reader(is_infer=True)
self.net(input_data, is_infer=True)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
reader:
batch_size: 1
class: "{workspace}/random_infer_reader.py"
test_data_path: "{workspace}/data/train"
train:
trainer:
# for cluster training
strategy: "async"
epochs: 3
workspace: "paddlerec.models.rerank.listwise"
device: cpu
reader:
batch_size: 2
class: "{workspace}/random_reader.py"
train_data_path: "{workspace}/data/train"
dataset_class: "DataLoader"
model:
models: "{workspace}/model.py"
hyper_parameters:
hidden_size: 128
user_vocab: 200
item_vocab: 1000
item_len: 5
embed_size: 16
learning_rate: 0.01
optimizer: sgd
save:
increment:
dirname: "increment"
epoch_interval: 2
save_last: True
inference:
dirname: "inference"
epoch_interval: 4
save_last: True
4764,174,1
4764,2958,0
4764,452,0
4764,1946,0
4764,3208,0
2044,2237,1
2044,1998,0
2044,328,0
2044,1542,0
2044,1932,0
4276,65,1
4276,3247,0
4276,942,0
4276,3666,0
4276,2222,0
3933,682,1
3933,2451,0
3933,3695,0
3933,1643,0
3933,3568,0
1151,1265,1
1151,118,0
1151,2532,0
1151,2083,0
1151,2350,0
1757,876,1
1757,201,0
1757,3633,0
1757,1068,0
1757,2549,0
3370,276,1
3370,2435,0
3370,606,0
3370,910,0
3370,2146,0
5137,1018,1
5137,2163,0
5137,3167,0
5137,2315,0
5137,3595,0
3933,2831,1
3933,2881,0
3933,2949,0
3933,3660,0
3933,417,0
3102,999,1
3102,1902,0
3102,2161,0
3102,3042,0
3102,1113,0
2022,336,1
2022,1672,0
2022,2656,0
2022,3649,0
2022,883,0
2664,655,1
2664,3660,0
2664,1711,0
2664,3386,0
2664,1668,0
25,701,1
25,32,0
25,2482,0
25,3177,0
25,2767,0
1738,1643,1
1738,2187,0
1738,228,0
1738,650,0
1738,3101,0
5411,1241,1
5411,2546,0
5411,3019,0
5411,3618,0
5411,1674,0
638,579,1
638,3512,0
638,783,0
638,2111,0
638,1880,0
3554,200,1
3554,2893,0
3554,2428,0
3554,969,0
3554,2741,0
4283,1074,1
4283,3056,0
4283,2032,0
4283,405,0
4283,1505,0
5111,200,1
5111,3488,0
5111,477,0
5111,2790,0
5111,40,0
3964,515,1
3964,1528,0
3964,2173,0
3964,1701,0
3964,2832,0
4764,174,1
4764,2958,0
4764,452,0
4764,1946,0
4764,3208,0
2044,2237,1
2044,1998,0
2044,328,0
2044,1542,0
2044,1932,0
4276,65,1
4276,3247,0
4276,942,0
4276,3666,0
4276,2222,0
3933,682,1
3933,2451,0
3933,3695,0
3933,1643,0
3933,3568,0
1151,1265,1
1151,118,0
1151,2532,0
1151,2083,0
1151,2350,0
1757,876,1
1757,201,0
1757,3633,0
1757,1068,0
1757,2549,0
3370,276,1
3370,2435,0
3370,606,0
3370,910,0
3370,2146,0
5137,1018,1
5137,2163,0
5137,3167,0
5137,2315,0
5137,3595,0
3933,2831,1
3933,2881,0
3933,2949,0
3933,3660,0
3933,417,0
3102,999,1
3102,1902,0
3102,2161,0
3102,3042,0
3102,1113,0
2022,336,1
2022,1672,0
2022,2656,0
2022,3649,0
2022,883,0
2664,655,1
2664,3660,0
2664,1711,0
2664,3386,0
2664,1668,0
25,701,1
25,32,0
25,2482,0
25,3177,0
25,2767,0
1738,1643,1
1738,2187,0
1738,228,0
1738,650,0
1738,3101,0
5411,1241,1
5411,2546,0
5411,3019,0
5411,3618,0
5411,1674,0
638,579,1
638,3512,0
638,783,0
638,2111,0
638,1880,0
3554,200,1
3554,2893,0
3554,2428,0
3554,969,0
3554,2741,0
4283,1074,1
4283,3056,0
4283,2032,0
4283,405,0
4283,1505,0
5111,200,1
5111,3488,0
5111,477,0
5111,2790,0
5111,40,0
3964,515,1
3964,1528,0
3964,2173,0
3964,1701,0
3964,2832,0
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.model import Model as ModelBase
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
def _init_hyper_parameters(self):
self.item_len = envs.get_global_env("hyper_parameters.self.item_len",
None, self._namespace)
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size",
None, self._namespace)
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab",
None, self._namespace)
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab",
None, self._namespace)
self.embed_size = envs.get_global_env("hyper_parameters.embed_size",
None, self._namespace)
def input_data(self, is_infer=False):
user_slot_names = fluid.data(
name='user_slot_names',
shape=[None, 1],
dtype='int64',
lod_level=1)
item_slot_names = fluid.data(
name='item_slot_names',
shape=[None, self.item_len],
dtype='int64',
lod_level=1)
lens = fluid.data(name='lens', shape=[None], dtype='int64')
labels = fluid.data(
name='labels',
shape=[None, self.item_len],
dtype='int64',
lod_level=1)
inputs = [user_slot_names] + [item_slot_names] + [lens] + [labels]
# demo: hot to use is_infer:
if is_infer:
return inputs
else:
return inputs
def net(self, inputs, is_infer=False):
# user encode
user_embedding = fluid.embedding(
input=inputs[0],
size=[self.user_vocab, self.embed_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(1e-5)),
is_sparse=True)
user_feature = fluid.layers.fc(
input=user_embedding,
size=self.hidden_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=np.sqrt(1.0 / self.hidden_size))),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0)),
act='relu',
name='user_feature_fc')
# item encode
item_embedding = fluid.embedding(
input=inputs[1],
size=[self.item_vocab, self.embed_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(1e-5)),
is_sparse=True)
item_embedding = fluid.layers.sequence_unpad(
x=item_embedding, length=inputs[2])
item_fc = fluid.layers.fc(
input=item_embedding,
size=self.hidden_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=np.sqrt(1.0 / self.hidden_size))),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0)),
act='relu',
name='item_fc')
pos = self._fluid_sequence_get_pos(item_fc)
pos_embed = fluid.embedding(
input=pos,
size=[self.user_vocab, self.embed_size],
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Xavier(),
regularizer=fluid.regularizer.L2Decay(1e-5)),
is_sparse=True)
pos_embed = fluid.layers.squeeze(pos_embed, [1])
# item gru
gru_input = fluid.layers.fc(
input=fluid.layers.concat([item_fc, pos_embed], 1),
size=self.hidden_size * 3,
name='item_gru_fc')
# forward gru
item_gru_forward = fluid.layers.dynamic_gru(
input=gru_input,
size=self.hidden_size,
is_reverse=False,
h_0=user_feature)
# backward gru
item_gru_backward = fluid.layers.dynamic_gru(
input=gru_input,
size=self.hidden_size,
is_reverse=True,
h_0=user_feature)
item_gru = fluid.layers.concat(
[item_gru_forward, item_gru_backward], axis=1)
out_click_fc1 = fluid.layers.fc(
input=item_gru,
size=self.hidden_size,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.TruncatedNormal(
loc=0.0, scale=np.sqrt(1.0 / self.hidden_size))),
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(
value=0.0)),
act='relu',
name='out_click_fc1')
click_prob = fluid.layers.fc(input=out_click_fc1,
size=2,
act='softmax',
name='out_click_fc2')
labels = fluid.layers.sequence_unpad(x=inputs[3], length=inputs[2])
auc_val, batch_auc, auc_states = fluid.layers.auc(input=click_prob,
label=labels)
if is_infer:
self._infer_results["AUC"] = auc_val
return
loss = fluid.layers.reduce_mean(
fluid.layers.cross_entropy(
input=click_prob, label=labels))
self._cost = loss
self._metrics['auc'] = auc_val
def _fluid_sequence_pad(self, input, pad_value, maxlen=None):
"""
args:
input: (batch*seq_len, dim)
returns:
(batch, max_seq_len, dim)
"""
pad_value = fluid.layers.cast(
fluid.layers.assign(input=np.array([pad_value], 'float32')),
input.dtype)
input_padded, _ = fluid.layers.sequence_pad(
input, pad_value,
maxlen=maxlen) # (batch, max_seq_len, 1), (batch, 1)
# TODO, maxlen=300, used to solve issues: https://github.com/PaddlePaddle/Paddle/issues/14164
return input_padded
def _fluid_sequence_get_pos(self, lodtensor):
"""
args:
lodtensor: lod = [[0,4,7]]
return:
pos: lod = [[0,4,7]]
data = [0,1,2,3,0,1,3]
shape = [-1, 1]
"""
lodtensor = fluid.layers.reduce_sum(lodtensor, dim=1, keep_dim=True)
assert lodtensor.shape == (-1, 1), (lodtensor.shape())
ones = fluid.layers.cast(lodtensor * 0 + 1,
'float32') # (batch*seq_len, 1)
ones_padded = self._fluid_sequence_pad(ones,
0) # (batch, max_seq_len, 1)
ones_padded = fluid.layers.squeeze(ones_padded,
[2]) # (batch, max_seq_len)
seq_len = fluid.layers.cast(
fluid.layers.reduce_sum(
ones_padded, 1, keep_dim=True), 'int64') # (batch, 1)
seq_len = fluid.layers.squeeze(seq_len, [1])
pos = fluid.layers.cast(
fluid.layers.cumsum(
ones_padded, 1, exclusive=True), 'int64')
pos = fluid.layers.sequence_unpad(pos, seq_len) # (batch*seq_len, 1)
pos.stop_gradient = True
return pos
#def train_net(self):
# input_data = self.input_data()
# self.net(input_data)
#def infer_net(self):
# input_data = self.input_data(is_infer=True)
# self.net(input_data, is_infer=True)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
from collections import defaultdict
class EvaluateReader(Reader):
def init(self):
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab",
None, "train.model")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab",
None, "train.model")
self.item_len = envs.get_global_env("hyper_parameters.item_len", None,
"train.model")
self.batch_size = envs.get_global_env("batch_size", None,
"train.reader")
def reader_creator(self):
def reader():
user_slot_name = []
for j in range(self.batch_size):
user_slot_name.append(
[int(np.random.randint(self.user_vocab))])
item_slot_name = np.random.randint(
self.item_vocab, size=(self.batch_size,
self.item_len)).tolist()
length = [self.item_len] * self.batch_size
label = np.random.randint(
2, size=(self.batch_size, self.item_len)).tolist()
output = [user_slot_name, item_slot_name, length, label]
yield output
return reader
def generate_batch_from_trainfiles(self, files):
return fluid.io.batch(
self.reader_creator(), batch_size=self.batch_size)
def generate_sample(self, line):
"""
the file is not used
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
pass
return reader
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddlerec.core.reader import Reader
from paddlerec.core.utils import envs
from collections import defaultdict
class TrainReader(Reader):
def init(self):
self.user_vocab = envs.get_global_env("hyper_parameters.user_vocab",
None, "train.model")
self.item_vocab = envs.get_global_env("hyper_parameters.item_vocab",
None, "train.model")
self.item_len = envs.get_global_env("hyper_parameters.item_len", None,
"train.model")
self.batch_size = envs.get_global_env("batch_size", None,
"train.reader")
def reader_creator(self):
def reader():
user_slot_name = []
for j in range(self.batch_size):
user_slot_name.append(
[int(np.random.randint(self.user_vocab))])
item_slot_name = np.random.randint(
self.item_vocab, size=(self.batch_size,
self.item_len)).tolist()
length = [self.item_len] * self.batch_size
label = np.random.randint(
2, size=(self.batch_size, self.item_len)).tolist()
output = [user_slot_name, item_slot_name, length, label]
yield output
return reader
def generate_batch_from_trainfiles(self, files):
return fluid.io.batch(
self.reader_creator(), batch_size=self.batch_size)
def generate_sample(self, line):
"""
the file is not used
"""
def reader():
"""
This function needs to be implemented by the user, based on data format
"""
pass
return reader
# 重排序模型库
## 简介
我们提供了常见的重排序使用的模型算法的PaddleRec实现, 单机训练&预测效果指标以及分布式训练&预测性能指标等。目前实现的模型是 [Listwise](listwise)
模型算法库在持续添加中,欢迎关注。
## 目录
* [整体介绍](#整体介绍)
* [重排序模型列表](#重排序模型列表)
* [使用教程](#使用教程)
* [训练 预测](#训练 预测)
* [效果对比](#效果对比)
* [模型效果列表](#模型效果列表)
## 整体介绍
### 融合模型列表
| 模型 | 简介 | 论文 |
| :------------------: | :--------------------: | :---------: |
| Listwise | Listwise | [Sequential Evaluation and Generation Framework for Combinatorial Recommender System](https://arxiv.org/pdf/1902.00245.pdf)(2019) |
下面是每个模型的简介(注:图片引用自链接中的论文)
[Listwise](https://arxiv.org/pdf/1902.00245.pdf):
<p align="center">
<img align="center" src="../../doc/imgs/listwise.png">
<p>
## 使用教程
### 训练 预测
```shell
python -m paddlerec.run -m paddlerec.models.rerank.listwise # listwise
```
## 效果对比
### 模型效果列表
| 数据集 | 模型 | loss | auc |
| :------------------: | :--------------------: | :---------: |:---------: |
| -- | Listwise | -- | -- |
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册