diff --git a/models/rank/dien/reader.py b/models/rank/dien/reader.py index 6f6fdc10737c808beea0974e4bf83c6bb8770e5b..2368a2d856c40502a4bc101566c67c7ae799613d 100755 --- a/models/rank/dien/reader.py +++ b/models/rank/dien/reader.py @@ -90,33 +90,23 @@ class Reader(ReaderBase): for i in range(len(b)): neg_item[i] = [] neg_cat[i] = [] + # Neg item and neg cat should be paried if len(self.neg_candidate_item) < self.max_neg_item: self.neg_candidate_item.extend(b[i][0]) + self.neg_candidate_cat.extend(b[i][1]) if len(self.neg_candidate_item) > self.max_neg_item: - self.neg_candidate_item = self.neg_candidate_item[ - 0:self.max_neg_item] + self.neg_candidate_item = self.neg_candidate_item[0:self.max_neg_item] + self.neg_candidate_cat = self.neg_candidate_cat[0:self.max_neg_item] else: len_seq = len(b[i][0]) start_idx = random.randint(0, self.max_neg_item - len_seq - 1) - self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[ - i][0] + self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[i][0] + self.neg_candidate_cat[start_idx:start_idx + len_seq + 1] = b[i][1] - if len(self.neg_candidate_cat) < self.max_neg_cat: - self.neg_candidate_cat.extend(b[i][1]) - if len(self.neg_candidate_cat) > self.max_neg_cat: - self.neg_candidate_cat = self.neg_candidate_cat[ - 0:self.max_neg_cat] - else: - len_seq = len(b[i][1]) - start_idx = random.randint(0, self.max_neg_cat - len_seq - 1) - self.neg_candidate_item[start_idx:start_idx + len_seq + 1] = b[ - i][1] - for _ in range(len(b[i][0])): - neg_item[i].append(self.neg_candidate_item[random.randint( - 0, len(self.neg_candidate_item) - 1)]) - for _ in range(len(b[i][1])): - neg_cat[i].append(self.neg_candidate_cat[random.randint( - 0, len(self.neg_candidate_cat) - 1)]) + for _ in range(len(b[i][0])): + randindex = random.randint(0, len(self.neg_candidate_item) - 1) + neg_item[i].append(self.neg_candidate_item[randindex]) + neg_cat[i].append(self.neg_candidate_cat[randindex]) len_array = [len(x[0]) for x in b] mask = np.array(