Mi az a Google JAX? Minden, amit tudnod kell

A Google JAX vagy Just After Execution egy keretrendszer, amelyet a Google fejlesztett ki a gépi tanulási feladatok felgyorsítására.

Tekinthetjük Python könyvtárának, amely segít a gyorsabb feladatvégrehajtásban, a tudományos számításokban, a függvényátalakításokban, a mély tanulásban, a neurális hálózatokban és még sok másban.

A Google JAX-ről

A Python legalapvetőbb számítási csomagja a NumPy csomag, amely tartalmazza az összes funkciót, például az aggregációt, a vektorműveleteket, a lineáris algebrát, az n-dimenziós tömb- és mátrixmanipulációkat, valamint sok más fejlett funkciót.

Mi lenne, ha tovább gyorsítanánk a NumPy használatával végzett számításokat – különösen hatalmas adatkészletek esetén?

Van valami, ami egyformán jól működhet különböző típusú processzorokon, például GPU-n vagy TPU-n, kódmódosítás nélkül?

Mi a helyzet, ha a rendszer automatikusan és hatékonyabban tudna komponálható függvénytranszformációkat végrehajtani?

A Google JAX egy könyvtár (vagy keretrendszer, ahogy a Wikipédia mondja), amely pontosan ezt teszi, és talán még sokkal többet. A teljesítmény optimalizálására és a gépi tanulási (ML) és a mély tanulási feladatok hatékony végrehajtására készült. A Google JAX a következő átalakítási funkciókat kínálja, amelyek egyedivé teszik más ML-könyvtárak közül, és segítik a fejlett tudományos számításokat a mély tanuláshoz és a neurális hálózatokhoz:

  • Automatikus megkülönböztetés
  • Automatikus vektorizálás
  • Automatikus párhuzamosítás
  • Just-in-time (JIT) összeállítás

A Google JAX egyedi funkciói

Az összes transzformáció XLA-t (Accelerated Linear Algebra) használ a nagyobb teljesítmény és memóriaoptimalizálás érdekében. Az XLA egy tartományspecifikus optimalizáló fordítómotor, amely lineáris algebrát hajt végre, és felgyorsítja a TensorFlow modelleket. Az XLA használata a Python-kódon nem igényel jelentős kódmódosítást!

Vizsgáljuk meg részletesen ezeket a funkciókat.

A Google JAX szolgáltatásai

A Google JAX fontos összeállítható átalakítási funkciókat tartalmaz a teljesítmény javítása és a mélytanulási feladatok hatékonyabb végrehajtása érdekében. Például az automatikus differenciálás egy függvény gradiensének lekéréséhez és tetszőleges sorrendű származékok kereséséhez. Hasonlóképpen, az automatikus párhuzamosítás és a JIT több feladat párhuzamos végrehajtásához. Ezek az átalakítások kulcsfontosságúak az olyan alkalmazásokban, mint a robotika, a játék, és még a kutatás is.

Az összeállítható transzformációs függvény egy tiszta függvény, amely egy adathalmazt egy másik formába alakít át. Összeállíthatónak nevezzük őket, mivel önállóak (azaz ezeknek a függvényeknek nincs függősége a program többi részétől), és állapot nélküliek (azaz ugyanaz a bemenet mindig ugyanazt a kimenetet eredményezi).

Y(x) = T: (f(x))

A fenti egyenletben f(x) az eredeti függvény, amelyre a transzformációt alkalmazzuk. Y(x) az eredő függvény a transzformáció alkalmazása után.

  DevOps bevezetés kezdőknek

Például, ha van egy ‘total_bill_amt’ nevű függvénye, és az eredményt függvénytranszformációként szeretné elérni, egyszerűen használhatja a kívánt transzformációt, mondjuk gradienst (grad):

grad_total_bill = grad(teljes_számla_összeg)

A numerikus függvények grad()-hoz hasonló függvényekkel történő transzformációjával könnyen megkaphatjuk a magasabb rendű származékaikat, amelyeket széles körben használhatunk mélytanulási optimalizálási algoritmusokban, mint például a gradiens süllyedés, ezáltal gyorsabbá és hatékonyabbá téve az algoritmusokat. Hasonlóképpen, a jit() használatával Python programokat fordíthatunk le éppen időben (lustán).

#1. Automatikus megkülönböztetés

A Python az autograd függvényt használja a NumPy és a natív Python kód automatikus megkülönböztetésére. A JAX az autograd (azaz grad) módosított változatát használja, és az XLA-t (Accelerated Linear Algebra) kombinálja, hogy automatikus differenciálást hajtson végre, és bármilyen sorrendű származékokat keressen a GPU (Graphic Processing Units) és a TPU (Tensor Processing Units) számára.]

