Ebből az oktatóanyagból megtudhatja, hogyan használhatja a NumPy argmax() függvényt a tömbök maximális elemének indexének megkeresésére.
A NumPy egy hatékony könyvtár a tudományos számításokhoz Pythonban; N-dimenziós tömböket biztosít, amelyek teljesítményesebbek, mint a Python listák. A NumPy tömbökkel végzett munka során az egyik gyakori művelet a maximális érték megtalálása a tömbben. Néha azonban érdemes megkeresni azt az indexet, amelynél a maximális érték előfordul.
Az argmax() függvény segít megtalálni a maximum indexét mind az egydimenziós, mind a többdimenziós tömbökben. Folytassuk, hogy megtanuljuk, hogyan működik.
Tartalomjegyzék
Hogyan lehet megtalálni a maximális elem indexét egy NumPy tömbben
Az oktatóanyag követéséhez telepítenie kell a Pythont és a NumPy-t. Egy Python REPL elindításával vagy egy Jupyter notebook elindításával kódolhat.
Először is importáljuk a NumPy-t a szokásos álnéven np.
import numpy as np
Használhatja a NumPy max() függvényt, hogy megkapja a maximális értéket egy tömbben (opcionálisan egy adott tengely mentén).
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.max(array_1)) # Output 10
Ebben az esetben az np.max(tömb_1) 10-et ad vissza, ami helyes.
Tegyük fel, hogy meg szeretné találni azt az indexet, amelynél a tömbben a maximális érték előfordul. A következő kétlépéses megközelítést alkalmazhatja:
A tömb_1-ben a 10-es maximális érték a 4-es indexnél fordul elő, nulla indexelést követően. Az első elem a 0 indexnél van; a második elem az 1-es indexnél van, és így tovább.
Ha meg szeretné keresni azt az indexet, amelynél a maximum előfordul, használja a NumPy where() függvényt. Az np.where(condition) az összes olyan index tömbjét adja vissza, ahol a feltétel igaz.
Be kell érintenie a tömböt, és hozzá kell férnie az első indexnél található elemhez. A maximális érték meghatározásához a feltételt tömb_1==10 értékre állítjuk be; ne feledje, hogy a 10 a tömb_1 maximális értéke.
print(int(np.where(array_1==10)[0])) # Output 4
Az np.where() függvényt csak a feltétellel használtuk, de nem ez a javasolt módszer a függvény használatához.
📑 Megjegyzés: NumPy where() függvény:
np.where(feltétel,x,y) a következőt adja vissza:
– x elemei, ha a feltétel igaz, és
– y elemei, ha a feltétel False.
Ezért az np.max() és np.where() függvények láncolásával megtalálhatjuk a maximális elemet, majd azt az indexet, amelynél előfordul.
A fenti kétlépéses folyamat helyett a NumPy argmax() függvényt használhatjuk a tömb maximális elemének indexének lekéréséhez.
A NumPy argmax() függvény szintaxisa
A NumPy argmax() függvény használatának általános szintaxisa a következő:
np.argmax(array,axis,out) # we've imported numpy under the alias np
A fenti szintaxisban:
- A tömb bármely érvényes NumPy tömb.
- tengely egy opcionális paraméter. Ha többdimenziós tömbökkel dolgozik, az tengely paraméterrel megkeresheti a maximum indexét egy adott tengely mentén.
- Az out egy másik opcionális paraméter. Az out paramétert beállíthatja egy NumPy tömbbe az argmax() függvény kimenetének tárolására.
Megjegyzés: A NumPy 1.22.0-s verziójától kezdve van egy további keepdims paraméter. Ha az argmax() függvényhívásban megadjuk az tengely paraméterét, a tömb a tengely mentén csökken. De ha a keepdims paramétert True értékre állítja, akkor a visszaadott kimenet ugyanolyan alakú lesz, mint a bemeneti tömb.
A NumPy argmax() használata a maximális elem indexének megkereséséhez
#1. Használjuk a NumPy argmax() függvényt, hogy megkeressük a tömb_1 maximális elemének indexét.
array_1 = np.array([1,5,7,2,10,9,8,4]) print(np.argmax(array_1)) # Output 4
Az argmax() függvény 4-et ad vissza, ami helyes! ✅
#2. Ha újradefiniáljuk a tömb_1-et úgy, hogy a 10 kétszer forduljon elő, az argmax() függvény csak az első előfordulás indexét adja vissza.
array_1 = np.array([1,5,7,2,10,10,8,4]) print(np.argmax(array_1)) # Output 4
A többi példa esetében a tömb_1 elemeit használjuk, amelyeket az 1. példában definiáltunk.
A NumPy argmax() használata a maximális elem indexének megkeresésére egy 2D tömbben
Alakítsuk át a NumPy tömb_1 tömbjét egy kétdimenziós tömbbé, amely két sorból és négy oszlopból áll.
array_2 = array_1.reshape(2,4) print(array_2) # Output [[ 1 5 7 2] [10 9 8 4]]
Kétdimenziós tömb esetén a 0-s tengely a sorokat, az 1-es tengely pedig az oszlopokat jelöli. A NumPy tömbök nulla indexelést követnek. Tehát a NumPy tömb tömb_2 sorainak és oszlopainak indexei a következők:
Most hívjuk meg az argmax() függvényt a kétdimenziós tömbön, tömb_2.
print(np.argmax(array_2)) # Output 4
Annak ellenére, hogy a kétdimenziós tömbön meghívtuk az argmax() függvényt, akkor is 4-et ad vissza. Ez megegyezik az előző szakaszban szereplő egydimenziós tömb, tömb_1 kimenetével.
Miért történik ez? 🤔
Ennek az az oka, hogy nem adtunk meg értéket a tengelyparaméterhez. Ha ez a tengelyparaméter nincs beállítva, akkor alapértelmezés szerint az argmax() függvény a maximális elem indexét adja vissza a lapított tömb mentén.
Mi az a lapított tömb? Ha létezik egy d1 x d2 x … x dN alakú N-dimenziós tömb, ahol d1, d2, legfeljebb dN a tömb méretei az N dimenzió mentén, akkor a lapított tömb egy hosszú, egydimenziós méretű tömb. d1 * d2 * … * dN.
Ha ellenőrizni szeretné, hogyan néz ki a lapított tömb a tömb_2 esetén, hívja meg a flatten() metódust, az alábbiak szerint:
array_2.flatten() # Output array([ 1, 5, 7, 2, 10, 9, 8, 4])
Maximális elem indexe a sorok mentén (tengely = 0)
Keressük meg a maximális elem indexét a sorok mentén (tengely = 0).
np.argmax(array_2,axis=0) # Output array([1, 1, 1, 1])
Ezt a kimenetet kissé nehéz lehet megérteni, de meg fogjuk érteni, hogyan működik.
Az tengely paramétert nullára állítottuk (tengely = 0), mivel szeretnénk megtalálni a sorok mentén a maximális elem indexét. Ezért az argmax() függvény azt a sorszámot adja vissza, amelyben a maximális elem előfordul – mindhárom oszlop esetében.
Vizualizáljuk ezt a jobb megértés érdekében.
A fenti diagramból és az argmax() kimenetből a következőket kapjuk:
- A 0 indexnél lévő első oszlop esetében a maximális 10 érték a második sorban található, az index = 1-nél.
- Az 1-es index második oszlopában a 9-es maximális érték a második sorban található, az index = 1-nél.
- A 2. és 3. index harmadik és negyedik oszlopában a 8. és 4. maximális érték a második sorban található, az index = 1 értéknél.
Pontosan ezért van a kimeneti tömbünk ([1, 1, 1, 1]), mert a maximális elem a sorok mentén a második sorban található (minden oszlop esetében).
A maximális elem indexe az oszlopok mentén (tengely = 1)
Ezután az argmax() függvény segítségével keressük meg a maximális elem indexét az oszlopok mentén.
Futtassa a következő kódrészletet, és figyelje meg a kimenetet.
np.argmax(array_2,axis=1)
array([2, 0])
Tudod elemezni a kimenetet?
Az oszlopok mentén a maximális elem indexének kiszámításához tengely = 1-et állítottunk be.
Az argmax() függvény minden sorhoz azt az oszlopszámot adja vissza, amelyben a maximális érték előfordul.
Íme egy vizuális magyarázat:
A fenti diagramból és az argmax() kimenetből a következőket kapjuk:
- Az első sorban a 0 indexnél a maximális 7-es érték a harmadik oszlopban található, az index = 2-nél.
- A második sorban az 1. indexnél a maximális 10-es érték az első oszlopban található, az index = 0-nál.
Remélem most már érted, mi a kimenet, array([2, 0]) azt jelenti.
Az opcionális out paraméter használata a NumPy argmax()-ban
Használhatja az opcionális out paramétert a NumPy argmax() függvényben, hogy a kimenetet egy NumPy tömbben tárolja.
Inicializáljunk egy nullákból álló tömböt az előző argmax() függvényhívás kimenetének tárolásához – hogy megtaláljuk a maximum indexét az oszlopok mentén (tengely= 1).
out_arr = np.zeros((2,)) print(out_arr) [0. 0.]
Most nézzük meg újra azt a példát, hogy megtaláljuk a maximális elem indexét az oszlopok mentén (tengely = 1), és állítsuk be az out-t out_arr értékre, amit fent definiáltunk.
np.argmax(array_2,axis=1,out=out_arr)
Látjuk, hogy a Python értelmező TypeError-t dob, mivel az out_arr alapértelmezés szerint egy float tömbbé lett inicializálva.
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds) 56 try: ---> 57 return bound(*args, **kwds) 58 except TypeError: TypeError: Cannot cast array data from dtype('float64') to dtype('int64') according to the rule 'safe'
Ezért, amikor a kimeneti tömbbe állítja be az out paramétert, fontos, hogy a kimeneti tömb alakja és adattípusa megfelelő legyen. Mivel a tömb indexei mindig egész számok, a kimeneti tömb meghatározásakor a dtype paramétert int-re kell állítani.
out_arr = np.zeros((2,),dtype=int) print(out_arr) # Output [0 0]
Most már hívhatjuk az argmax() függvényt az axis és az out paraméterekkel is, és ezúttal hiba nélkül fut.
np.argmax(array_2,axis=1,out=out_arr)
Az argmax() függvény kimenete most már elérhető az out_arr tömbben.
print(out_arr) # Output [2 0]
Következtetés
Remélem, ez az oktatóanyag segített megérteni a NumPy argmax() függvény használatát. A kódpéldákat futtathatja egy Jupyter notebookban.
Tekintsük át a tanultakat.
- A NumPy argmax() függvény a tömb maximális elemének indexét adja vissza. Ha a maximális elem többször előfordul egy a tömbben, akkor az np.argmax(a) az elem első előfordulásának indexét adja vissza.
- Ha többdimenziós tömbökkel dolgozik, az opcionális tengelyparaméterrel lekérheti a maximális elem indexét egy adott tengely mentén. Például egy kétdimenziós tömbben: az axis = 0 és az axis = 1 beállításával megkaphatja a maximális elem indexét a sorok, illetve az oszlopok mentén.
- Ha a visszaadott értéket egy másik tömbben szeretné tárolni, beállíthatja az opcionális out paramétert a kimeneti tömbre. A kimeneti tömbnek azonban kompatibilis alakúnak és adattípusúnak kell lennie.
Ezután tekintse meg a Python-készletekről szóló részletes útmutatót.