diff --git a/models/recall/tdm/tdm_reader.py b/models/recall/tdm/tdm_reader.py index 32d33aeb40cd94138395acc03556e93f634a86d5..17413249c2adc54d450129c76ac0761e27ba27a4 100644 --- a/models/recall/tdm/tdm_reader.py +++ b/models/recall/tdm/tdm_reader.py @@ -33,8 +33,8 @@ class TrainReader(Reader): This function needs to be implemented by the user, based on data format """ features = (line.strip('\n')).split('\t') - input_emb = features[0].split(' ') - item_label = [features[1]] + input_emb = map(float, features[0].split(' ')) + item_label = [int(features[1])] feature_name = ["input_emb", "item_label"] yield zip(feature_name, [input_emb] + [item_label])