JAX:有望取代Tensorflow,谷歌出品的又一超高性能機器學習框架

前言

在機器學習框架方面,JAX是一個新生事物——儘管Tensorflow的競爭對手從技術上講已經在2018年後已經很完備,但直到最近JAX才開始在更廣泛的機器學習研究社區中獲得吸引力。

JAX到底是什麼?根據JAX官方介紹:

JAX是NumPy在CPU、GPU和TPU上的版本,具有高性能機器學習研究的強大自動微分(automatic differentiation)能力。


接下來,我們會具體認識JAX。


JAX:有望取代Tensorflow,谷歌出品的又一超高性能機器學習框架


基礎介紹

就像上面說的,JAX是加速器支持的numpy以及大部分scipy功能,帶有一些通用機器學習操作的便利函數。

我們舉個例子

<code>import jaximport jax.numpy as npdef gpu_backed_hidden_layer(x):    return jax.nn.relu(np.dot(W, x) + b)/<code>

您可以得到numpy精心設計的API,它從2006年就開始使用了,具有Tensorflow和PyTorch等現代ML工具的性能特徵。

JAX還包括通過jax.scipy來支持相當大一部分scipy項目:

<code>from jax.scipy.linalg import svdsingular_vectors, singular_values = svd(x)/<code>

儘管有加速器支持的numpy + scipy版本已經非常有用,但JAX還有一些其他的妙招。首先讓我們看看JAX對自動微分的廣泛支持。

自動微分·Autograd

Autograd是一個用於在numpy和原生python代碼上高效計算梯度的庫。Autograd恰好也是JAX的前身。儘管最初的autograd存儲庫不再被積極開發,但是在autograd上工作的大部分核心團隊已經開始全職從事JAX項目。

就像autograd, JAX允許對一個python函數的輸出求導,只需調用grad:

<code>from jax import graddef hidden_layer(x):    return jax.nn.relu(np.dot(W, x) + b)grad_hidden_layer = grad(hidden_layer)/<code>

您還可以通過本機的python控制結構進行區分——而不需要使用tf.cond:

<code>def absolute_value(x)    if x >= 0:        return x    else:        return -xgrad_absolute_value = grad(absolute_value)/<code>

JAX還支持獲取高階導數——grad函數可以任意連接:

<code>from jax.nn import tanh# grads all the way downprint(grad(grad(grad(tanh)))(1.0))/<code>

默認情況下,grad 為您提供了逆向模式梯度——這是計算梯度最常用的模式,它依賴於緩存激活來提高向後傳遞的效率。反模式差分是計算參數更新最有效的方法。但是,特別是在實現依賴於高階派生的優化方法時,它並不總是最佳選擇。JAX通過jacfwd和jacrev為逆向模式自動差分和正向模式自動差分提供了一流的支持:

<code>from jax import jacfwd, jacrevhessian_fn = jacfwd(jacrev(fn))/<code>

除了grad、jacfwd和jacrev之外,JAX還提供了一些實用程序,用於計算函數的線性逼近、定義自定義梯度操作,以及作為其自動微分支持的一部分。


加速線性代數·XLA

XLA (Accelerated Linear Algebra)是一個特定域的線性代數代碼編譯器,它是JAX將python和numpy表達式轉換成加速器支持的操作的基礎。

除了允許JAX將python + numpy代碼轉換為可以在加速器上運行的操作之外(就像我們在第一個示例中看到的那樣),XLA支持還允許JAX將多個操作融合到一個內核中。它在計算圖中尋找節點簇,這些節點簇可以被重寫以減少計算或中間變量的存儲。Tensorflow關於XLA的文檔使用以下示例來解釋問題可以從XLA編譯中受益的實例類型。

<code>def unoptimized_fn(x, y, z):  return np.sum(x + y * z)/<code>

在沒有XLA的情況下運行,這將作為3個獨立的內核運行——一個乘法、一個加法和一個加法減法。使用XLA運行時,這變成了一個負責所有這三個方面的內核,不需要存儲中間變量,從而節省了時間和內存。


向量化和並行性

雖然Autograd和XLA構成了JAX庫的核心,但是還有兩個JAX函數脫穎而出。你可以使用jax.vmap和jax.pmap用於向量化和基於spmd(單程序多數據)並行的pmap。

為了說明vmap的優點,我們將返回到我們的簡單稠密層的示例,它操作一個由向量x表示的示例。

<code># convention to distinguish between # jax.numpy and numpyimport numpy as onpdef hidden_layer(x):    return jax.nn.relu(np.dot(W, x + b)   print(hidden_layer(np.random.randn(128)).shape)# (128,)/<code>

我們已經編寫了隱含層來獲取單個向量輸入,但實際上我們幾乎總是批量處理輸入以利用向量化計算。使用JAX,您可以使用任何接受單個輸入的函數,並允許它使用JAX .vmap接受一批輸入:

<code>batch_hidden_layer = vmap(hidden_layer)print(batch_hidden_layer(onp.random.randn(32, 128)).shape)# (32, 128)/<code>

它的美妙之處在於,它意味著你或多或少地忽略了模型函數中的批處理維數,並且在你構造模型的時候,在你的頭腦中少了一個張量維數。

如果您有幾個輸入都應該向量化,或者您想沿著軸向量化而不是沿著軸0,您可以使用in_axes參數來指定。


<code>batch_hidden_layer = vmap(hidden_layer, in_axes=(0,))/<code>

JAX用於SPMD paralellism的實用程序,遵循非常類似的API。如果你有一臺4-gpu機器和4個例子,你可以使用pmap在每個設備上運行一個例子。

<code># first dimension must align with number of XLA-enabled devicesspmd_hidden_layer = pmap(hidden_layer)/<code>


和往常一樣,你可以隨心所欲地編寫函數:

<code># hypothetical setup for high-throughput inferenceoutputs = pmap(vmap(hidden_layer))(onp.random.randn(4, 32, 128))print(outputs.shape)# (4, 32, 128)/<code>


為什麼是JAX?

JAX不是因為它都比現有的機器學習框架更加乾淨,或者因為它是比Tensorflow PyTorch更好地設計的東西,而是因為它能讓我們更容易嘗試更多的想法以及探索更廣泛的空間。


分享到:


相關文章: