未验证 提交 f7fe635f 编写于 作者: Q Qdriving 提交者: GitHub

Update reader.py

 Neg item and neg cat should be paried
上级 3b9c100b
......@@ -90,33 +90,23 @@ class Reader(ReaderBase):
for i in range(len(b)):
neg_item[i] = []
neg_cat[i] = []
# Neg item and neg cat should be paried
if len(self.neg_candidate_item) < self.max_neg_item:
self.neg_candidate_item.extend(b[i][0])
self.neg_candidate_cat.extend(b[i][1])
if len(self.neg_candidate_item) > self.max_neg_item:
self.neg_candidate_item = self.neg_candidate_item[
0:self.max_neg_item]
self.neg_candidate_item = self.neg_candidate_item[0:self.max_neg_item]
self.neg_candidate_cat = self.neg_candidate_cat[0:self.max_neg_item]
else:
len_seq = len(b[i][0])
start_idx = random.randint(0, self.max_neg_item - len_seq - 1)
self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[
i][0]
self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[i][0]
self.neg_candidate_cat[start_idx:start_idx + len_seq + 1] = b[i][1]
if len(self.neg_candidate_cat) < self.max_neg_cat:
self.neg_candidate_cat.extend(b[i][1])
if len(self.neg_candidate_cat) > self.max_neg_cat:
self.neg_candidate_cat = self.neg_candidate_cat[
0:self.max_neg_cat]
else:
len_seq = len(b[i][1])
start_idx = random.randint(0, self.max_neg_cat - len_seq - 1)
self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[
i][1]
for _ in range(len(b[i][0])):
neg_item[i].append(self.neg_candidate_item[random.randint(
0, len(self.neg_candidate_item) - 1)])
for _ in range(len(b[i][1])):
neg_cat[i].append(self.neg_candidate_cat[random.randint(
0, len(self.neg_candidate_cat) - 1)])
randindex = random.randint(0, len(self.neg_candidate_item) - 1)
neg_item[i].append(self.neg_candidate_item[randindex])
neg_cat[i].append(self.neg_candidate_cat[randindex])
len_array = [len(x[0]) for x in b]
mask = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册