提交 fd7df95f 编写于 作者: M malin10

add simnet

上级 6dcb9cae
......@@ -93,7 +93,7 @@ class SingleTrainer(TranspileTrainer):
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 10 == 0 and batch_id != 0:
if batch_id % self.fetch_period == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
evaluate:
workspace: "fleetrec.models.recall.multiview-simnet"
reader:
batch_size: 2
class: "{workspace}/evaluate_reader.py"
test_data_path: "{workspace}/data/test"
train:
trainer:
# for cluster training
strategy: "async"
epochs: 2
workspace: "fleetrec.models.recall.multiview-simnet"
reader:
batch_size: 2
class: "{workspace}/reader.py"
train_data_path: "{workspace}/data/train"
dataset_class: "DataLoader"
model:
models: "{workspace}/model.py"
hyper_parameters:
use_DataLoader: True
query_encoder: "bow"
title_encoder: "bow"
query_encode_dim: 128
title_encode_dim: 128
query_slots: 1
title_slots: 1
sparse_feature_dim: 1000001
embedding_dim: 128
hidden_size: 128
learning_rate: 0.0001
optimizer: adam
save:
increment:
dirname: "increment"
epoch_interval: 1
save_last: True
inference:
dirname: "inference"
epoch_interval: 1
save_last: True
55845:q0 48327:q0 35594:q0 45144:q0 24234:q0 30304:q0 49505:q0 81291:q0 41458:q0 14444:q0 48595:pt0 33252:pt0 80121:pt0 48187:pt0 19290:pt0 86838:pt0 12952:pt0 22651:pt0 40981:pt0 93151:pt0
24310:q0 95198:q0 63888:q0 97388:q0 35618:q0 60812:q0 15200:q0 56153:q0 40836:q0 20601:q0 61771:pt0 91433:pt0 23561:pt0 5193:pt0 7638:pt0 83280:pt0 40560:pt0 3866:pt0 46393:pt0 23540:pt0
27457:q0 11157:q0 67566:q0 79598:q0 43460:q0 23949:q0 8785:q0 32809:q0 11198:q0 85918:q0 8067:pt0 30818:pt0 7356:pt0 38800:pt0 10263:pt0 71683:pt0 2327:pt0 18645:pt0 3697:pt0 59405:pt0
67244:q0 11147:q0 32445:q0 50824:q0 23953:q0 69579:q0 61298:q0 29212:q0 4404:q0 20147:q0 91983:pt0 14086:pt0 62007:pt0 48478:pt0 21500:pt0 48079:pt0 25472:pt0 80782:pt0 196:pt0 25996:pt0
23980:q0 28095:q0 76849:q0 4840:q0 13727:q0 6899:q0 14224:q0 29154:q0 67655:q0 19190:q0 55244:pt0 78364:pt0 6822:pt0 9469:pt0 88192:pt0 20879:pt0 46695:pt0 77738:pt0 56719:pt0 34339:pt0
21762:q0 45574:q0 14707:q0 91857:q0 498:q0 69851:q0 44184:q0 88230:q0 68280:q0 63441:q0 29662:pt0 67343:pt0 17316:pt0 67547:pt0 20075:pt0 42813:pt0 48618:pt0 71078:pt0 64804:pt0 71161:pt0
26983:q0 15077:q0 78400:q0 20527:q0 5551:q0 53694:q0 25733:q0 22458:q0 51732:q0 55983:q0 27832:pt0 25228:pt0 88149:pt0 42938:pt0 1728:pt0 31127:pt0 43884:pt0 88393:pt0 31921:pt0 6008:pt0
10009:q0 81206:q0 67854:q0 44704:q0 71528:q0 33799:q0 11805:q0 19961:q0 42334:q0 47131:q0 81425:pt0 18282:pt0 75162:pt0 85100:pt0 66930:pt0 58086:pt0 14809:pt0 71246:pt0 16668:pt0 40496:pt0
10494:q0 17795:q0 9906:q0 76400:q0 23409:q0 52849:q0 37389:q0 32100:q0 99920:q0 48401:q0 35078:pt0 34381:pt0 17627:pt0 96420:pt0 51059:pt0 1526:pt0 70144:pt0 76407:pt0 49928:pt0 66158:pt0
61679:q0 16128:q0 14316:q0 99879:q0 98866:q0 26097:q0 94332:q0 85755:q0 86293:q0 77971:q0 78059:pt0 58096:pt0 18534:pt0 22886:pt0 39979:pt0 50215:pt0 49305:pt0 83042:pt0 21844:pt0 20832:pt0
25212:q0 41019:q0 15221:q0 26969:q0 36669:q0 15986:q0 91749:q0 30848:q0 65210:q0 36795:q0 51801:pt0 148:pt0 64025:pt0 91107:pt0 45193:pt0 15358:pt0 37016:pt0 98657:pt0 8768:pt0 50232:pt0 1313:nt0 86725:nt0 98273:nt0 46754:nt0 53202:nt0 73359:nt0 57339:nt0 97310:nt0 95286:nt0 42304:nt0
91803:q0 22382:q0 95998:q0 79155:q0 62328:q0 36070:q0 46321:q0 49510:q0 95638:q0 57873:q0 37491:pt0 41388:pt0 41649:pt0 84972:pt0 85092:pt0 19921:pt0 53701:pt0 70145:pt0 53337:pt0 97445:pt0 52620:nt0 79645:nt0 9555:nt0 35554:nt0 60410:nt0 69824:nt0 1487:nt0 61492:nt0 57026:nt0 42018:nt0
8247:q0 70601:q0 70209:q0 27625:q0 2652:q0 44564:q0 79847:q0 75873:q0 43830:q0 25367:q0 9294:pt0 11471:pt0 56945:pt0 17886:pt0 39367:pt0 21254:pt0 59394:pt0 8827:pt0 22590:pt0 46047:pt0 66963:nt0 25474:nt0 38485:nt0 732:nt0 96098:nt0 78423:nt0 29482:nt0 63866:nt0 76600:nt0 62664:nt0
14162:q0 60298:q0 83441:q0 90760:q0 88224:q0 70442:q0 37425:q0 50530:q0 50017:q0 50288:q0 36582:pt0 87172:pt0 7095:pt0 89474:pt0 90924:pt0 58990:pt0 88493:pt0 67453:pt0 78688:pt0 42423:pt0 53442:nt0 59360:nt0 445:nt0 63133:nt0 57171:nt0 8207:nt0 8781:nt0 61454:nt0 59407:nt0 5189:nt0
95981:q0 11454:q0 73927:q0 78505:q0 25738:q0 77610:q0 34547:q0 83948:q0 87500:q0 71928:q0 38269:pt0 75996:pt0 64291:pt0 215:pt0 32570:pt0 13733:pt0 15304:pt0 67986:pt0 2283:pt0 7896:pt0 53977:nt0 63572:nt0 98439:nt0 57037:nt0 60009:nt0 92660:nt0 413:nt0 10434:nt0 13035:nt0 33110:nt0
56719:q0 31980:q0 80014:q0 10699:q0 59425:q0 53792:q0 3984:q0 25257:q0 17241:q0 82107:q0 71965:pt0 53900:pt0 84616:pt0 97909:pt0 11625:pt0 80883:pt0 40321:pt0 89692:pt0 64363:pt0 70647:pt0 5444:nt0 415:nt0 21854:nt0 94962:nt0 12220:nt0 50927:nt0 13578:nt0 52078:nt0 32889:nt0 94443:nt0
45603:q0 34278:q0 29984:q0 14052:q0 44562:q0 13997:q0 87924:q0 61856:q0 5458:q0 48804:q0 42902:pt0 28880:pt0 68089:pt0 74598:pt0 33197:pt0 76521:pt0 44762:pt0 58170:pt0 14177:pt0 21283:pt0 64523:nt0 66038:nt0 34411:nt0 88249:nt0 42915:nt0 9998:nt0 65033:nt0 70132:nt0 63762:nt0 7497:nt0
11740:q0 84220:q0 43427:q0 59656:q0 25221:q0 89764:q0 52901:q0 81268:q0 76015:q0 52799:q0 93405:pt0 32788:pt0 36498:pt0 37733:pt0 12795:pt0 55438:pt0 60294:pt0 56537:pt0 35317:pt0 25310:pt0 1499:nt0 1305:nt0 48984:nt0 57311:nt0 55083:nt0 8319:nt0 53953:nt0 83839:nt0 89471:nt0 78813:nt0
7045:q0 31725:q0 40138:q0 84358:q0 16071:q0 32227:q0 17767:q0 26566:q0 98709:q0 71006:q0 67541:pt0 92703:pt0 32306:pt0 60506:pt0 75276:pt0 35969:pt0 41749:pt0 23469:pt0 28621:pt0 35213:pt0 82816:nt0 55050:nt0 85484:nt0 76618:nt0 46177:nt0 54583:nt0 9357:nt0 87694:nt0 78601:nt0 88601:nt0
72413:q0 46396:q0 7065:q0 91955:q0 59212:q0 48775:q0 66636:q0 394:q0 82077:q0 18533:q0 58905:pt0 40190:pt0 52536:pt0 20779:pt0 76068:pt0 70402:pt0 52102:pt0 3167:pt0 72461:pt0 29606:pt0 89297:nt0 33717:nt0 78957:nt0 42046:nt0 16408:nt0 80806:nt0 19095:nt0 81176:nt0 16634:nt0 72387:nt0
#! /bin/bash
set -e
echo "begin to prepare data"
mkdir -p data/train
mkdir -p data/test
python generate_synthetic_data.py
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import io
import copy
import random
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class EvaluateReader(Reader):
def init(self):
self.query_slots = envs.get_global_env("hyper_parameters.query_slots", None, "train.model")
self.title_slots = envs.get_global_env("hyper_parameters.title_slots", None, "train.model")
self.all_slots = []
for i in range(self.query_slots):
self.all_slots.append('q' + str(i))
for i in range(self.title_slots):
self.all_slots.append('pt' + str(i))
self._all_slots_dict = dict()
for index, slot in enumerate(self.all_slots):
self._all_slots_dict[slot] = [False, index]
def generate_sample(self, line):
def data_iter():
elements = line.rstrip().split()
padding = 0
output = [(slot, []) for slot in self.all_slots]
for elem in elements:
feasign, slot = elem.split(':')
if not self._all_slots_dict.has_key(slot):
continue
self._all_slots_dict[slot][0] = True
index = self._all_slots_dict[slot][1]
output[index][1].append(int(feasign))
for slot in self._all_slots_dict:
visit, index = self._all_slots_dict[slot]
if visit:
self._all_slots_dict[slot][0] = False
else:
output[index][1].append(padding)
yield output
return data_iter
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
class Dataset:
def __init__(self):
pass
class SyntheticDataset(Dataset):
def __init__(self, sparse_feature_dim, query_slot_num, title_slot_num, dataset_size=10000):
# ids are randomly generated
self.ids_per_slot = 10
self.sparse_feature_dim = sparse_feature_dim
self.query_slot_num = query_slot_num
self.title_slot_num = title_slot_num
self.dataset_size = dataset_size
def _reader_creator(self, is_train):
def generate_ids(num, space):
return [random.randint(0, space - 1) for i in range(num)]
def reader():
for i in range(self.dataset_size):
query_slots = []
pos_title_slots = []
neg_title_slots = []
for i in range(self.query_slot_num):
qslot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
qslot = [str(fea) + ':q' + str(i) for fea in qslot]
query_slots += qslot
for i in range(self.title_slot_num):
pt_slot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
pt_slot = [str(fea) + ':pt' + str(i) for fea in pt_slot]
pos_title_slots += pt_slot
if is_train:
for i in range(self.title_slot_num):
nt_slot = generate_ids(self.ids_per_slot,
self.sparse_feature_dim)
nt_slot = [str(fea) + ':nt' + str(i) for fea in nt_slot]
neg_title_slots += nt_slot
yield query_slots + pos_title_slots + neg_title_slots
else:
yield query_slots + pos_title_slots
return reader
def train(self):
return self._reader_creator(True)
def valid(self):
return self._reader_creator(True)
def test(self):
return self._reader_creator(False)
if __name__ == '__main__':
sparse_feature_dim = 1000001
query_slots = 1
title_slots = 1
dataset_size = 10
dataset = SyntheticDataset(sparse_feature_dim, query_slots, title_slots, dataset_size)
train_reader = dataset.train()
test_reader = dataset.test()
with open("data/train/train.txt", 'w') as fout:
for data in train_reader():
fout.write(' '.join(data))
fout.write("\n")
with open("data/test/test.txt", 'w') as fout:
for data in test_reader():
fout.write(' '.join(data))
fout.write("\n")
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import math
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.layers.tensor as tensor
import paddle.fluid.layers.control_flow as cf
from fleetrec.core.utils import envs
from fleetrec.core.model import Model as ModelBase
class BowEncoder(object):
""" bow-encoder """
def __init__(self):
self.param_name = ""
def forward(self, emb):
return fluid.layers.sequence_pool(input=emb, pool_type='sum')
class CNNEncoder(object):
""" cnn-encoder"""
def __init__(self,
param_name="cnn",
win_size=3,
ksize=128,
act='tanh',
pool_type='max'):
self.param_name = param_name
self.win_size = win_size
self.ksize = ksize
self.act = act
self.pool_type = pool_type
def forward(self, emb):
return fluid.nets.sequence_conv_pool(
input=emb,
num_filters=self.ksize,
filter_size=self.win_size,
act=self.act,
pool_type=self.pool_type,
param_attr=self.param_name + ".param",
bias_attr=self.param_name + ".bias")
class GrnnEncoder(object):
""" grnn-encoder """
def __init__(self, param_name="grnn", hidden_size=128):
self.param_name = param_name
self.hidden_size = hidden_size
def forward(self, emb):
fc0 = fluid.layers.fc(input=emb,
size=self.hidden_size * 3,
param_attr=self.param_name + "_fc.w",
bias_attr=False)
gru_h = fluid.layers.dynamic_gru(
input=fc0,
size=self.hidden_size,
is_reverse=False,
param_attr=self.param_name + ".param",
bias_attr=self.param_name + ".bias")
return fluid.layers.sequence_pool(input=gru_h, pool_type='max')
class SimpleEncoderFactory(object):
def __init__(self):
pass
''' create an encoder through create function '''
def create(self, enc_type, enc_hid_size):
if enc_type == "bow":
bow_encode = BowEncoder()
return bow_encode
elif enc_type == "cnn":
cnn_encode = CNNEncoder(ksize=enc_hid_size)
return cnn_encode
elif enc_type == "gru":
rnn_encode = GrnnEncoder(hidden_size=enc_hid_size)
return rnn_encode
class Model(ModelBase):
def __init__(self, config):
ModelBase.__init__(self, config)
self.init_config()
def init_config(self):
self._fetch_interval = 1
query_encoder = envs.get_global_env("hyper_parameters.query_encoder", None, self._namespace)
title_encoder = envs.get_global_env("hyper_parameters.title_encoder", None, self._namespace)
query_encode_dim = envs.get_global_env("hyper_parameters.query_encode_dim", None, self._namespace)
title_encode_dim = envs.get_global_env("hyper_parameters.title_encode_dim", None, self._namespace)
query_slots = envs.get_global_env("hyper_parameters.query_slots", None, self._namespace)
title_slots = envs.get_global_env("hyper_parameters.title_slots", None, self._namespace)
factory = SimpleEncoderFactory()
self.query_encoders = [
factory.create(query_encoder, query_encode_dim)
for i in range(query_slots)
]
self.title_encoders = [
factory.create(title_encoder, title_encode_dim)
for i in range(title_slots)
]
self.emb_size = envs.get_global_env("hyper_parameters.sparse_feature_dim", None, self._namespace)
self.emb_dim = envs.get_global_env("hyper_parameters.embedding_dim", None, self._namespace)
self.emb_shape = [self.emb_size, self.emb_dim]
self.hidden_size = envs.get_global_env("hyper_parameters.hidden_size", None, self._namespace)
self.margin = 0.1
def input(self, is_train=True):
self.q_slots = [
fluid.data(
name="q%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.query_encoders))
]
self.pt_slots = [
fluid.data(
name="pt%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
if is_train == False:
return self.q_slots + self.pt_slots
self.nt_slots = [
fluid.data(
name="nt%d" % i, shape=[None, 1], lod_level=1, dtype='int64')
for i in range(len(self.title_encoders))
]
return self.q_slots + self.pt_slots + self.nt_slots
def train_input(self):
res = self.input()
self._data_var = res
use_dataloader = envs.get_global_env("hyper_parameters.use_DataLoader", False, self._namespace)
if self._platform != "LINUX" or use_dataloader:
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, capacity=256, use_double_buffer=False, iterable=False)
def get_acc(self, x, y):
less = tensor.cast(cf.less_than(x, y), dtype='float32')
label_ones = fluid.layers.fill_constant_batch_size_like(
input=x, dtype='float32', shape=[-1, 1], value=1.0)
correct = fluid.layers.reduce_sum(less)
total = fluid.layers.reduce_sum(label_ones)
acc = fluid.layers.elementwise_div(correct, total)
return acc
def net(self):
q_embs = [
fluid.embedding(
input=query, size=self.emb_shape, param_attr="emb")
for query in self.q_slots
]
pt_embs = [
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in self.pt_slots
]
nt_embs = [
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in self.nt_slots
]
# encode each embedding field with encoder
q_encodes = [
self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs)
]
pt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
]
nt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(nt_embs)
]
# concat multi view for query, pos_title, neg_title
q_concat = fluid.layers.concat(q_encodes)
pt_concat = fluid.layers.concat(pt_encodes)
nt_concat = fluid.layers.concat(nt_encodes)
# projection of hidden layer
q_hid = fluid.layers.fc(q_concat,
size=self.hidden_size,
param_attr='q_fc.w',
bias_attr='q_fc.b')
pt_hid = fluid.layers.fc(pt_concat,
size=self.hidden_size,
param_attr='t_fc.w',
bias_attr='t_fc.b')
nt_hid = fluid.layers.fc(nt_concat,
size=self.hidden_size,
param_attr='t_fc.w',
bias_attr='t_fc.b')
# cosine of hidden layers
cos_pos = fluid.layers.cos_sim(q_hid, pt_hid)
cos_neg = fluid.layers.cos_sim(q_hid, nt_hid)
# pairwise hinge_loss
loss_part1 = fluid.layers.elementwise_sub(
tensor.fill_constant_batch_size_like(
input=cos_pos,
shape=[-1, 1],
value=self.margin,
dtype='float32'),
cos_pos)
loss_part2 = fluid.layers.elementwise_add(loss_part1, cos_neg)
loss_part3 = fluid.layers.elementwise_max(
tensor.fill_constant_batch_size_like(
input=loss_part2, shape=[-1, 1], value=0.0, dtype='float32'),
loss_part2)
self.avg_cost = fluid.layers.mean(loss_part3)
self.acc = self.get_acc(cos_neg, cos_pos)
def avg_loss(self):
self._cost = self.avg_cost
def metrics(self):
self._metrics["loss"] = self.avg_cost
self._metrics["acc"] = self.acc
def train_net(self):
self.train_input()
self.net()
self.avg_loss()
self.metrics()
def optimizer(self):
learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None, self._namespace)
optimizer = fluid.optimizer.Adam(learning_rate=learning_rate)
return optimizer
def infer_input(self):
res = self.input(is_train=False)
self._infer_data_var = res
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var, capacity=64, use_double_buffer=False, iterable=False)
def infer_net(self):
self.infer_input()
# lookup embedding for each slot
q_embs = [
fluid.embedding(
input=query, size=self.emb_shape, param_attr="emb")
for query in self.q_slots
]
pt_embs = [
fluid.embedding(
input=title, size=self.emb_shape, param_attr="emb")
for title in self.pt_slots
]
# encode each embedding field with encoder
q_encodes = [
self.query_encoders[i].forward(emb) for i, emb in enumerate(q_embs)
]
pt_encodes = [
self.title_encoders[i].forward(emb) for i, emb in enumerate(pt_embs)
]
# concat multi view for query, pos_title, neg_title
q_concat = fluid.layers.concat(q_encodes)
pt_concat = fluid.layers.concat(pt_encodes)
# projection of hidden layer
q_hid = fluid.layers.fc(q_concat,
size=self.hidden_size,
param_attr='q_fc.w',
bias_attr='q_fc.b')
pt_hid = fluid.layers.fc(pt_concat,
size=self.hidden_size,
param_attr='t_fc.w',
bias_attr='t_fc.b')
# cosine of hidden layers
cos = fluid.layers.cos_sim(q_hid, pt_hid)
self._infer_results['query_pt_sim'] = cos
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import io
import copy
import random
from fleetrec.core.reader import Reader
from fleetrec.core.utils import envs
class TrainReader(Reader):
def init(self):
self.query_slots = envs.get_global_env("hyper_parameters.query_slots", None, "train.model")
self.title_slots = envs.get_global_env("hyper_parameters.title_slots", None, "train.model")
self.all_slots = []
for i in range(self.query_slots):
self.all_slots.append('q' + str(i))
for i in range(self.title_slots):
self.all_slots.append('pt' + str(i))
for i in range(self.title_slots):
self.all_slots.append('nt' + str(i))
self._all_slots_dict = dict()
for index, slot in enumerate(self.all_slots):
self._all_slots_dict[slot] = [False, index]
def generate_sample(self, line):
def data_iter():
elements = line.rstrip().split()
padding = 0
output = [(slot, []) for slot in self.all_slots]
for elem in elements:
feasign, slot = elem.split(':')
if not self._all_slots_dict.has_key(slot):
continue
self._all_slots_dict[slot][0] = True
index = self._all_slots_dict[slot][1]
output[index][1].append(int(feasign))
for slot in self._all_slots_dict:
visit, index = self._all_slots_dict[slot]
if visit:
self._all_slots_dict[slot][0] = False
else:
output[index][1].append(padding)
yield output
return data_iter
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册