Künstliche neuronale Netze sind eine Form des Deep Learning. Der beste Weg, um ihre Funktionsweise vollständig zu durchdringen, besteht darin, sich selbst die Hände "schmutzig" zu machen. Dieser Artikel liefert Ihnen dafür die Grundlage und demonstriert, wie Sie ein neuronales Netzwerk in Java aufbauen und trainieren. Unser Beispiel für diesen Artikel ist dabei keineswegs ein produktionsreifes System - vielmehr gibt es in verständlicher Form Aufschluss über alle Hauptkomponenten.
Ein grundlegendes, neuronales Netz
Ein neuronales Netz ist ein Graph, bestehend aus Knoten (Nodes), die Neuronen genannt werden. Das Neuron ist die Grundeinheit der Berechnung: Es empfängt Eingaben und verarbeitet diese mithilfe:
eines Weight-per-Input-Algorithmus,
eines Bias-per-Node-Algorithmus sowie
eines Final-Function-Processor-Algorithmus.
Nachfolgende Abbildung zeigt ein Neuron mit zwei Inputs:
Dieses Modell ist sehr variabel - im Folgenden verwenden wir diese Konfiguration.
Unser erster Schritt besteht darin, eine Neuron
-Class zu modellieren, die diese Werte enthalten soll. Eine erste Version der Class sehen Sie im folgenden Listing 1 - diese wird sich im weiteren Verlauf verändern, wenn weitere Funktionen hinzukommen.
Listing 1: Eine einfache Neuron-Class
class Neuron {
Random random = new Random();
private Double bias = random.nextDouble(-1, 1);
public Double weight1 = random.nextDouble(-1, 1);
private Double weight2 = random.nextDouble(-1, 1);
public double compute(double input1, double input2){
double preActivation = (this.weight1 * input1) + (this.weight2 * input2) + this.bias;
double output = Util.sigmoid(preActivation);
return output;
}
}
Wie Sie sehen, ist die Neuron
-Class recht simpel und weist drei Mitglieder auf: bias
, weight1
und weight2
. Jedes dieser Mitglieder wird mit einem zufälligen Double zwischen -1 und 1 initialisiert.
Geht es darum, den Output des Neurons zu berechnen, folgen wir dem in Abbildung 1 gezeigten Algorithmus: Wir multiplizieren jede Eingabe mit ihrer Gewichtung plus dem Bias: input1 * weight1 + input2 * weight2 + bias. So erhalten wir die unverarbeitete Berechnung (preActivation
), die wir durch die Aktivierungsfunktion laufen lassen. In diesem Fall verwenden wir die Sigmoid-Aktivierungsfunktion, die Werte in einem Bereich von -1 bis 1 komprimiert. Im Folgenden die statische Util.sigmoid()
-Methode.
Listing 2: Sigmoid-Aktivierungsfunktion
public class Util {
public static double sigmoid(double in){
return 1 / (1 + Math.exp(-in));
}
}
Nachdem wir nun die Funktionsweise von Neuronen beleuchtet haben, gilt es, einige Neuronen in ein Netzwerk einfügen. Dazu nutzen wir eine Network
-Class mit einer Liste von Neuronen.
Listing 3: Die Neural Network Class
class Network {
List<Neuron> neurons = Arrays.asList(
new Neuron(), new Neuron(), new Neuron(), /* input nodes */
new Neuron(), new Neuron(), /* hidden nodes */
new Neuron()); /* output node */
}
}
Obwohl die Liste der Neuronen eindimensional ist, werden wir sie während der Nutzung zu einem Netzwerk verbinden. Die ersten drei Neuronen sind Inputs, die folgenden beiden versteckt und das letzte der Output-Knoten.
Eine Prediction anstoßen
Nun soll es darum gehen, ein Netzwerk zu Prediction-Zwecken einzusetzen. Dazu verwenden wir einen einfachen Datensatz mit zwei ganzzahligen Inputs und einem Antwortformat von 0 bis 1. In unserem Beispiel wird eine Kombination aus Gewicht und Größe verwendet, um das Geschlecht einer Person zu erraten.
Dabei wird davon ausgegangen, dass mehr Gewicht und Größe auf eine männliche Person hindeuten. Dieselbe Formel ließe sich für jede beliebige Wahrscheinlichkeitsrechnung mit zwei Faktoren und einem Output nutzen. Den Input könnte man auch als Vektor betrachten - und somit die Gesamtfunktion der Neuronen als Umwandlung eines Vektors in einem Skalarwert. Die Prediction-Phase des Netzes gestaltet sich wie folgt.
Listing 4: Network prediction
public Double predict(Integer input1, Integer input2){
return neurons.get(5).compute(
neurons.get(4).compute(
neurons.get(2).compute(input1, input2),
neurons.get(1).compute(input1, input2)
),
neurons.get(3).compute(
neurons.get(1).compute(input1, input2),
neurons.get(0).compute(input1, input2)
)
);
}
Listing 4 zeigt, dass die beiden Inputs an die ersten drei Neuronen fließen. Deren Outputs werden an die Neuronen 4 und 5 weitergeleitet wird, die wiederum in das Output-Neuron einspeisen. Dieser Prozess wird als Feedforward bezeichnet. Nun könnten wir das Netz zu einer Prediction auffordern.
Listing 5: Prediction
Network network = new Network();
Double prediction = network.predict(Arrays.asList(115, 66));
System.out.println("prediction: " + prediction);
Das würde sicher zu Ergebnissen führen - die allerdings nur auf Zufallswerten und Bias basieren. Für eine echte Prediction ist es nötig, das Netzwerk zuvor zu trainieren.
Das Netzwerk trainieren
Das Training eines neuronalen Netzwerks folgt einem Prozess, der als Backpropagation bekannt ist. Der beinhaltet im Grunde, Änderungen rückwärts durch das Netzwerk zu "schieben", damit sich der Output in Richtung eines gewünschten Zielwerts bewegt. Backpropagation lässt sich mit Hilfe von Funktionsdifferenzierung durchführen - für unser Beispiel werden wir allerdings einen anderen Weg gehen und jedem Neuron die Fähigkeit verleihen, zu "mutieren".
In jeder Trainingsrunde (auch Epoch genannt) wählen wir ein anderes Neuron aus, um eine kleine, zufällige Anpassung an einer seiner Eigenschaften (weight1
, weight2
oder bias
) vorzunehmen und dann zu prüfen, ob sich die Ergebnisse verbessern. Ist das der Fall, behalten wir diese Änderung mit einer remember()
-Methode bei. Wenn sich die Ergebnisse verschlechtert haben, machen wir sie mit einer forget()
-Methode rückgängig.
Um die Änderungen zu tracken, fügen wir Class-Mitglieder hinzu (old*
-Versionen von weights und bias). Im Folgenden betrachten wir die Methoden mutate()
, remember()
und forget()
.
Listing 6: mutate(), remember(), forget()
public class Neuron() {
private Double oldBias = random.nextDouble(-1, 1), bias = random.nextDouble(-1, 1);
public Double oldWeight1 = random.nextDouble(-1, 1), weight1 = random.nextDouble(-1, 1);
private Double oldWeight2 = random.nextDouble(-1, 1), weight2 = random.nextDouble(-1, 1);
public void mutate(){
int propertyToChange = random.nextInt(0, 3);
Double changeFactor = random.nextDouble(-1, 1);
if (propertyToChange == 0){
this.bias += changeFactor;
} else if (propertyToChange == 1){
this.weight1 += changeFactor;
} else {
this.weight2 += changeFactor;
};
}
public void forget(){
bias = oldBias;
weight1 = oldWeight1;
weight2 = oldWeight2;
}
public void remember(){
oldBias = bias;
oldWeight1 = weight1;
oldWeight2 = weight2;
}
}
Zusammengefasst:
Die
mutate()
-Methode wählt eine zufällige Eigenschaft und einen zufälligen Wert zwischen -1 und 1 aus und ändert dann die Eigenschaft.Die
forget()
-Methode setzt diese Änderung auf den alten Wert zurück.Die
remember()
-Methode kopiert den neuen Wert in den Puffer.
Um nun die neuen Fähigkeiten unseres Neuron
s zu nutzen, fügen wir Network
eine train()
-Methode hinzu.
Listing 7: Die Network.train()-Methode
public void train(List<List<Integer>> data, List<Double> answers){
Double bestEpochLoss = null;
for (int epoch = 0; epoch < 1000; epoch++){
// adapt neuron
Neuron epochNeuron = neurons.get(epoch % 6);
List<Double> predictions = new ArrayList<Double>();
for (int i = 0; i < data.size(); i++){
predictions.add(i, this.predict(data.get(i).get(0), data.get(i).get(1)));
}
Double thisEpochLoss = Util.meanSquareLoss(answers, predictions);
if (bestEpochLoss == null){
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
} else {
if (thisEpochLoss < bestEpochLoss){
bestEpochLoss = thisEpochLoss;
epochNeuron.remember();
} else {
epochNeuron.forget();
}
}
}
Die train()
-Methode iteriert eintausendmal über die aufgeführten data
, answers
und list
s. Es handelt sich um gleich große Trainingsmengen: data
beinhaltet Input-Werte, answers
die bekannten, richtigen Antworten. Die Methode ermittelt dann einen Wert darüber, wie nahe das Ergebnis des Netzwerks den bekannten, richtigen Antworten kommt. Dann wird ein zufälliges Neuron verändert (mutiert), wobei die Änderung beibehalten wird, wenn ein neuer Test ergibt, dass sie eine bessere Vorhersage zur Folge hatte.
Die Ergebnisse lassen sich mithilfe der Mean-Squared-Error (MSE) -Formel überprüfen - einer dafür gängigen Methode.
Listing 8: MSE-Funktion
public static Double meanSquareLoss(List<Double> correctAnswers, List<Double> predictedAnswers){
double sumSquare = 0;
for (int i = 0; i < correctAnswers.size(); i++){
double error = correctAnswers.get(i) - predictedAnswers.get(i);
sumSquare += (error * error);
}
return sumSquare / (correctAnswers.size());
}
System feinabstimmen
Nun müssen wir nur noch einige Trainingsdaten in das Netz fließen lassen und es mit weiteren Predictions austesten. Im Folgenden betrachten wir, wie man Trainingsdaten bereitstellt.
Listing 9: Trainingsdaten
List<List<Integer>> data = new ArrayList<List<Integer>>();
data.add(Arrays.asList(115, 66));
data.add(Arrays.asList(175, 78));
data.add(Arrays.asList(205, 72));
data.add(Arrays.asList(120, 67));
List<Double> answers = Arrays.asList(1.0,0.0,0.0,1.0);
Network network = new Network();
network.train(data, answers);
In Listing 9 bestehen unsere Trainingsdaten aus einer Liste von zweidimensionalen Integer-Sets (wir könnten sie uns als Gewicht und Größe vorstellen) und einer Liste von Antworten (wobei 1.0 weiblich und 0.0 männlich ist).
Wenn wir den Trainingsalgorithmus eine Logging-Funktionalität hinzufügen, erhalten wir nachfolgendes Resultat.
Listing 10. Trainingsdaten-Logging
// Logging:
if (epoch % 10 == 0) System.out.println(String.format("Epoch: %s | bestEpochLoss: %.15f | thisEpochLoss: %.15f", epoch, bestEpochLoss, thisEpochLoss));
// output:
Epoch: 910 | bestEpochLoss: 0.034404863820424 | thisEpochLoss: 0.034437939546120
Epoch: 920 | bestEpochLoss: 0.033875954196897 | thisEpochLoss: 0.431451026477016
Epoch: 930 | bestEpochLoss: 0.032509260025490 | thisEpochLoss: 0.032509260025490
Epoch: 940 | bestEpochLoss: 0.003092720117159 | thisEpochLoss: 0.003098025397281
Epoch: 950 | bestEpochLoss: 0.002990128276146 | thisEpochLoss: 0.431062364628853
Epoch: 960 | bestEpochLoss: 0.001651762688346 | thisEpochLoss: 0.001651762688346
Epoch: 970 | bestEpochLoss: 0.001637709485751 | thisEpochLoss: 0.001636810460399
Epoch: 980 | bestEpochLoss: 0.001083365453009 | thisEpochLoss: 0.391527869500699
Epoch: 990 | bestEpochLoss: 0.001078338540452 | thisEpochLoss: 0.001078338540452
Wie in Listing 10 zu sehen, nimmt der "Loss" (also die Fehlerabweichung von "100 Prozent korrekt") langsam ab. Das Modell nähert sich also immer mehr einer genauen Vorhersage an. Nun gilt es zu überprüfen, wie gut unser Modell mit echten Daten funktioniert.
Listing 11: Vorhersagen
System.out.println("");
System.out.println(String.format(" male, 167, 73: %.10f", network.predict(167, 73)));
System.out.println(String.format("female, 105, 67: %.10", network.predict(105, 67)));
System.out.println(String.format("female, 120, 72: %.10f | network1000: %.10f", network.predict(120, 72)));
System.out.println(String.format(" male, 143, 67: %.10f | network1000: %.10f", network.predict(143, 67)));
System.out.println(String.format(" male', 130, 66: %.10f | network: %.10f", network.predict(130, 66)));
Wie in Listing 11 zu sehen, füttern wir unser trainiertes neuronales Netz mit Daten und geben die Vorhersagen aus. Das Resultat sieht in etwa wie folgt aus.
Listing 12: Trainierte Predictions
male, 167, 73: 0.0279697143
female, 105, 67: 0.9075809407
female, 120, 72: 0.9075808235
male, 143, 67: 0.0305401413
male, 130, 66: network: 0.9009811922
Listing 12 zeigt, dass das Netzwerk bei den meisten Wertepaaren (Vektoren) ziemlich gute Arbeit geleistet hat. Es gibt den weiblichen Datensätzen eine Schätzung um 0.907 - was ziemlich nahe an 1 ist. Zwei männliche Datensätze weisen 0.027 und 0.030 auf - und nähern sich damit der 0. Der männliche Ausreißer-Datensatz (130, 67) wird als "wahrscheinlich weiblich" angesehen, bei einem Wert von 0.900 allerdings mit geringerer Zuversicht.
Es gibt eine Reihe von Möglichkeiten, die Einstellungen an diesem System zu verändern: Die Anzahl der Epochs in einem Trainingslauf ist dabei ein wichtiger Faktor. Je mehr Epochs, desto besser wird das Modell auf die Daten abgestimmt. Das kann auch die Genauigkeit von Live-Daten verbessern, die mit den Trainingssätzen übereinstimmen. Allerdings kann es auch in einem "Overtraining" resultieren - also einem Modell, das zuversichtlich die falschen Ergebnisse für Randfälle vorhersagt.
Den vollständigen Code für dieses Tutorial finden Sie in diesem GitHub-Repository - zusammen mit einigen zusätzlichen Funktionen. (fm)
Dieser Beitrag basiert auf einem Artikel unserer US-Schwesterpublikation Infoworld.