A Google JAX vagy Just After Execution egy keretrendszer, amelyet a Google fejlesztett ki a gépi tanulási feladatok felgyorsítására.
Tekinthetjük Python könyvtárának, amely segít a gyorsabb feladatvégrehajtásban, a tudományos számításokban, a függvényátalakításokban, a mély tanulásban, a neurális hálózatokban és még sok másban.
Tartalomjegyzék
A Google JAX-ről
A Python legalapvetőbb számítási csomagja a NumPy csomag, amely tartalmazza az összes funkciót, például az aggregációt, a vektorműveleteket, a lineáris algebrát, az n-dimenziós tömb- és mátrixmanipulációkat, valamint sok más fejlett funkciót.
Mi lenne, ha tovább gyorsítanánk a NumPy használatával végzett számításokat – különösen hatalmas adatkészletek esetén?
Van valami, ami egyformán jól működhet különböző típusú processzorokon, például GPU-n vagy TPU-n, kódmódosítás nélkül?
Mi a helyzet, ha a rendszer automatikusan és hatékonyabban tudna komponálható függvénytranszformációkat végrehajtani?
A Google JAX egy könyvtár (vagy keretrendszer, ahogy a Wikipédia mondja), amely pontosan ezt teszi, és talán még sokkal többet. A teljesítmény optimalizálására és a gépi tanulási (ML) és a mély tanulási feladatok hatékony végrehajtására készült. A Google JAX a következő átalakítási funkciókat kínálja, amelyek egyedivé teszik más ML-könyvtárak közül, és segítik a fejlett tudományos számításokat a mély tanuláshoz és a neurális hálózatokhoz:
- Automatikus megkülönböztetés
- Automatikus vektorizálás
- Automatikus párhuzamosítás
- Just-in-time (JIT) összeállítás
A Google JAX egyedi funkciói
Az összes transzformáció XLA-t (Accelerated Linear Algebra) használ a nagyobb teljesítmény és memóriaoptimalizálás érdekében. Az XLA egy tartományspecifikus optimalizáló fordítómotor, amely lineáris algebrát hajt végre, és felgyorsítja a TensorFlow modelleket. Az XLA használata a Python-kódon nem igényel jelentős kódmódosítást!
Vizsgáljuk meg részletesen ezeket a funkciókat.
A Google JAX szolgáltatásai
A Google JAX fontos összeállítható átalakítási funkciókat tartalmaz a teljesítmény javítása és a mélytanulási feladatok hatékonyabb végrehajtása érdekében. Például az automatikus differenciálás egy függvény gradiensének lekéréséhez és tetszőleges sorrendű származékok kereséséhez. Hasonlóképpen, az automatikus párhuzamosítás és a JIT több feladat párhuzamos végrehajtásához. Ezek az átalakítások kulcsfontosságúak az olyan alkalmazásokban, mint a robotika, a játék, és még a kutatás is.
Az összeállítható transzformációs függvény egy tiszta függvény, amely egy adathalmazt egy másik formába alakít át. Összeállíthatónak nevezzük őket, mivel önállóak (azaz ezeknek a függvényeknek nincs függősége a program többi részétől), és állapot nélküliek (azaz ugyanaz a bemenet mindig ugyanazt a kimenetet eredményezi).
Y(x) = T: (f(x))
A fenti egyenletben f(x) az eredeti függvény, amelyre a transzformációt alkalmazzuk. Y(x) az eredő függvény a transzformáció alkalmazása után.
Például, ha van egy ‘total_bill_amt’ nevű függvénye, és az eredményt függvénytranszformációként szeretné elérni, egyszerűen használhatja a kívánt transzformációt, mondjuk gradienst (grad):
grad_total_bill = grad(teljes_számla_összeg)
A numerikus függvények grad()-hoz hasonló függvényekkel történő transzformációjával könnyen megkaphatjuk a magasabb rendű származékaikat, amelyeket széles körben használhatunk mélytanulási optimalizálási algoritmusokban, mint például a gradiens süllyedés, ezáltal gyorsabbá és hatékonyabbá téve az algoritmusokat. Hasonlóképpen, a jit() használatával Python programokat fordíthatunk le éppen időben (lustán).
#1. Automatikus megkülönböztetés
A Python az autograd függvényt használja a NumPy és a natív Python kód automatikus megkülönböztetésére. A JAX az autograd (azaz grad) módosított változatát használja, és az XLA-t (Accelerated Linear Algebra) kombinálja, hogy automatikus differenciálást hajtson végre, és bármilyen sorrendű származékokat keressen a GPU (Graphic Processing Units) és a TPU (Tensor Processing Units) számára.]
Gyors megjegyzés a TPU-ról, GPU-ról és CPU-ról: a CPU vagy a központi feldolgozó egység kezeli a számítógép összes műveletét. A GPU egy további processzor, amely növeli a számítási teljesítményt és csúcskategóriás műveleteket futtat. A TPU egy nagy teljesítményű egység, amelyet kifejezetten az összetett és nehéz munkaterhelésekhez fejlesztettek ki, mint például az AI és a mélytanulási algoritmusok.
Ugyanúgy, mint az autograd függvény, amely hurkok, rekurziók, elágazások és így tovább tud különbséget tenni, a JAX a grad() függvényt használja a fordított módú színátmenetekhez (backpropagation). Ezenkívül a függvényt bármely sorrendtől megkülönböztethetjük a grad használatával:
grad(grad(grad(sin θ))) (1.0)
Magasabb rendű automatikus megkülönböztetés
Mint korábban említettük, a grad nagyon hasznos egy függvény parciális deriváltjainak megtalálásában. Egy részleges derivált segítségével kiszámíthatjuk egy költségfüggvény gradiens süllyedését a neurális hálózat paramétereihez képest mély tanulásban, hogy minimalizáljuk a veszteségeket.
Részleges derivált számítása
Tegyük fel, hogy egy függvénynek több változója van, x, y és z. Az egyik változó deriváltjának megtalálását a többi változó állandó tartása mellett parciális deriváltnak nevezzük. Tegyük fel, hogy van egy függvényünk,
f(x,y,z) = x + 2y + z2
Példa a részleges derivált bemutatására
Az x parciális deriváltja ∂f/∂x lesz, ami megmondja, hogyan változik egy függvény egy változóra, ha a többi állandó. Ha ezt manuálisan hajtjuk végre, akkor írnunk kell egy programot a megkülönböztetésre, alkalmaznunk kell minden változóra, majd ki kell számítanunk a gradiens süllyedését. Ez összetett és időigényes üggyé válna több változó esetében.
Az automatikus differenciálás a függvényt elemi műveletek halmazára bontja, például +, -, *, / vagy sin, cos, tan, exp stb., majd a láncszabályt alkalmazza a derivált kiszámításához. Ezt előre és hátramenetben is megtehetjük.
Ez nem az! Mindezek a számítások olyan gyorsan megtörténnek (jó, gondoljon a fentiekhez hasonló millió számításra, és az időbe telhet!). Az XLA gondoskodik a sebességről és a teljesítményről.
#2. Gyorsított lineáris algebra
Vegyük az előző egyenletet. XLA nélkül a számítás három (vagy több) kernelt vesz igénybe, ahol minden kernel kisebb feladatot fog végrehajtani. Például,
Kernel k1 –> x * 2y (szorzás)
k2 –> x * 2y + z (összeadás)
k3 –> Redukció
Ha ugyanazt a feladatot az XLA hajtja végre, egyetlen kernel gondoskodik az összes közbenső műveletről, összeolvasztva azokat. Az elemi műveletek közbenső eredményei a memóriában való tárolás helyett streamingre kerülnek, így memóriát takarítanak meg és növelik a sebességet.
#3. Just-in-time összeállítás
A JAX belsőleg az XLA fordítót használja a végrehajtás sebességének növelésére. Az XLA növelheti a CPU, a GPU és a TPU sebességét. Mindez a JIT kódvégrehajtással lehetséges. Ennek használatához használhatjuk a jit-et importáláson keresztül:
from jax import jit def my_function(x): …………some lines of code my_function_jit = jit(my_function)
Egy másik módszer a jit díszítése a függvénydefiníció fölé:
@jit def my_function(x): …………some lines of code
Ez a kód sokkal gyorsabb, mert az átalakítás a kód lefordított verzióját adja vissza a hívónak, nem pedig a Python értelmezőt. Ez különösen hasznos vektor bemeneteknél, például tömböknél és mátrixoknál.
Ugyanez igaz az összes létező python függvényre is. Például a NumPy csomag függvényei. Ebben az esetben a jax.numpy fájlt jnp-ként kell importálnunk a NumPy helyett:
import jax import jax.numpy as jnp x = jnp.array([[1,2,3,4], [5,6,7,8]])
Ha ezt megtette, a DeviceArray nevű alapvető JAX tömbobjektum lecseréli a szabványos NumPy tömböt. A DeviceArray lusta – az értékeket a gyorsítóban tartják, amíg szükség van rá. Ez azt is jelenti, hogy a JAX program nem várja meg, hogy az eredmények visszatérjenek a hívó (Python) programhoz, így aszinkron küldést követnek.
#4. Automatikus vektorizálás (vmap)
Egy tipikus gépi tanulási világban egymillió vagy több adatpontot tartalmazó adatkészleteink vannak. Valószínűleg ezeken az adatpontokon vagy a legtöbben elvégeznénk néhány számítást vagy manipulációt – ami nagyon idő- és memóriaigényes feladat! Például, ha meg szeretné találni az adathalmaz minden egyes adatpontjának négyzetét, akkor először egy hurok létrehozása és a négyzet egyesével történő felvétele jut eszébe – argh!
Ha ezeket a pontokat vektorként hozzuk létre, akkor az összes négyzetet egy menetben meg tudjuk csinálni, ha vektor- vagy mátrix-manipulációkat hajtunk végre az adatpontokon kedvenc NumPy-nkkal. És ha a programod ezt automatikusan megtehetné – kérhetsz még valamit? Pontosan ezt csinálja a JAX! Automatikusan vektorizálhatja az összes adatpontot, így könnyedén végrehajthat rajtuk bármilyen műveletet – így az algoritmusok sokkal gyorsabbak és hatékonyabbak.
A JAX a vmap függvényt használja az automatikus vektorizáláshoz. Tekintsük a következő tömböt:
x = jnp.array([1,2,3,4,5,6,7,8,9,10]) y = jnp.square(x)
A fentiek végrehajtásával a négyzetes metódus a tömb minden pontjára végrehajtódik. De ha a következőket teszi:
vmap(jnp.square(x))
A metódus négyzet csak egyszer fog lefutni, mert az adatpontok a függvény végrehajtása előtt a vmap metódussal automatikusan vektorizálásra kerülnek, a hurok pedig le van tolva a művelet elemi szintjére – így skaláris szorzás helyett mátrixszorzást eredményez, így jobb teljesítményt nyújt. .
#5. SPMD programozás (pmap)
Az SPMD – vagy a Single Program Multiple Data programozás elengedhetetlen a mély tanulási környezetben – gyakran ugyanazokat a funkciókat kell alkalmazni a több GPU-n vagy TPU-n található különböző adatkészleteken. A JAX pumpa nevű funkcióval rendelkezik, amely lehetővé teszi a párhuzamos programozást több GPU-n vagy bármilyen gyorsítón. A JIT-hez hasonlóan a pmap-ot használó programokat az XLA fordítja le, és egyidejűleg hajtja végre a rendszerben. Ez az automatikus párhuzamosítás mind az előre, mind a visszirányú számításoknál működik.
Hogyan működik a pmap
Több transzformációt is alkalmazhatunk egy menetben, tetszőleges sorrendben bármely függvényen:
pmap(vmap(jit(grad(f(x))))
Többféle összeállítható átalakítás
A Google JAX korlátai
A Google JAX fejlesztői jól átgondolták a mélytanulási algoritmusok felgyorsítását, miközben bevezették ezeket a fantasztikus átalakításokat. A tudományos számítási függvények és csomagok a NumPy vonalain vannak, így nem kell aggódnia a tanulási görbe miatt. A JAX-nek azonban a következő korlátai vannak:
- A Google JAX még a fejlesztés korai szakaszában jár, és bár fő célja a teljesítményoptimalizálás, nem nyújt sok előnyt a CPU számítástechnika számára. Úgy tűnik, hogy a NumPy jobban teljesít, és a JAX használata csak növeli a költségeket.
- A JAX még mindig a kutatási szakaszban vagy korai szakaszában van, és további finomhangolásra van szüksége, hogy elérje az olyan keretrendszerek infrastrukturális szabványait, mint a TensorFlow, amelyek már megalapozottabbak és több előre meghatározott modellel, nyílt forráskódú projekttel és tananyaggal rendelkeznek.
- Jelenleg a JAX nem támogatja a Windows operációs rendszert – a működéséhez virtuális gépre van szükség.
- A JAX csak tiszta funkciókon működik – azokon, amelyeknek nincs mellékhatása. A mellékhatásokkal járó funkciókhoz a JAX nem biztos, hogy jó választás.
A JAX telepítése Python környezetben
Ha van python beállítása a rendszeren, és a JAX-et szeretné futtatni a helyi gépen (CPU), használja a következő parancsokat:
pip install --upgrade pip pip install --upgrade "jax[cpu]"
Ha a Google JAX-et GPU-n vagy TPU-n szeretné futtatni, kövesse az alábbi utasításokat GitHub JAX oldalon. A Python beállításához keresse fel a python hivatalos letöltések oldalon.
Következtetés
A Google JAX kiválóan alkalmas hatékony mélytanulási algoritmusok, robotika és kutatások írásához. A korlátok ellenére széles körben használják más keretrendszerekkel, mint például a Haiku, a Len és még sok más. Értékelni fogja, mit csinál a JAX, amikor programokat futtat, és láthatja az időbeli különbségeket a kód végrehajtása során JAX-szal és anélkül. Kezdheti azzal, hogy elolvassa a hivatalos Google JAX dokumentációami elég átfogó.