分类任务:从MNIST说起(二)二元分类器及精度、召回率


我把自己的原码放在这里了,用Jupyter notebook编写的,包含该系列的所有代码,有需要的同学请自取!!

附件点击下载

CV界的HelloWorld:MNIST手写字识别(二)训练二元分类器

从这一篇博客开始着手编程进行MNIST数据集的分类。一开始先让我们简化一下问题,先尝试只识别一个数字,如数字5.这样“数字5检测器”就变成了一个二元分类器,分类结果只有5非5

用SGD搭建二元分类器

首先划分训练集和测试集,这里我是按照6:1的比例,前60000张作为训练集,最后10000张作为测试集。

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

然后为分类任务创建目标向量:

# 若为5则为true,非5则为false
y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5)

分类器我们采用随机梯度下降(stochastic parallel gradient descent,SGD)算法,也是很多神经网络采用的优化算法。这个分类器能够有效处理大型数据集。

from sklearn.linear_model import SGDClassifier

sgd_clf = SGDClassifier(max_iter=1000, tol=1e-3, random_state=42)
sgd_clf.fit(X_train, y_train_5)

然后用它来检测数字5的图片(见上一篇博客CV界的HelloWorld:MNIST手写字识别(一)数据集介绍):

>>>sgd_clf.predict([some_digit])

输出:

array([ True])

分类正确!

测试分类器性能

我们采用交叉验证来测量准确率。
关于K-折交叉验证:该方法将训练集随机分割成$K$个不同的自己,每个子集称为一个折叠,然后对模型进行$K$次训练和评估,每次挑选1个折叠进行评估,使用另外$K-1$个折叠进行训练。

交叉验证的两种方法:
法一:自己实现交叉验证的代码

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

skfolds = StratifiedKFold(n_splits=3, shuffle=True, random_state=42)

for train_index, test_index in skfolds.split(X_train, y_train_5):
    clone_clf = clone(sgd_clf)
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]
    X_test_fold = X_train[test_index]
    y_test_fold = y_train_5[test_index]

    clone_clf.fit(X_train_folds, y_train_folds)
    y_pred = clone_clf.predict(X_test_fold)
    n_correct = sum(y_pred == y_test_fold)
    print(n_correct / len(y_pred))

输出:

0.9669
0.91625
0.96785

法二:直接调用内置函数!(理直气壮)

from sklearn.model_selection import cross_val_predict
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

输出

array([0.95035, 0.96035, 0.9604 ])

题外话,我们是否应该重复造轮子

虽然Scikit-learn有内置的交叉验证函数cross_val_predict,但是依然建议大家手动实现一下交叉验证的函数,这样对交叉验证的理解也会加深。大家都说“不要重复发明轮子”,但没有人说“不要重复造轮子”。动手实现别人封装好的代码无论对概念的理解还是对自身编程水平的提高都有很大的锻炼。

关于分类结果的说明

准确率高并不代表模型有效

从上面的结果中我们可以看到,我们的准确率非常高,达到了$93%$,但这并不代表我们模型能够准确分类。因为数据集里有超过$90%$的数据是非5,只有$10%$的数据是5。换句话说,就算模型把所有的预测都输出非5,模型也能有$90%$的准确率。这也说明了准确率无法成为分类器的首要性能指标,特别是处理有偏数据集时。

混淆矩阵

评估分类器性能的更好方法是混淆矩阵(Confusion Matrix)。混淆矩阵的总体思路是:统计A类别实例被分成B类别的次数。例如:想要知道分类器将数字3和数字5混淆了多少次,只需要通过混淆矩阵的第5行第3列来查看。

让我们用训练集进行预测,通过交叉验证绘制混淆矩阵。

from sklearn.model_selection import cross_val_predict

y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

y_train_predy_train_score一样,用来执行K-折交叉验证,但返回的是每个折叠的预测。每个实例都会得到一个干净的预测,即模型预测时使用的数据在训练期间从未出现过

然后通过confusion_matrix()函数计算混淆矩阵:

from sklearn.metrics import confusion_matrix

confusion_matrix(y_train_5, y_train_pred)

输出:

