Читать книгу Praxiseinstieg Machine Learning mit Scikit-Learn, Keras und TensorFlow - Aurélien Géron - Страница 104
Fehleranalyse
ОглавлениеIn einem echten Projekt würden Sie an dieser Stelle natürlich die Schritte auf der Checkliste für Machine-Learning-Projekte abarbeiten (siehe Anhang B): Optionen zur Datenaufarbeitung abwägen, mehrere Modelle ausprobieren, die besten in eine engere Auswahl ziehen, deren Hyperparameter mit GridSearchCV optimieren und so viel wie möglich automatisieren. Wir nehmen an dieser Stelle an, dass Sie ein vielversprechendes Modell gefunden haben und nach Verbesserungsmöglichkeiten suchen. Eine Möglichkeit ist, die Arten von Fehlern zu untersuchen, die das Modell begeht.
Sie können sich zunächst die Konfusionsmatrix ansehen. Dazu müssen Sie als Erstes über die Funktion cross_val_predict() Vorhersagen treffen und anschließend die bereits gezeigte Funktion confusion_matrix() aufrufen:
>>> y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
>>> conf_mx = confusion_matrix(y_train, y_train_pred)
>>> conf_mx
array([[5578, 0, 22, 7, 8, 45, 35, 5, 222, 1],
[ 0, 6410, 35, 26, 4, 44, 4, 8, 198, 13],
[ 28, 27, 5232, 100, 74, 27, 68, 37, 354, 11],
[ 23, 18, 115, 5254, 2, 209, 26, 38, 373, 73],
[ 11, 14, 45, 12, 5219, 11, 33, 26, 299, 172],
[ 26, 16, 31, 173, 54, 4484, 76, 14, 482, 65],
[ 31, 17, 45, 2, 42, 98, 5556, 3, 123, 1],
[ 20, 10, 53, 27, 50, 13, 3, 5696, 173, 220],
[ 17, 64, 47, 91, 3, 125, 24, 11, 5421, 48],
[ 24, 18, 29, 67, 116, 39, 1, 174, 329, 5152]])
Dies ist eine Menge Zahlen. Es ist oft bequemer, sich eine Konfusionsmatrix über die Matplotlib-Funktion matshow() als Diagramm anzeigen zu lassen:
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
Diese Konfusionsmatrix sieht recht gut aus, da die meisten Bilder auf der Hauptdiagonalen liegen. Dies bedeutet, dass sie korrekt zugeordnet wurden. Die 5en sehen etwas dunkler als die anderen Ziffern aus. Dies könnte bedeuten, dass es weniger Bilder von 5en im Datensatz gibt oder dass der Klassifikator bei den 5en nicht so gut wie bei den übrigen Ziffern abschneidet. Es lässt sich zeigen, dass hier beides der Fall ist.
Heben wir im Diagramm nun die Fehler hervor. Zunächst müssen Sie jeden Wert in der Konfusionsmatrix durch die Anzahl der Bilder in der entsprechenden Kategorie teilen, sodass Sie den Anteil der Fehler statt der absoluten Anzahl Fehler vergleichen können (Letzteres würde die häufigen Kategorien übertrieben schlecht aussehen lassen):
row_sums = conf_mx.sum(axis=1, keepdims=True)
norm_conf_mx = conf_mx / row_sums
Nun füllen wir die Diagonale mit Nullen auf, um nur die Fehler zu betrachten, und plotten das Ergebnis:
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
Nun sehen Sie deutlich, welche Arten von Fehler der Klassifikator begeht. Wie gesagt, die Zeilen stehen für die tatsächlichen Kategorien und die Spalten für die vorhergesagten. Die Spalte der Kategorie 8 ist recht hell, es wurden also recht viele Bilder fälschlicherweise als 8 zugeordnet. Aber die Zeile für Kategorie 8 ist gar nicht so schlecht, was Ihnen zeigt, dass echte 8er im Allgemeinen korrekt als 8er klassifiziert werden. Wie Sie sehen, muss die Konfusionsmatrix nicht notwendigerweise symmetrisch sein. Sie können auch erkennen, dass 3er und 5er häufig verwechselt werden (in beide Richtungen).
Eine Analyse der Konfusionsmatrix gibt Ihnen häufig Aufschluss darüber, wie sich Ihr Klassifikator verbessern lässt. Aus diesem Diagramm sehen Sie, dass sich Ihre Mühe auf das Verbessern der Klassifikation der falschen 8er konzentrieren sollte. Sie könnten also beispielsweise versuchen, zusätzliche Trainingsdaten für Ziffern zu sammeln, die wie 8er aussehen (es aber nicht sind), sodass der Klassifikator lernen kann, sie von den echten 8ern zu unterscheiden. Oder Sie probieren, neue Merkmale ermitteln, die dem Klassifikator helfen können – beispielsweise ein Algorithmus, der die geschlossenen Schleifen zählt (eine 8 hat zwei, eine 6 hat eine, eine 5 keine). Oder Sie könnten die Bilder vorverarbeiten (z.B. mit Scikit-Image, Pillow oder OpenCV), um vorhandene Muster wie die Schleifen besser hervorzuheben.
Das Analysieren einzelner Fehler hilft auch dabei, zu erkennen, was Ihr Klassifikator eigentlich tut und warum er scheitert. Dies ist aber schwieriger und zeitlich aufwendiger. Wir könnten etwa Beispiele von 3en und 5en plotten:
cl_a, cl_b = 3, 5
X_aa = X_train[(y_train == cl_a) & (y_train_pred == cl_a)]
X_ab = X_train[(y_train == cl_a) & (y_train_pred == cl_b)]
X_ba = X_train[(y_train == cl_b) & (y_train_pred == cl_a)]
X_bb = X_train[(y_train == cl_b) & (y_train_pred == cl_b)]
plt.figure(figsize=(8,8))
plt.subplot(221); plot_digits(X_aa[:25], images_per_row=5)
plt.subplot(222); plot_digits(X_ab[:25], images_per_row=5)
plt.subplot(223); plot_digits(X_ba[:25], images_per_row=5)
plt.subplot(224); plot_digits(X_bb[:25], images_per_row=5)
plt.show()
Die zwei 5 × 5-Blöcke auf der linken Seite zeigen als 3en klassifizierte Ziffern, und die zwei 5 × 5-Blöcke auf der rechten Seite zeigen als 5en klassifizierte Ziffern. Einige der falsch zugeordneten Ziffern (z.B. in den Blöcken links unten und rechts oben) sind so unleserlich, dass sogar ein Mensch bei der Zuordnung Schwierigkeiten hätte (z.B. sieht die 5 in der ersten Zeile und zweiten Spalte wirklich wie eine schlecht geschriebene 3 aus). Es ist schwer nachzuvollziehen, warum der Klassifikator diese Fehler gemacht hat.3 Der Grund ist, dass wir einen einfachen SGDClas sifier verwendet haben, ein lineares Modell. Dieses ordnet in jeder Kategorie ein Gewicht pro Pixel zu und zählt bei einem neuen Bild einfach nur die gewichteten Intensitäten der Pixel zusammen, um einen Score für jede Kategorie zu berechnen. Da sich die 3en und 5en nur um einige Pixel unterscheiden, kommt das Modell bei diesen Ziffern leicht durcheinander.
Der Hauptunterschied zwischen 3en und 5en ist die Stellung der kurzen Linie, die die obere Linie mit dem unteren Bogen verbindet. Wenn Sie eine 3 zeichnen und diese Verbindungslinie ein Stück weiter links platzieren, könnte der Klassifikator das Bild leicht für eine 5 halten und umgekehrt. Anders gesagt, reagiert unser Klassifikator sehr sensibel auf Verschiebungen und Rotationen des Bilds. Ein Möglichkeit, der Verwechslung von 3 und 5 entgegenzuwirken, wäre daher, die Bilder so vorzuverarbeiten, dass sie alle zentriert und wenig gedreht sind. Vermutlich ließen sich auf diese Weise auch weitere Fehler vermeiden.