Was ist Google JAX? Alles, was Sie wissen müssen

Google JAX, auch bekannt als Just After Execution, ist ein von Google entwickeltes Framework, das darauf abzielt, Aufgaben im Bereich des maschinellen Lernens zu beschleunigen.

Man kann es sich als eine Python-Bibliothek vorstellen, die dazu dient, die Ausführung von Prozessen wie wissenschaftlichen Berechnungen, Funktionstransformationen, Deep Learning, neuronalen Netzen und vielem mehr zu beschleunigen.

Google JAX im Detail

Das grundlegendste Paket für numerische Berechnungen in Python ist NumPy, das Funktionen wie Aggregationen, Vektoroperationen, lineare Algebra, die Manipulation von n-dimensionalen Arrays und Matrizen sowie viele weitere fortgeschrittene Funktionen bereitstellt.

Was wäre, wenn wir die Berechnungen, die wir mit NumPy durchführen, noch schneller gestalten könnten – insbesondere bei sehr großen Datensätzen?

Gibt es eine Möglichkeit, dass unsere Programme auf verschiedenen Prozessortypen wie GPU oder TPU ohne Änderungen am Code gleich gut funktionieren?

Wäre es nicht ideal, wenn ein System in der Lage wäre, zusammensetzbare Funktionstransformationen automatisch und effizienter durchzuführen?

Google JAX ist eine Bibliothek (oder ein Framework, wie Wikipedia es nennt), die genau das tut und möglicherweise noch viel mehr. Es wurde entwickelt, um die Leistung zu optimieren und Aufgaben im Bereich des maschinellen Lernens (ML) und des Deep Learnings effizient auszuführen. Google JAX bietet die folgenden Transformationsfunktionen, die es von anderen ML-Bibliotheken abheben und bei fortgeschrittenen wissenschaftlichen Berechnungen für Deep Learning und neuronale Netze helfen:

  • Automatische Differenzierung
  • Automatische Vektorisierung
  • Automatische Parallelisierung
  • Just-in-time (JIT)-Kompilierung

Die besonderen Funktionen von Google JAX

Alle Transformationen nutzen XLA (Accelerated Linear Algebra) für eine höhere Performance und Speicheroptimierung. XLA ist eine domänenspezifische, optimierende Compiler-Engine, die lineare Algebra ausführt und TensorFlow-Modelle beschleunigt. Die Nutzung von XLA mit Ihrem Python-Code erfordert keine grundlegenden Änderungen am Code!

Lassen Sie uns jede dieser Funktionen im Detail betrachten.

Die Funktionen von Google JAX

Google JAX zeichnet sich durch wichtige, zusammensetzbare Transformationsfunktionen aus, die die Leistung steigern und Deep-Learning-Aufgaben effizienter machen. Dazu gehört zum Beispiel die automatische Differenzierung, um den Gradienten einer Funktion zu bestimmen und Ableitungen beliebiger Ordnung zu ermitteln. Auch Autoparallelisierung und JIT ermöglichen die gleichzeitige Ausführung mehrerer Aufgaben. Diese Transformationen sind essenziell für Anwendungsbereiche wie Robotik, Spiele und Forschung.

Eine zusammensetzbare Transformationsfunktion ist eine reine Funktion, die einen Datensatz in eine andere Form umwandelt. Sie werden als zusammensetzbar bezeichnet, weil sie in sich abgeschlossen sind (d.h., diese Funktionen haben keine Abhängigkeiten zum Rest des Programms) und zustandslos sind (d.h., die gleiche Eingabe führt immer zur gleichen Ausgabe).

Y(x) = T: (f(x))

In der obigen Gleichung ist f(x) die Originalfunktion, auf die eine Transformation angewendet wird. Y(x) ist die Funktion, die nach der Transformation resultiert.

Wenn Sie zum Beispiel eine Funktion mit dem Namen „total_bill_amt“ haben und das Ergebnis als Funktionstransformation benötigen, können Sie einfach die gewünschte Transformation nutzen, z.B. den Gradienten (grad):