Gyors megjegyzés a TPU-ról, GPU-ról és CPU-ról: a CPU vagy a központi feldolgozó egység kezeli a számítógép összes műveletét. A GPU egy további processzor, amely növeli a számítási teljesítményt és csúcskategóriás műveleteket futtat. A TPU egy nagy teljesítményű egység, amelyet kifejezetten az összetett és nehéz munkaterhelésekhez fejlesztettek ki, mint például az AI és a mélytanulási algoritmusok.

Ugyanúgy, mint az autograd függvény, amely hurkok, rekurziók, elágazások és így tovább tud különbséget tenni, a JAX a grad() függvényt használja a fordított módú színátmenetekhez (backpropagation). Ezenkívül a függvényt bármely sorrendtől megkülönböztethetjük a grad használatával:

grad(grad(grad(sin θ))) (1.0)

Magasabb rendű automatikus megkülönböztetés

Mint korábban említettük, a grad nagyon hasznos egy függvény parciális deriváltjainak megtalálásában. Egy részleges derivált segítségével kiszámíthatjuk egy költségfüggvény gradiens süllyedését a neurális hálózat paramétereihez képest mély tanulásban, hogy minimalizáljuk a veszteségeket.

Részleges derivált számítása

Tegyük fel, hogy egy függvénynek több változója van, x, y és z. Az egyik változó deriváltjának megtalálását a többi változó állandó tartása mellett parciális deriváltnak nevezzük. Tegyük fel, hogy van egy függvényünk,

f(x,y,z) = x + 2y + z2

Példa a részleges derivált bemutatására

Az x parciális deriváltja ∂f/∂x lesz, ami megmondja, hogyan változik egy függvény egy változóra, ha a többi állandó. Ha ezt manuálisan hajtjuk végre, akkor írnunk kell egy programot a megkülönböztetésre, alkalmaznunk kell minden változóra, majd ki kell számítanunk a gradiens süllyedését. Ez összetett és időigényes üggyé válna több változó esetében.

Az automatikus differenciálás a függvényt elemi műveletek halmazára bontja, például +, -, *, / vagy sin, cos, tan, exp stb., majd a láncszabályt alkalmazza a derivált kiszámításához. Ezt előre és hátramenetben is megtehetjük.

Ez nem az! Mindezek a számítások olyan gyorsan megtörténnek (jó, gondoljon a fentiekhez hasonló millió számításra, és az időbe telhet!). Az XLA gondoskodik a sebességről és a teljesítményről.

  Hogyan lehet kikapcsolni a képernyőt egy laptopon

#2. Gyorsított lineáris algebra

Vegyük az előző egyenletet. XLA nélkül a számítás három (vagy több) kernelt vesz igénybe, ahol minden kernel kisebb feladatot fog végrehajtani. Például,

Kernel k1 –> x * 2y (szorzás)

k2 –> x * 2y + z (összeadás)

k3 –> Redukció

Ha ugyanazt a feladatot az XLA hajtja végre, egyetlen kernel gondoskodik az összes közbenső műveletről, összeolvasztva azokat. Az elemi műveletek közbenső eredményei a memóriában való tárolás helyett streamingre kerülnek, így memóriát takarítanak meg és növelik a sebességet.

#3. Just-in-time összeállítás

A JAX belsőleg az XLA fordítót használja a végrehajtás sebességének növelésére. Az XLA növelheti a CPU, a GPU és a TPU sebességét. Mindez a JIT kódvégrehajtással lehetséges. Ennek használatához használhatjuk a jit-et importáláson keresztül:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Egy másik módszer a jit díszítése a függvénydefiníció fölé:

@jit
def my_function(x):
	…………some lines of code

Ez a kód sokkal gyorsabb, mert az átalakítás a kód lefordított verzióját adja vissza a hívónak, nem pedig a Python értelmezőt. Ez különösen hasznos vektor bemeneteknél, például tömböknél és mátrixoknál.

Ugyanez igaz az összes létező python függvényre is. Például a NumPy csomag függvényei. Ebben az esetben a jax.numpy fájlt jnp-ként kell importálnunk a NumPy helyett:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Ha ezt megtette, a DeviceArray nevű alapvető JAX tömbobjektum lecseréli a szabványos NumPy tömböt. A DeviceArray lusta – az értékeket a gyorsítóban tartják, amíg szükség van rá. Ez azt is jelenti, hogy a JAX program nem várja meg, hogy az eredmények visszatérjenek a hívó (Python) programhoz, így aszinkron küldést követnek.

#4. Automatikus vektorizálás (vmap)

