结果异常
来源:2-4 LFM算法训练数据抽取

WPaulG
2022-10-08
老师,我这出的结果和您的不一样。
def get_train_data(input_file):
"""
get train data for LFM model train
Args:
input_file:user item rating file
Return:
a list:[(userid, itemid, label), (userid1, itemid1, label)]
"""
if not os.path.exists(input_file):
return[]
score_dict = get_ave_score(input_file)
neg_dict = {}
pos_dict = {}
train_data = []
linenum = 0
score_thr = 4.0
fp = open(input_file)
for line in fp:
if linenum == 0:
linenum += 1
continue
item = line.strip().split(',')
if len(item) < 4:
continue
userid, itemid, rating = item[0], item[1], float(item[2])
if userid not in pos_dict:
pos_dict[userid] = []
if userid not in neg_dict:
neg_dict[userid] = []
if rating >= score_thr:
pos_dict[userid].append((itemid, 1))
else:
score = score_dict.get(itemid, 0)
neg_dict[userid].append((itemid, score))
fp.close()
for userid in pos_dict:
data_num = min(len(pos_dict[userid]), len(neg_dict.get(userid, [])))
if data_num > 0:
train_data += [(userid, zuhe[0], zuhe[1]) for zuhe in pos_dict[userid]][:data_num]
else:
continue
sorted_neg_list = sorted(neg_dict[userid], key=lambda element:element[1], reverse=True)[:data_num]
train_data += [(userid, zuhe[0], 0) for zuhe in sorted_neg_list]
return train_data
if __name__ == "__main__":
train_data = get_train_data("../data/ratings.txt")
print(len(train_data))
print(train_data[:20])
写回答
1回答
-
David
2022-10-09
把那个图文字粘贴一下 那个图片看不见
032022-10-11
相似问题