array([[53892, 687],
[1891, 3530]], dtype=int64)

注意:SGD算法的参数不是固定的,所以混淆矩阵的值每次都不一样。

混淆矩阵的行表示实际类别,列表示预测类别。
本例中:第一行表示所有非5的图片中:53057张被正确分类为非5类,687张被错误分类成了5;第二行表示所有5的图片长,1891张被错误分类成非5,3530张被正确分在了5。而一个完美的分类器,只会在其左上到右下的对角线上有非零值:

# 假设分类器是完美的
y_train_perfect_predictions = y_train_5  
confusion_matrix(y_train_5, y_train_perfect_predictions)

输出:

array([[54579, 0],
[ 0, 5421]])

精度和召回率

混淆矩阵能提供大量信息,但我们希望指标能更简洁一些。下面我们介绍精度和召回率的概念。
精度:是正类预测的准确率,公式为:

$$ 精度=\frac{TP}{TP+FP} $$

其中,$TP$表示真正类的数量,$FP$表示假正类的数量。

Scikit-learn提供的精度计算函数如下:

from sklearn.metrics import precision_score, recall_score

precision_score(y_train_5, y_train_pred)

输出:

0.8370879772350012

单独的精度并没有什么意义,因为他忽略了正类实例以外的所有内容。所以需要与其他的指标联合使用,即召回率,或灵敏度。

召回率/灵敏度:分类器正确检测到的正类实例的比率,公式为:

$$ 召回率=\frac{TP}{TP+FN} $$

其中$FN$是假负类的数量。

from sklearn.metrics import precision_score, recall_score

precision_score(y_train_5, y_train_pred)

输出:

0.6511713705958311

从这两个指标中我们可以看到,模型说一张图片是5时,只有$83.7%$的概率是正确的,而且只有$65.1$的数字5倍检测出来了。

F!

我们还可以将精度和召回率组合成一个单一的指标,称为$F_1$分数。$F_1$分数是精度和召回率的谐波平均值。正常的平均值平等对待所有的值,但谐波平均值会给予低值更高的权重。因此,只有当召回率和精度都很高时,分类器才能得到较高的$F_1$分数。
公式如下:

$$ F_1 =\frac{2}{\frac{1}{精度}+\frac{1}{召回率}}=2\times\frac{精度\times召回率}{精度+召回率}=\frac{TP}{TP+\frac{FN+FP}{2}} $$

要计算F1分数,只需要调用f1_score()即可:

from sklearn.metrics import f1_score

f1_score(y_train_5, y_train_pred)

输出:

0.7325171197343846

分类器的精度和召回率相近时,$F_1$分数更高。

精度/召回率权衡

鱼和熊掌不可兼得

模型不能同时增加精度又减少召回率,这就叫做精度/召回率权衡。很多时候我们只能根据需求来选择舍弃精度还是舍弃召回率。

  • 情景一:编写程序来过滤适合儿童观看的视频,我们宁愿过滤掉一些无害的视频(正类),也不愿漏掉一个暴力血腥视频(负类)。
    权衡选择:高精度、低召回率
  • 情景二:编写程序来通过监控视频检测小偷,我们宁愿抓错很多个好人$^{[1]}$(正类),也不愿放过任何一个坏人(负类)
    权衡选择:低精度,高召回率
[1] 没有对好人不敬的意思,安得广厦千万间大辟天下寒士俱欢颜(真诚)

精度/召回率的权衡过程
想要理解精度和召回率的权衡过程,首先要理解SGDClassifier如何进行分类决策。对于每个实例,会根据决策函数计算出一个值,若这个值大于阈值,则将该实例分为正类,否则分为负类。阈值越大,精度越高,但是召回率越低;阈值越小,精度越低,但是召回率越高。

Scikit-learn不允许直接设置阈值,但可以访问它用于预测的决策函数。我们调用decision_function()方法,该方法返回每个实例的分数,然后根据这些分数,我们就可以使用任意阈值进行预测了。
依然以第一张图数字5(变量some_digit)为例:

y_scores = sgd_clf.decision_function([some_digit])
y_scores

输出:

array([2164.22030239])