Egy tipikus gépi tanulási világban egymillió vagy több adatpontot tartalmazó adatkészleteink vannak. Valószínűleg ezeken az adatpontokon vagy a legtöbben elvégeznénk néhány számítást vagy manipulációt – ami nagyon idő- és memóriaigényes feladat! Például, ha meg szeretné találni az adathalmaz minden egyes adatpontjának négyzetét, akkor először egy hurok létrehozása és a négyzet egyesével történő felvétele jut eszébe – argh!

Ha ezeket a pontokat vektorként hozzuk létre, akkor az összes négyzetet egy menetben meg tudjuk csinálni, ha vektor- vagy mátrix-manipulációkat hajtunk végre az adatpontokon kedvenc NumPy-nkkal. És ha a programod ezt automatikusan megtehetné – kérhetsz még valamit? Pontosan ezt csinálja a JAX! Automatikusan vektorizálhatja az összes adatpontot, így könnyedén végrehajthat rajtuk bármilyen műveletet – így az algoritmusok sokkal gyorsabbak és hatékonyabbak.

A JAX a vmap függvényt használja az automatikus vektorizáláshoz. Tekintsük a következő tömböt:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

A fentiek végrehajtásával a négyzetes metódus a tömb minden pontjára végrehajtódik. De ha a következőket teszi:

vmap(jnp.square(x))

A metódus négyzet csak egyszer fog lefutni, mert az adatpontok a függvény végrehajtása előtt a vmap metódussal automatikusan vektorizálásra kerülnek, a hurok pedig le van tolva a művelet elemi szintjére – így skaláris szorzás helyett mátrixszorzást eredményez, így jobb teljesítményt nyújt. .

  Mely üzletek fogadják el a Samsung Pay fizetést?

#5. SPMD programozás (pmap)

Az SPMD – vagy a Single Program Multiple Data programozás elengedhetetlen a mély tanulási környezetben – gyakran ugyanazokat a funkciókat kell alkalmazni a több GPU-n vagy TPU-n található különböző adatkészleteken. A JAX pumpa nevű funkcióval rendelkezik, amely lehetővé teszi a párhuzamos programozást több GPU-n vagy bármilyen gyorsítón. A JIT-hez hasonlóan a pmap-ot használó programokat az XLA fordítja le, és egyidejűleg hajtja végre a rendszerben. Ez az automatikus párhuzamosítás mind az előre, mind a visszirányú számításoknál működik.

Hogyan működik a pmap

Több transzformációt is alkalmazhatunk egy menetben, tetszőleges sorrendben bármely függvényen:

pmap(vmap(jit(grad(f(x))))

Többféle összeállítható átalakítás

A Google JAX korlátai

A Google JAX fejlesztői jól átgondolták a mélytanulási algoritmusok felgyorsítását, miközben bevezették ezeket a fantasztikus átalakításokat. A tudományos számítási függvények és csomagok a NumPy vonalain vannak, így nem kell aggódnia a tanulási görbe miatt. A JAX-nek azonban a következő korlátai vannak:

  • A Google JAX még a fejlesztés korai szakaszában jár, és bár fő célja a teljesítményoptimalizálás, nem nyújt sok előnyt a CPU számítástechnika számára. Úgy tűnik, hogy a NumPy jobban teljesít, és a JAX használata csak növeli a költségeket.
  • A JAX még mindig a kutatási szakaszban vagy korai szakaszában van, és további finomhangolásra van szüksége, hogy elérje az olyan keretrendszerek infrastrukturális szabványait, mint a TensorFlow, amelyek már megalapozottabbak és több előre meghatározott modellel, nyílt forráskódú projekttel és tananyaggal rendelkeznek.
  • Jelenleg a JAX nem támogatja a Windows operációs rendszert – a működéséhez virtuális gépre van szükség.
  • A JAX csak tiszta funkciókon működik – azokon, amelyeknek nincs mellékhatása. A mellékhatásokkal járó funkciókhoz a JAX nem biztos, hogy jó választás.

A JAX telepítése Python környezetben

Ha van python beállítása a rendszeren, és a JAX-et szeretné futtatni a helyi gépen (CPU), használja a következő parancsokat:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Ha a Google JAX-et GPU-n vagy TPU-n szeretné futtatni, kövesse az alábbi utasításokat GitHub JAX oldalon. A Python beállításához keresse fel a python hivatalos letöltések oldalon.

Következtetés

A Google JAX kiválóan alkalmas hatékony mélytanulási algoritmusok, robotika és kutatások írásához. A korlátok ellenére széles körben használják más keretrendszerekkel, mint például a Haiku, a Len és még sok más. Értékelni fogja, mit csinál a JAX, amikor programokat futtat, és láthatja az időbeli különbségeket a kód végrehajtása során JAX-szal és anélkül. Kezdheti azzal, hogy elolvassa a hivatalos Google JAX dokumentációami elég átfogó.