Java Deep Learning Tutorial
So bauen Sie ein neuronales Netzwerk auf
System feinabstimmen
Nun müssen wir nur noch einige Trainingsdaten in das Netz fließen lassen und es mit weiteren Predictions austesten. Im Folgenden betrachten wir, wie man Trainingsdaten bereitstellt.
Listing 9: Trainingsdaten
List<List<Integer>> data = new ArrayList<List<Integer>>();
data.add(Arrays.asList(115, 66));
data.add(Arrays.asList(175, 78));
data.add(Arrays.asList(205, 72));
data.add(Arrays.asList(120, 67));
List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);
Network network = new Network();
network.train(data, answers);
In Listing 9 bestehen unsere Trainingsdaten aus einer Liste von zweidimensionalen Integer-Sets (wir könnten sie uns als Gewicht und Größe vorstellen) und einer Liste von Antworten (wobei 1.0 weiblich und 0.0 männlich ist).
Wenn wir den Trainingsalgorithmus eine Logging-Funktionalität hinzufügen, erhalten wir nachfolgendes Resultat.
Listing 10. Trainingsdaten-Logging
// Logging:
if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));
// output:
Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120
Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016
Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490
Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281
Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853
Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346
Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399
Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699
Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452
Wie in Listing 10 zu sehen, nimmt der "Loss" (also die Fehlerabweichung von "100 Prozent korrekt") langsam ab. Das Modell nähert sich also immer mehr einer genauen Vorhersage an. Nun gilt es zu überprüfen, wie gut unser Modell mit echten Daten funktioniert.
Listing 11: Vorhersagen
System.out.println("");
System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));
System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));
System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));
System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));
System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));
Wie in Listing 11 zu sehen, füttern wir unser trainiertes neuronales Netz mit Daten und geben die Vorhersagen aus. Das Resultat sieht in etwa wie folgt aus.
Listing 12: Trainierte Predictions
male, 167, 73: 0.0279697143
female, 105, 67: 0.9075809407
female, 120, 72: 0.9075808235
male, 143, 67: 0.0305401413
male, 130, 66: network: 0.9009811922
Listing 12 zeigt, dass das Netzwerk bei den meisten Wertepaaren (Vektoren) ziemlich gute Arbeit geleistet hat. Es gibt den weiblichen Datensätzen eine Schätzung um 0.907 - was ziemlich nahe an 1 ist. Zwei männliche Datensätze weisen 0.027 und 0.030 auf - und nähern sich damit der 0. Der männliche Ausreißer-Datensatz (130, 67) wird als "wahrscheinlich weiblich" angesehen, bei einem Wert von 0.900 allerdings mit geringerer Zuversicht.
Es gibt eine Reihe von Möglichkeiten, die Einstellungen an diesem System zu verändern: Die Anzahl der Epochs in einem Trainingslauf ist dabei ein wichtiger Faktor. Je mehr Epochs, desto besser wird das Modell auf die Daten abgestimmt. Das kann auch die Genauigkeit von Live-Daten verbessern, die mit den Trainingssätzen übereinstimmen. Allerdings kann es auch in einem "Overtraining" resultieren - also einem Modell, das zuversichtlich die falschen Ergebnisse für Randfälle vorhersagt.
Den vollständigen Code für dieses Tutorial finden Sie in diesem GitHub-Repository - zusammen mit einigen zusätzlichen Funktionen. (fm)
Dieser Beitrag basiert auf einem Artikel unserer US-Schwesterpublikation Infoworld.