So verwenden Sie die NumPy argmax () -Funktion in Python

Die NumPy argmax() Funktion: Den Index des maximalen Elements finden

Dieses Tutorial führt Sie durch die Anwendung der NumPy-Funktion argmax(), um den Index des größten Elements in Arrays zu lokalisieren. NumPy ist eine mächtige Python-Bibliothek für wissenschaftliche Berechnungen, die unter anderem mehrdimensionale Arrays bereitstellt, die flexibler als einfache Python-Listen sind. Häufig müssen Sie in einem NumPy-Array den Maximalwert bestimmen. Manchmal ist jedoch nicht nur der Wert selbst, sondern auch seine Position, also der Index, von Interesse.

Hier kommt die Funktion argmax() ins Spiel: Sie hilft Ihnen, genau diesen Index zu ermitteln, sowohl in eindimensionalen als auch in mehrdimensionalen Arrays. Im Folgenden werden wir uns genauer ansehen, wie dies funktioniert.

Vorbereitung

Um dieses Tutorial nachzuvollziehen, sollten Sie Python und die NumPy-Bibliothek installiert haben. Sie können die Beispiele entweder in einer Python-REPL oder in einem Jupyter Notebook ausprobieren. Importieren wir zunächst NumPy, wie es üblich ist, unter dem Alias np:

import numpy as np

Die NumPy-Funktion max() liefert Ihnen den größten Wert innerhalb eines Arrays (optional entlang einer bestimmten Achse):

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.max(array_1))

# Ausgabe
10

In diesem Beispiel gibt np.max(array_1) korrekterweise 10 zurück.

Nehmen wir an, Sie möchten den Index finden, an dem der Maximalwert auftritt. Eine Möglichkeit ist folgender Ansatz in zwei Schritten:

  • Zuerst den Maximalwert bestimmen.
  • Anschließend den Index dieses maximalen Elements ermitteln.

Im Array array_1 befindet sich der Maximalwert von 10 an Position 4, wenn wir mit 0 zu zählen beginnen. Das erste Element hat den Index 0, das zweite den Index 1 und so weiter.

Um den Index zu finden, können Sie NumPys where()-Funktion verwenden. np.where(condition) gibt ein Array mit allen Indizes zurück, für die die Bedingung zutrifft. Um die Position des Maximums zu finden, setzen wir die Bedingung auf array_1==10, da 10 der Maximalwert im Array ist.

print(int(np.where(array_1==10)[0]))

# Ausgabe
4

Wir haben hier np.where() nur mit der Bedingung benutzt, was nicht die gängigste Art ist, die Funktion einzusetzen.

Hinweis: Die volle Syntax von NumPys where() lautet np.where(condition,x,y) und gibt Folgendes zurück:

  • Elemente aus x, wenn die Bedingung wahr ist, und
  • Elemente aus y, wenn die Bedingung falsch ist.

Durch die Verkettung von np.max() und np.where() könnte man zuerst den Maximalwert finden und anschließend den dazugehörigen Index. Jedoch gibt es eine elegantere Lösung.

Statt des oben gezeigten zweistufigen Prozesses können Sie die NumPy-Funktion argmax() verwenden, die direkt den Index des größten Elements zurückgibt.

Syntax der NumPy argmax() Funktion

Die allgemeine Syntax für die Verwendung der Funktion np.argmax() sieht wie folgt aus:

np.argmax(array,axis,out)
# numpy wurde als np importiert

Hierbei ist:

  • array: ein gültiges NumPy-Array.
  • axis: ein optionaler Parameter. Bei mehrdimensionalen Arrays können Sie den axis-Parameter verwenden, um den Index des Maximums entlang einer bestimmten Achse zu finden.
  • out: ein optionaler Parameter. Er erlaubt Ihnen, die Ausgabe der Funktion argmax() in einem NumPy-Array zu speichern.

Beachten Sie, dass ab NumPy-Version 1.22.0 ein zusätzlicher Parameter keepdims existiert. Wenn wir den Parameter axis in argmax() verwenden, wird das Array entlang dieser Achse reduziert. Wenn Sie keepdims jedoch auf True setzen, wird die Ausgabe die gleiche Form wie das Eingabe-Array beibehalten.

Anwendung von NumPy argmax()

#1. Verwenden wir argmax(), um den Index des größten Elements in array_1 zu finden:

array_1 = np.array([1,5,7,2,10,9,8,4])
print(np.argmax(array_1))

# Ausgabe
4

Die Funktion argmax() gibt 4 zurück, was korrekt ist! ✅

#2. Wenn wir array_1 so verändern, dass 10 zweimal vorkommt, gibt argmax() dennoch nur den Index des ersten Auftretens zurück:

array_1 = np.array([1,5,7,2,10,10,8,4])
print(np.argmax(array_1))

# Ausgabe
4

Für die folgenden Beispiele verwenden wir wieder die ursprüngliche Version von array_1 aus Beispiel 1.

argmax() mit 2D-Arrays

Formen wir array_1 in ein zweidimensionales Array array_2 mit zwei Zeilen und vier Spalten um:

array_2 = array_1.reshape(2,4)
print(array_2)

# Ausgabe
[[ 1  5  7  2]
 [10  9  8  4]]

In einem zweidimensionalen Array bezeichnet die Achse 0 die Zeilen und die Achse 1 die Spalten. NumPy-Arrays beginnen bei 0 mit der Indizierung. Die Indizes für Zeilen und Spalten von array_2 sind wie folgt:

Rufen wir nun argmax() für array_2 auf:

print(np.argmax(array_2))

# Ausgabe
4

