我试图计算多标签文本分类的汉明损失和汉明分数
def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
acc_list = []
for i in range(y_true.shape[0]):
set_true = set( np.where(y_true[i])[0] )
set_pred = set( np.where(y_pred[i])[0] )
tmp_a = None
if len(set_true) == 0 and len(set_pred) == 0:
tmp_a = 1
else:
tmp_a = len(set_true.intersection(set_pred))/float(len(set_true.union(set_pred)))
acc_list.append(tmp_a)
return np.mean(acc_list)
def print_score(y_pred, clf):
print("Clf: ", clf.__class__.__name__)
print("Hamming loss: {}".format(hamming_loss(y_pred, y_test)))
print("Hamming score: {}".format(hamming_score(y_pred, y_test)))
print("---")
nb_clf = MultinomialNB()
sgd = SGDClassifier(loss='hinge', penalty='l2', alpha=1e-3, random_state=42, max_iter=6, tol=None)
lr = LogisticRegression()
mn = MultinomialNB()
for classifier in [nb_clf, sgd, lr, mn]:
clf = OneVsRestClassifier(classifier)
clf.fit(x_train, y_train)
y_pred = clf.predict(x_test)
print_score(y_pred, classifier)
汉明损失的结果发生了,但汉明分数出现了错误,有人能帮我解决这个问题吗?非常感谢。
Clf:多项式NB汉明损失:0.01911111111-----------------------------------------------8 Clf中的ValueError回溯(最近一次调用)。安装(x_系列,y_系列)9 y_pred=clf。预测(x_检验)---
<ipython-input-313-60ed43baa4c1> in print_score(y_pred, clf)
21 print("Clf: ", clf.__class__.__name__)
22 print("Hamming loss: {}".format(hamming_loss(y_pred, y_test)))
---> 23 print("Hamming score: {}".format(hamming_score(predictions, y_test)))
24 print("---")
<ipython-input-313-60ed43baa4c1> in hamming_score(y_true, y_pred, normalize, sample_weight)
8 acc_list = []
9 for i in range(y_true.shape[0]):
---> 10 set_true = set( np.where(y_true[i])[0] )
11 set_pred = set( np.where(y_pred[i])[0] )
12 tmp_a = None
~\Anaconda3\lib\site-packages\scipy\sparse\base.py in __bool__(self)
285 return self.nnz != 0
286 else:
--> 287 raise ValueError("The truth value of an array with more than one "
288 "element is ambiguous. Use a.any() or a.all().")
289 __nonzero__ = __bool__
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().
您好,您可以尝试下面的代码,看看它是否工作
print("Hamming score: {}".format(hamming_score(y_pred, y_test,normalize=False)))