grad_total_bill = grad(total_bill_amt)

Durch die Transformation numerischer Funktionen mit Funktionen wie grad() können wir leicht ihre Ableitungen höherer Ordnung erhalten, die wir umfassend in Deep-Learning-Optimierungsalgorithmen wie dem Gradientenabstieg verwenden können, wodurch die Algorithmen schneller und effizienter werden. Ebenso können wir mit jit() Python-Programme zur Laufzeit (lazy) kompilieren.

#1. Automatische Differenzierung

Python nutzt die Autograd-Funktion, um NumPy und nativen Python-Code automatisch zu differenzieren. JAX verwendet eine modifizierte Version von Autograd (d.h. grad) und kombiniert diese mit XLA (Accelerated Linear Algebra), um eine automatische Differenzierung zu realisieren und Ableitungen beliebiger Ordnung für GPU (Graphic Processing Units) und TPU (Tensor Processing Units) zu finden.

Eine kurze Erläuterung zu TPU, GPU und CPU: Die CPU oder Central Processing Unit ist für die Verwaltung aller Prozesse auf dem Computer zuständig. Die GPU ist ein zusätzlicher Prozessor, der die Rechenleistung erhöht und anspruchsvolle Operationen ausführt. Die TPU ist eine leistungsstarke Einheit, die speziell für komplexe und rechenintensive Aufgaben wie KI und Deep-Learning-Algorithmen entwickelt wurde.

Ähnlich der Autograd-Funktion, die Schleifen, Rekursionen, Verzweigungen usw. differenzieren kann, verwendet JAX die Funktion grad(), um Gradienten im umgekehrten Modus (Backpropagation) zu bestimmen. Darüber hinaus können wir mit grad eine Funktion beliebiger Ordnung differenzieren:

grad(grad(grad(sin θ))) (1.0)

Automatische Differenzierung höherer Ordnung

Wie bereits erwähnt, ist grad sehr hilfreich, um die partiellen Ableitungen einer Funktion zu finden. Eine partielle Ableitung können wir verwenden, um den Gradientenabstieg einer Kostenfunktion in Bezug auf die Parameter eines neuronalen Netzwerks im Deep Learning zu berechnen und so Verluste zu minimieren.

Berechnung partieller Ableitungen

Nehmen wir an, eine Funktion hat mehrere Variablen, x, y und z. Die Bestimmung der Ableitung einer Variablen, während die anderen Variablen konstant gehalten werden, wird als partielle Ableitung bezeichnet. Gehen wir von folgender Funktion aus:

f(x,y,z) = x + 2y + z²

Beispiel zur Darstellung der partiellen Ableitung

Die partielle Ableitung von x ist ∂f/∂x, die uns angibt, wie sich eine Funktion für eine Variable verändert, wenn die anderen konstant sind. Wenn wir dies manuell durchführen, müssen wir ein Programm zur Differenzierung schreiben, es auf jede Variable anwenden und dann den Gradientenabstieg berechnen. Dies wäre bei mehreren Variablen eine komplexe und zeitaufwendige Angelegenheit.

Die automatische Differenzierung zerlegt die Funktion in eine Reihe elementarer Operationen wie +, -, *, /, sin, cos, tan, exp usw. und wendet dann die Kettenregel an, um die Ableitung zu berechnen. Dies können wir sowohl im Vorwärts- als auch im Rückwärtsmodus tun.

Das ist noch nicht alles! All diese Berechnungen laufen sehr schnell ab (denken Sie an Millionen von Berechnungen wie die oben genannte und die Zeit, die das in Anspruch nehmen könnte!). XLA kümmert sich um Geschwindigkeit und Leistung.

#2. Beschleunigte Lineare Algebra

Betrachten wir die vorherige Gleichung. Ohne XLA wären für die Berechnung drei (oder mehr) Kernel notwendig, wobei jeder Kernel eine kleinere Aufgabe übernimmt. Zum Beispiel:

