我把自己的原码放在这里了,用Jupyter notebook编写的,包含该系列的所有代码,有需要的同学请自取!!
分类任务:从MNIST说起(四)多分类器
有一些算法(如随机森林或朴素贝叶斯分类器)可以直接处理多个类。也有一些严格的二元分类器(如:支持向量机、线性分类器)。但有很多方法可以用几个二元分类器实现多类分类。
- 法一:剥洋葱式分类器(OvR策略)
训练10个二元分类器,每个数组一个(0-检测器、1-检测器、······),当需要对一张图片进行检测分类时,获取每个分类器的决策分数,哪个分类器给分最高就分为哪个类。OvR策略也成一对多(($One-versus-All$))。 - 法二:排列组合分类器(OvO策略)
为每一对数字训练一个二元分类器,如:一个用于区分0和1、一个用于区分0和2、······如果存在N个类别,则需要训练$\frac{N\times(N-1)}{2}$个分类器。对于MNIST,需要训练45个二元分类器。对一张图片分类时,需要运行45个二分类,看看哪个类获胜最多。
用Scikit-Learn实现SVM多分类器
Scikit-Learn会根据情况自动运行OvR或者OvO。
我们先训练一个小型分类器,用数据集的前1000个样本进行训练:
from sklearn.svm import SVC
svm_clf = SVC(gamma="auto", random_state=42)
svm_clf.fit(X_train[:1000], y_train[:1000]) # y_train, not y_train_5
svm_clf.predict([some_digit])
输出:
array([5], dtype=uint8)
可以看到训练结果正好是第5类,分类正确。
注意:“数字5值第5类”只是个巧合
在训练分类器时,目标类的列表会存储在classes_
属性中,按照值的大小排序。在本例中,classes_
数组中每个类的索引正好对应其类本身(即:索引上的第5个类正好是数字5这个类)但是一般来说,并不会这么巧。所以一定要做好索引与标签的对应。
指定OvR或OvO策略
我们可以强制Scikit-Learn使用一对一或一对剩余策略,可以使用OneVsOne-Classifier
或OneVsRestClassifier
类。
例如,使用OvR策略,基于SVC创建多分类器:
from sklearn.multiclass import OneVsRestClassifier
ovr_clf = OneVsRestClassifier(SVC(gamma="auto", random_state=42))
ovr_clf.fit(X_train[:1000], y_train[:1000])
ovr_clf.predict([some_digit])
输出:
array([5], dtype=uint8)
估计器长度:
len(ovr_clf.estimators_)
10
用多分类器进行分类
我们可以直接用SGD或者随机森林进行分类,这两个分类器可以直接将实例分为多个类。
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])
输出:
array([3], dtype=uint8)
看来误差还是很大,5错分成了3
调用decision_function()
查看一下分类器将每个实例分类为每个类的概率列表:
sgd_clf.decision_function([some_digit])
输出:
array([[-31893.03095419, -34419.69069632, -9530.63950739,
1823.73154031, -22320.14822878, -1385.80478895,
-26188.91070951, -16147.51323997, -4604.35491274,
-12050.767298 ]])
可以看到,类3的得分是1823,所有分类中最高的。而正确答案类5的得分是-1385,分类第二高。
我们再用交叉验证测试一下:
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
输出:
array([0.87365, 0.85835, 0.8689 ])
每个折叠的准确率都达到了$85%$。如果是纯随机分类器,准确率大约是$10%$。相比之下这个分类器的结果也还能接受。
但我们可以通过简单缩放提高一下准确率:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64))
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
输出:
array([0.8983, 0.891 , 0.9018])
有了明显的提升。
Comments | NOTHING