Читать книгу Praxiseinstieg Machine Learning mit Scikit-Learn, Keras und TensorFlow - Aurélien Géron - Страница 91

MNIST

Оглавление

In diesem Kapitel werden wir den MNIST-Datensatz verwenden, eine Sammlung von 70.000 kleinen Bildern handschriftlicher Ziffern, die von Oberschülern und Mitarbeitern des US Census Bureaus aufgeschrieben wurden. Jedes Bild ist mit der dargestellten Ziffer gelabelt. Dieser Datensatz ist so intensiv untersucht worden, dass er oft als »Hello World« des Machine Learning bezeichnet wird: Wann immer jemand ein neues Klassifikationsverfahren entwickelt, möchte man wissen, wie es auf MNIST abschneidet. Jeder, der Machine Learning lernt, beschäftigt sich früher oder später mit MNIST.

Scikit-Learn enthält viele Hilfsfunktionen zum Herunterladen verbreiteter Datensätze. MNIST ist einer davon. Der folgende Code besorgt den MNIST-Datensatz:1

>>> from sklearn.datasets import fetch_openml

>>> mnist = fetch_openml('mnist_784', version=1)

>>> mnist.keys()

dict_keys(['data', 'target', 'feature_names', 'DESCR', 'details',

'categories', 'url'])

Die von Scikit-Learn heruntergeladenen Datensätze sind für gewöhnlich Dictionaries mit einer ähnlichen Struktur, bestehend aus folgenden Schlüsseln:

 Der Schlüssel DESCR beschreibt den Datensatz.

 Der Schlüssel data enthält ein Array mit einer Zeile pro Datenpunkt und einer Spalte pro Merkmal.

 Der Schlüssel target enthält ein Array mit den Labels.

Betrachten wir die beiden Arrays:

>>> X, y = mnist["data"], mnist["target"]

>>> X.shape

(70000, 784)

>>> y.shape

(70000,)

Es gibt 70.000 Bilder, und jedes davon hat 768 Merkmale. Das liegt daran, dass jedes Bild aus 28 x 28 Pixeln besteht und jedes Merkmal einfach die Intensität eines Pixels von 0 (weiß) bis 255 (schwarz) enthält. Betrachten wir eine Ziffer aus dem Datensatz. Dazu müssen wir lediglich den Merkmalsvektor eines Datenpunkts herausgreifen, zu einem Array mit den Abmessungen 28 x 28 umformatieren und mit der Funktion imshow() aus Matplotlib darstellen:

import matplotlib as mpl

import matplotlib.pyplot as plt

some_digit = X[0]

some_digit_image = some_digit.reshape(28, 28)

plt.imshow(some_digit_image, cmap="binary")

plt.axis("off")

plt.show()


Dieses Bild sieht wie eine 5 aus, was uns das Label auch bestätigt:

>>> y[0]

'5'

Beachten Sie, dass das Label ein String ist. Die meisten ML-Algorithmen erwarten Zahlen, daher wollen wir y in einen Integer casten:

>>> y = y.astype(np.uint8)

Abbildung 3-1 zeigt einige weitere Bilder aus dem MNIST-Datensatz, um Ihnen ein Gefühl für die Komplexität dieser Klassifikationsaufgabe zu geben.

Einen Moment noch! Sie sollten stets einen Testdatensatz erstellen und ihn vor dem genaueren Betrachten der Daten beiseitelegen. Der MNIST-Datensatz ist bereits in Trainingsdaten (die ersten 60.000 Bilder) und Testdaten (die letzten 10.000 Bilder) unterteilt:

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

Abbildung 3-1: Ziffern aus dem MNIST-Datensatz

Die Trainingsdaten sind schon für uns gemischt, was gut ist, denn damit stellen wir sicher, dass bei der Kreuzvalidierung sämtliche Folds einander ähnlich sind (Sie möchten nicht, dass einige Ziffern in einem Fold fehlen). Außerdem reagieren ein paar Lernalgorithmen sensibel auf die Reihenfolge der Trainingsdatenpunkte und schneiden schlechter ab, wenn sie viele ähnliche Datenpunkte nacheinander erhalten. Das Mischen des Datensatzes sorgt dafür, dass dies nicht passiert.2

Praxiseinstieg Machine Learning mit Scikit-Learn, Keras und TensorFlow

Подняться наверх