Kernel k1 –> x * 2y (Multiplikation)

k2 –> x * 2y + z (Addition)

k3 –> Reduktion

Wenn die gleiche Aufgabe von XLA ausgeführt wird, kümmert sich ein einziger Kernel um alle Zwischenoperationen, indem er sie zusammenfasst. Die Zwischenergebnisse elementarer Operationen werden gestreamt, anstatt sie im Speicher abzulegen, was Speicherplatz spart und die Geschwindigkeit erhöht.

#3. Just-in-Time-Kompilierung

JAX nutzt intern den XLA-Compiler, um die Ausführungsgeschwindigkeit zu erhöhen. XLA kann die Geschwindigkeit von CPU, GPU und TPU steigern. All dies ist durch die JIT-Codeausführung möglich. Um dies zu nutzen, können wir jit per Import verwenden:

from jax import jit
def my_function(x):
	…………some lines of code
my_function_jit = jit(my_function)

Eine weitere Möglichkeit ist, jit über die Funktionsdefinition zu dekorieren:

@jit
def my_function(x):
	…………some lines of code

Dieser Code ist wesentlich schneller, da die Transformation die kompilierte Version des Codes an den Aufrufer zurückgibt, anstatt den Python-Interpreter zu nutzen. Dies ist besonders nützlich bei Vektoreingaben wie Arrays und Matrizen.

Dasselbe gilt auch für alle bestehenden Python-Funktionen. Zum Beispiel für Funktionen aus dem NumPy-Paket. In diesem Fall sollten wir jax.numpy als jnp anstelle von NumPy importieren:

import jax
import jax.numpy as jnp

x = jnp.array([[1,2,3,4], [5,6,7,8]])

Sobald dies geschehen ist, ersetzt das Kern-JAX-Array-Objekt namens DeviceArray das Standard-NumPy-Array. DeviceArray ist „lazy“ – die Werte werden im Beschleuniger gespeichert, bis sie benötigt werden. Dies bedeutet auch, dass das JAX-Programm nicht darauf wartet, dass Ergebnisse an das aufrufende Python-Programm zurückgegeben werden, und somit einem asynchronen Versand folgt.

#4. Automatische Vektorisierung (vmap)

In der typischen Welt des maschinellen Lernens haben wir Datensätze mit einer Million oder mehr Datenpunkten. Sehr wahrscheinlich würden wir einige Berechnungen oder Manipulationen an jedem oder den meisten dieser Datenpunkte durchführen – was eine sehr zeit- und speicherintensive Aufgabe ist! Wenn Sie beispielsweise das Quadrat jedes Datenpunktes im Datensatz berechnen wollen, wäre Ihr erster Gedanke, eine Schleife zu erstellen und das Quadrat nacheinander zu berechnen – argh!

Wenn wir diese Punkte als Vektoren darstellen, könnten wir alle Quadrate auf einmal erstellen, indem wir Vektor- oder Matrixmanipulationen an den Datenpunkten mit unserem beliebten NumPy durchführen. Und wenn Ihr Programm das automatisch tun könnte – was will man mehr? Genau das macht JAX! Es kann alle Ihre Datenpunkte automatisch vektorisieren, so dass Sie alle Operationen problemlos an ihnen durchführen können – wodurch Ihre Algorithmen wesentlich schneller und effizienter werden.

JAX verwendet die vmap-Funktion zur automatischen Vektorisierung. Betrachten wir das folgende Array:

x = jnp.array([1,2,3,4,5,6,7,8,9,10])
y = jnp.square(x)

Wenn Sie genau das Obige tun, wird die Quadratmethode für jeden Punkt im Array ausgeführt. Wenn Sie jedoch Folgendes tun:

vmap(jnp.square(x))