Obwohl wir argmax() auf ein zweidimensionales Array angewendet haben, ist die Ausgabe immer noch 4. Das entspricht der Ausgabe des eindimensionalen array_1 im vorherigen Abschnitt.

Warum ist das so? 🤔

Das liegt daran, dass wir keinen Wert für den Parameter axis angegeben haben. Ohne axis-Parameter gibt argmax() standardmäßig den Index des größten Elements im ‚geglätteten‘ Array zurück.

Was ist ein ‚geglättetes‘ Array? Wenn Sie ein N-dimensionales Array mit den Dimensionen d1 x d2 x … x dN haben, dann ist das geglättete Array ein eindimensionales Array der Länge d1 * d2 * … * dN.

Um zu sehen, wie das geglättete Array für array_2 aussieht, können Sie die Methode flatten() aufrufen:

array_2.flatten()

# Ausgabe
array([ 1,  5,  7,  2, 10,  9,  8,  4])

Maximum entlang der Zeilen (Achse = 0)

Bestimmen wir den Index des Maximums entlang der Zeilen (axis=0):

np.argmax(array_2,axis=0)

# Ausgabe
array([1, 1, 1, 1])

Diese Ausgabe ist etwas schwieriger zu verstehen. Wir betrachten jede Spalte einzeln.

Wir haben axis auf 0 gesetzt, da wir den Index des Maximums entlang der Zeilen suchen. argmax() gibt für jede Spalte die Zeilennummer zurück, in der sich das jeweilige Maximum befindet.

Hier zur besseren Verständlichkeit eine Visualisierung:

Aus dem Diagramm und der Ausgabe von argmax() ergibt sich:

  • In der ersten Spalte (Index 0) befindet sich das Maximum (10) in der zweiten Zeile (Index 1).
  • In der zweiten Spalte (Index 1) befindet sich das Maximum (9) in der zweiten Zeile (Index 1).
  • In der dritten und vierten Spalte (Index 2 und 3) befinden sich die Maxima (8 und 4) in der zweiten Zeile (Index 1).

Daher ist die Ausgabe [1, 1, 1, 1], da das maximale Element in allen Spalten in der zweiten Zeile liegt.

Maximum entlang der Spalten (Achse = 1)

Als Nächstes suchen wir den Index des maximalen Elements entlang der Spalten (axis=1):

np.argmax(array_2,axis=1)
# Ausgabe
array([2, 0])

Können Sie die Ausgabe verstehen?

Wir haben axis auf 1 gesetzt, um den Index des größten Elements entlang der Spalten zu berechnen.

Die Funktion argmax() gibt für jede Zeile die Spaltennummer zurück, die das Maximum enthält.

Hier ist die Visualisierung:

Aus dem Diagramm und der Ausgabe von argmax() ergibt sich:

  • In der ersten Zeile (Index 0) befindet sich das Maximum (7) in der dritten Spalte (Index 2).
  • In der zweiten Zeile (Index 1) befindet sich das Maximum (10) in der ersten Spalte (Index 0).

Ich hoffe, Sie verstehen nun die Bedeutung von array([2, 0]).

Verwendung des optionalen out-Parameters

Sie können den optionalen Parameter out verwenden, um die Ausgabe von argmax() in einem NumPy-Array zu speichern.

Initialisieren wir ein Array aus Nullen, in dem wir die Ausgabe des vorherigen argmax()-Aufrufs speichern, um den Index des Maximums entlang der Spalten (axis=1) zu erhalten:

out_arr = np.zeros((2,))
print(out_arr)
[0. 0.]

Versuchen wir nun, den Index des maximalen Elements entlang der Spalten zu berechnen (axis=1) und speichern die Ausgabe in out_arr:

np.argmax(array_2,axis=1,out=out_arr)

Wir stellen fest, dass der Python-Interpreter einen TypeError ausgibt, da out_arr standardmäßig mit einem Array aus Gleitkommazahlen initialisiert wurde.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     56     try:
---> 57         return bound(*args, **kwds)
     58     except TypeError:

TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'

Deshalb ist es entscheidend, dass das Ausgabe-Array die richtige Form und den richtigen Datentyp hat. Da Array-Indizes immer ganze Zahlen sind, sollten wir den Parameter dtype auf int setzen, wenn wir das Ausgabe-Array erstellen.

out_arr = np.zeros((2,),dtype=int)
print(out_arr)

# Ausgabe
[0 0]

Jetzt können wir argmax() mit axis und out aufrufen, ohne Fehler zu erhalten:

np.argmax(array_2,axis=1,out=out_arr)

Die Ausgabe von argmax() ist nun im Array out_arr gespeichert:

print(out_arr)
# Ausgabe
[2 0]

Fazit

Ich hoffe, dieses Tutorial hat Ihnen geholfen, die NumPy-Funktion argmax() zu verstehen. Sie können die Codebeispiele in einem Jupyter Notebook testen.

Fassen wir zusammen, was wir gelernt haben:

  • Die NumPy-Funktion argmax() gibt den Index des größten Elements im Array zurück. Wenn das maximale Element mehr als einmal vorkommt, liefert np.argmax(a) den Index des ersten Vorkommens.
  • Bei mehrdimensionalen Arrays können Sie den optionalen Parameter axis verwenden, um den Index des Maximums entlang einer bestimmten Achse zu erhalten. In einem 2D-Array bekommen Sie z.B. mit axis=0 den Index des größten Elements entlang der Zeilen und mit axis=1 entlang der Spalten.
  • Wenn Sie den Rückgabewert in einem anderen Array speichern wollen, können Sie den optionalen Parameter out verwenden. Das Ausgabearray sollte jedoch die passende Form und den richtigen Datentyp haben.

Als Nächstes könnten Sie sich das ausführliche Tutorial zu Python Sets ansehen.