提交 70150f6a 编写于 作者: D dongdaxiang

add ssr training part

上级 ebb16616
......@@ -51,23 +51,27 @@ class CNNEncoder(object):
filter_size=self.win_size,
act=self.act,
pool_type=self.pool_type,
attr=self.param_name)
param_attr=str(self.param_name))
class GrnnEncoder(object):
""" grnn-encoder """
def __init__(self, param_name="grnn.w", hidden_size=128):
self.param_name = args
self.param_name = param_name
self.hidden_size = hidden_size
def forward(self, emb):
fc0 = nn.fc(input=emb, size=self.hidden_size * 3)
fc0 = nn.fc(
input=emb,
size=self.hidden_size * 3,
param_attr=str(str(self.param_name) + "_fc")
)
gru_h = nn.dynamic_gru(
input=emb,
input=fc0,
size=self.hidden_size,
is_reverse=False,
attr=self.param_name)
param_attr=str(self.param_name))
return nn.sequence_pool(input=gru_h, pool_type='max')
......@@ -191,7 +195,7 @@ class MultiviewSimnet(object):
loss_part2)
avg_cost = nn.mean(loss_part3)
correct = self.get_correct(cos_pos, cos_neg)
correct = self.get_correct(cos_neg, cos_pos)
return q_slots + pt_slots + nt_slots, avg_cost, correct
......
......@@ -13,7 +13,16 @@ Sequence Semantic Retrieval(SSR) Model shares the similar idea with Multi-Rate D
- With the representation of news items, we are able to build an vector indexing service online for news prediction and this is the retrieval part of SSR.
## Dataset
Dataset preprocessing follows the method of [GRU4Rec Project](https://github.com/PaddlePaddle/models/tree/develop/fluid/PaddleRec/gru4rec). Note that you should reuse scripts from GRU4Rec project for data preprocessing.
## Training
The command line options for training can be listed by `python train.py -h`
``` bash
fluid train.py --train_file rsc15_train_tr_paddle.txt
```
## Build Index
TBA
## Retrieval
TBA
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import paddle.fluid.layers.nn as nn
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
import paddle.fluid.layers.io as io
from PaddleRec.multiview_simnet.nets import BowEncoder
from PaddleRec.multiview_simnet.nets import GrnnEncoder
from PaddleRec.multiview_simnet.nets import MultiviewSimnet
class PairwiseHingeLoss(object):
def __init__(self, margin=0.8):
self.margin = margin
def forward(self, pos, neg):
loss_part1 = nn.elementwise_sub(
tensor.fill_constant_batch_size_like(
input=pos,
shape=[-1, 1],
value=self.margin,
dtype='float32'),
pos)
loss_part2 = nn.elementwise_add(loss_part1, neg)
loss_part3 = nn.elementwise_max(
tensor.fill_constant_batch_size_like(
input=loss_part2,
shape=[-1, 1],
value=0.0,
dtype='float32'),
loss_part2)
return loss_part3
class SequenceSemanticRetrieval(object):
""" sequence semantic retrieval model """
def __init__(self, embedding_size, embedding_dim, hidden_size):
self.embedding_size = embedding_size
self.embedding_dim = embedding_dim
self.emb_shape = [self.embedding_size, self.embedding_dim]
self.hidden_size = hidden_size
self.user_encoder = GrnnEncoder(hidden_size=hidden_size)
self.item_encoder = BowEncoder()
self.pairwise_hinge_loss = PairwiseHingeLoss()
def get_correct(self, x, y):
less = tensor.cast(cf.less_than(x, y), dtype='float32')
correct = nn.reduce_sum(less)
return correct
def train(self):
user_data = io.data(
name="user", shape=[1], dtype="int64", lod_level=1
)
pos_item_data = io.data(
name="p_item", shape=[1], dtype="int64", lod_level=1
)
neg_item_data = io.data(
name="n_item", shape=[1], dtype="int64", lod_level=1
)
user_emb = nn.embedding(
input=user_data, size=self.emb_shape, param_attr="emb.item"
)
pos_item_emb = nn.embedding(
input=pos_item_data, size=self.emb_shape, param_attr="emb.item"
)
neg_item_emb = nn.embedding(
input=neg_item_data, size=self.emb_shape, param_attr="emb.item"
)
user_enc = self.user_encoder.forward(user_emb)
pos_item_enc = self.item_encoder.forward(pos_item_emb)
neg_item_enc = self.item_encoder.forward(neg_item_emb)
user_hid = nn.fc(
input=user_enc, size=self.hidden_size, param_attr='user.w', bias_attr="user.b"
)
pos_item_hid = nn.fc(
input=pos_item_enc, size=self.hidden_size, param_attr='item.w', bias_attr="item.b"
)
neg_item_hid = nn.fc(
input=neg_item_enc, size=self.hidden_size, param_attr='item.w', bias_attr="item.b"
)
cos_pos = nn.cos_sim(user_hid, pos_item_hid)
cos_neg = nn.cos_sim(user_hid, neg_item_hid)
hinge_loss = self.pairwise_hinge_loss.forward(cos_pos, cos_neg)
avg_cost = nn.mean(hinge_loss)
correct = self.get_correct(cos_neg, cos_pos)
return [user_data, pos_item_data, neg_item_data], \
pos_item_hid, neg_item_hid, avg_cost, correct
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
class Dataset:
def __init__(self):
pass
class Vocab:
def __init__(self):
pass
class YoochooseVocab(Vocab):
def __init__(self):
self.vocab = {}
self.word_array = []
def load(self, filelist):
idx = 0
for f in filelist:
with open(f, "r") as fin:
for line in fin:
group = line.strip().split()
for item in group:
if item not in self.vocab:
self.vocab[item] = idx
self.word_array.append(idx)
idx += 1
else:
self.word_array.append(self.vocab[item])
def get_vocab(self):
return self.vocab
def _get_word_array(self):
return self.word_array
class YoochooseDataset(Dataset):
def __init__(self, y_vocab):
self.vocab_size = len(y_vocab.get_vocab())
self.word_array = y_vocab._get_word_array()
self.vocab = y_vocab.get_vocab()
def sample_neg(self):
return random.randint(0, self.vocab_size - 1)
def sample_neg_from_seq(self, seq):
return seq[random.randint(0, len(seq) - 1)]
# TODO(guru4elephant): wait memory, should be improved
def sample_from_word_freq(self):
return self.word_array[random.randint(0, len(self.word_array) - 1)]
def _reader_creator(self, filelist, is_train):
def reader():
for f in filelist:
with open(f, 'r') as fin:
line_idx = 0
for line in fin:
ids = line.strip().split()
if len(ids) <= 1:
continue
conv_ids = [self.vocab[i] if i in self.vocab else 0 for i in ids]
# random select an index as boundary
# make ids before boundary as sequence
# make id next to boundary right as target
boundary = random.randint(1, len(ids) - 1)
src = conv_ids[:boundary]
pos_tgt = [conv_ids[boundary]]
if is_train:
neg_tgt = [self.sample_from_word_freq()]
yield [src, pos_tgt, neg_tgt]
else:
yield [src, pos_tgt]
return reader
def train(self, file_list):
return self._reader_creator(file_list, True)
def test(self, file_list):
return self._reader_creator(file_list, False)
#Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import argparse
import logging
import paddle.fluid as fluid
import paddle
import reader as reader
from nets import SequenceSemanticRetrieval
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser("sequence semantic retrieval")
parser.add_argument("--train_file", type=str, help="Training file")
parser.add_argument("--valid_file", type=str, help="Validation file")
parser.add_argument(
"--epochs", type=int, default=10, help="Number of epochs for training")
parser.add_argument(
"--model_output_dir",
type=str,
default='model_output',
help="Model output folder")
parser.add_argument(
"--sequence_encode_dim",
type=int,
default=128,
help="Dimension of sequence encoder output")
parser.add_argument(
"--matching_dim",
type=int,
default=128,
help="Dimension of hidden layer")
parser.add_argument(
"--batch_size", type=int, default=128, help="Batch size for training")
parser.add_argument(
"--embedding_dim",
type=int,
default=128,
help="Default Dimension of Embedding")
return parser.parse_args()
def start_train(args):
y_vocab = reader.YoochooseVocab()
y_vocab.load([args.train_file])
logger.info("Load yoochoose vocabulary size: {}".format(len(y_vocab.get_vocab())))
y_data = reader.YoochooseDataset(y_vocab)
train_reader = paddle.batch(
paddle.reader.shuffle(
y_data.train([args.train_file]), buf_size=args.batch_size * 100),
batch_size=args.batch_size)
place = fluid.CPUPlace()
ssr = SequenceSemanticRetrieval(
len(y_vocab.get_vocab()), args.embedding_dim, args.matching_dim
)
input_data, user_rep, item_rep, avg_cost, acc = ssr.train()
optimizer = fluid.optimizer.Adam(learning_rate=1e-4)
optimizer.minimize(avg_cost)
startup_program = fluid.default_startup_program()
loop_program = fluid.default_main_program()
data_list = [var.name for var in input_data]
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place)
exe.run(startup_program)
for pass_id in range(args.epochs):
for batch_id, data in enumerate(train_reader()):
loss_val, correct_val = exe.run(loop_program,
feed=feeder.feed(data),
fetch_list=[avg_cost, acc])
logger.info("Train --> pass: {} batch_id: {} avg_cost: {}, acc: {}".
format(pass_id, batch_id, loss_val,
float(correct_val) / args.batch_size))
fluid.io.save_inference_model(args.model_output_dir,
[var.name for val in input_data],
[user_rep, item_rep, avg_cost, acc], exe)
def main():
args = parse_args()
start_train(args)
if __name__ == "__main__":
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册