wird die square-Methode nur einmal ausgeführt, da die Datenpunkte nun automatisch mit der vmap-Methode vektorisiert werden, bevor die Funktion ausgeführt wird, und die Schleifenbildung wird auf die elementare Ebene der Operation reduziert – was zu einer Matrixmultiplikation anstelle einer skalaren Multiplikation und damit zu einer besseren Leistung führt.

#5. SPMD-Programmierung (pmap)

SPMD – oder Single Program Multiple Data-Programmierung ist in Deep-Learning-Kontexten essentiell – oft wendet man dieselben Funktionen auf unterschiedliche Datensätze an, die sich auf mehreren GPUs oder TPUs befinden. JAX verfügt über eine Funktion namens pmap, die eine parallele Programmierung auf mehreren GPUs oder einem beliebigen Beschleuniger ermöglicht. Ähnlich wie bei JIT werden Programme, die pmap nutzen, von XLA kompiliert und gleichzeitig auf allen Systemen ausgeführt. Diese automatische Parallelisierung funktioniert sowohl für Vorwärts- als auch für Rückwärtsberechnungen.

Wie funktioniert pmap

Wir können auch mehrere Transformationen gleichzeitig und in beliebiger Reihenfolge auf eine Funktion anwenden, z.B.:

pmap(vmap(jit(grad (f(x)))))

Mehrere zusammensetzbare Transformationen

Einschränkungen von Google JAX

Die Entwickler von Google JAX haben sorgfältig darüber nachgedacht, Deep-Learning-Algorithmen zu beschleunigen und gleichzeitig all diese großartigen Transformationen einzuführen. Die wissenschaftlichen Berechnungsfunktionen und Pakete entsprechen NumPy, so dass Sie sich keine Gedanken über die Lernkurve machen müssen. Dennoch hat JAX folgende Einschränkungen:

  • Google JAX befindet sich noch in einer frühen Entwicklungsphase und obwohl sein Hauptziel die Leistungsoptimierung ist, bietet es für CPU-Berechnungen keinen wesentlichen Vorteil. NumPy scheint eine bessere Leistung zu erzielen und der Einsatz von JAX könnte den Overhead sogar erhöhen.
  • JAX befindet sich noch in der Forschungs- bzw. frühen Entwicklungsphase und muss noch feiner abgestimmt werden, um die Infrastrukturstandards von Frameworks wie TensorFlow zu erreichen, die etablierter sind und über mehr vordefinierte Modelle, Open-Source-Projekte und Lernmaterial verfügen.
  • Aktuell unterstützt JAX kein Windows-Betriebssystem – Sie benötigen eine virtuelle Maschine, damit es funktioniert.
  • JAX funktioniert nur mit reinen Funktionen – also solchen, die keine Seiteneffekte haben. Für Funktionen mit Seiteneffekten ist JAX möglicherweise keine gute Wahl.

So installieren Sie JAX in Ihrer Python-Umgebung

Wenn Sie Python auf Ihrem System eingerichtet haben und JAX auf Ihrem lokalen Computer (CPU) ausführen möchten, verwenden Sie die folgenden Befehle:

pip install --upgrade pip
pip install --upgrade "jax[cpu]"

Wenn Sie Google JAX auf einer GPU oder TPU ausführen möchten, befolgen Sie die Anweisungen auf der GitHub-JAX-Seite. Um Python einzurichten, besuchen Sie die offizielle Python-Downloads-Seite.

Fazit

Google JAX eignet sich hervorragend, um effiziente Deep-Learning-Algorithmen, Anwendungen in der Robotik und Forschung zu entwickeln. Trotz der Einschränkungen wird es häufig in Verbindung mit anderen Frameworks wie Haiku, Flax und vielen mehr eingesetzt. Sie werden in der Lage sein einzuschätzen, was JAX bei der Ausführung von Programmen leistet, und die Zeitunterschiede bei der Ausführung von Code mit und ohne JAX zu erkennen. Sie können mit der Lektüre der offiziellen Google JAX-Dokumentation beginnen, die sehr umfangreich ist.