未验证 提交 f7fe635f 编写于 作者: Q Qdriving 提交者: GitHub

Update reader.py

 Neg item and neg cat should be paried
上级 3b9c100b
...@@ -90,33 +90,23 @@ class Reader(ReaderBase): ...@@ -90,33 +90,23 @@ class Reader(ReaderBase):
for i in range(len(b)): for i in range(len(b)):
neg_item[i] = [] neg_item[i] = []
neg_cat[i] = [] neg_cat[i] = []
# Neg item and neg cat should be paried
if len(self.neg_candidate_item) < self.max_neg_item: if len(self.neg_candidate_item) < self.max_neg_item:
self.neg_candidate_item.extend(b[i][0]) self.neg_candidate_item.extend(b[i][0])
self.neg_candidate_cat.extend(b[i][1])
if len(self.neg_candidate_item) > self.max_neg_item: if len(self.neg_candidate_item) > self.max_neg_item:
self.neg_candidate_item = self.neg_candidate_item[ self.neg_candidate_item = self.neg_candidate_item[0:self.max_neg_item]
0:self.max_neg_item] self.neg_candidate_cat = self.neg_candidate_cat[0:self.max_neg_item]
else: else:
len_seq = len(b[i][0]) len_seq = len(b[i][0])
start_idx = random.randint(0, self.max_neg_item - len_seq - 1) start_idx = random.randint(0, self.max_neg_item - len_seq - 1)
self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[ self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[i][0]
i][0] self.neg_candidate_cat[start_idx:start_idx + len_seq + 1] = b[i][1]
if len(self.neg_candidate_cat) < self.max_neg_cat: for _ in range(len(b[i][0])):
self.neg_candidate_cat.extend(b[i][1]) randindex = random.randint(0, len(self.neg_candidate_item) - 1)
if len(self.neg_candidate_cat) > self.max_neg_cat: neg_item[i].append(self.neg_candidate_item[randindex])
self.neg_candidate_cat = self.neg_candidate_cat[ neg_cat[i].append(self.neg_candidate_cat[randindex])
0:self.max_neg_cat]
else:
len_seq = len(b[i][1])
start_idx = random.randint(0, self.max_neg_cat - len_seq - 1)
self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[
i][1]
for _ in range(len(b[i][0])):
neg_item[i].append(self.neg_candidate_item[random.randint(
0, len(self.neg_candidate_item) - 1)])
for _ in range(len(b[i][1])):
neg_cat[i].append(self.neg_candidate_cat[random.randint(
0, len(self.neg_candidate_cat) - 1)])
len_array = [len(x[0]) for x in b] len_array = [len(x[0]) for x in b]
mask = np.array( mask = np.array(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册