此时该图的预测分数是2164。
当阈值设为0时,显然该图分数大于阈值,分类器能检测到该图:

threshold = 0
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred

输出:

array([True])

但当阈值调整为8000时,分类器就无法检测出该图:

threshold = 8000
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred

输出:

array([False])

如何设置阈值?

首先,使用cross_val_predict()函数获取训练集中所有实例的分数,注意是决策分数,不是预测结果

y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3,
                             method="decision_function")

而后就可以使用precision_recall_curve()函数计算所有可能的阈值的精度和召回率:

from sklearn.metrics import precision_recall_curve
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

绘制Matplotlib绘制精度和召回率相对于阈值的函数图:

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.legend(loc="center right", fontsize=16)
    plt.xlabel("Threshold", fontsize=16) 
    plt.grid(True) 
    plt.axis([-50000, 50000, 0, 1]) 


recall_90_precision = recalls[np.argmax(precisions >= 0.90)]
threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]


plt.figure(figsize=(8, 4))                                                                  # Not shown
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)
plt.plot([threshold_90_precision, threshold_90_precision], [0., 0.9], "r:")                 # Not shown
plt.plot([-50000, threshold_90_precision], [0.9, 0.9], "r:")                                # Not shown
plt.plot([-50000, threshold_90_precision], [recall_90_precision, recall_90_precision], "r:")# Not shown
plt.plot([threshold_90_precision], [0.9], "ro")                                             # Not shown
plt.plot([threshold_90_precision], [recall_90_precision], "ro")                             # Not shown
save_fig("precision_recall_vs_threshold_plot")                                              # Not shown
plt.show()

绘制图像如下:



图1

为什么阈值偏高时,精度会出现波动?

其实很好理解,当精度变大时,预测正类中依然可能存在误分类的负类。当我们排除了一个正确的正类时,精度就会下降。比如排除掉一个正确的正类,精度可能会从$4/5(80%)$变成$3/4(75%)$。
但是,召回率绝对不会出现波动。因为当阈值上升时,未被检测到的正类只会越来越多,反之亦然。所以召回率曲线必然是平滑的。


还有一种搜索精度/召回率权衡的方法:直接绘制精度和召回率的函数图。
绘制方法如下:

def plot_precision_vs_recall(precisions, recalls):
    plt.plot(recalls, precisions, "b-", linewidth=2)
    plt.xlabel("召回率", fontsize=16)
    plt.ylabel("精度", fontsize=16)
    plt.axis([0, 1, 0, 1])
    plt.grid(True)

plt.figure(figsize=(8, 6))
plot_precision_vs_recall(precisions, recalls)
plt.plot([recall_90_precision, recall_90_precision], [0., 0.9], "r:")
plt.plot([0.0, recall_90_precision], [0.9, 0.9], "r:")
plt.plot([recall_90_precision], [0.9], "ro")
save_fig("precision_vs_recall_plot")
plt.show()

绘制图像如下:



图2

可以看到,从$80%$的召回率往右,精度开始急剧下降。所以我们尽量在陡降发生前选择一个精度/召回率权衡,比如在召回率$60%$左右。

假设我们现在需要一个精度为$90%$的分类器。查找图一,会发现需要设置3500的阈值。查找方式当然不是用眼看!使用函数搜索:

threshold_90_precision = thresholds[np.argmax(precisions >= 0.90)]
threshold_90_precision

输出:

3370.0194991439557

要进行预测,除了调用分类器的predict()方法,也可以运行这段代码:

y_train_pred_90 = (y_scores >= threshold_90_precision)
print('精度:',precision_score(y_train_5, y_train_pred_90)) # 精度
print('召回率:',recall_score(y_train_5, y_train_pred_90)) # 精度

输出:

精度: 0.9000345901072293
召回率: 0.4799852425751706

到此为止我们可以按照自己的想法随心所欲的指定二分类器的精度了!(代价是召回率哭了)

声明:奋斗小刘|版权所有,违者必究|如未注明,均为原创|本网站采用BY-NC-SA协议进行授权

转载:转载请注明原文链接 - 分类任务:从MNIST说起(二)二元分类器及精度、召回率


Make Everyday Count