The latest

Original text: jax. Readthedocs. IO/en/latest/n…

JAX Quick Start


First answer a question: what is JAX?

Simply put, numpy is gPU-accelerated, autodiff enabled. As we all know, Numpy is the basic numerical computation library in Python and is widely used. No one can do scientific computing or machine learning without Python. But numpy doesn’t support Gpus or other hardware accelerators, and there’s no built-in support for BackPropagation, plus Python’s own speed limit, so few people would use Numpy directly to train or deploy deep learning models in production. This is why there are deep learning frameworks such as Theano, TensorFlow, and Caffe. However, Numpy has its unique advantages: low level, flexible, easy to debug, stable API and familiar (in line with MATLAB), favored by researchers. JAX’s main starting point is to combine the advantages of Numpy with hardware acceleration. JAX (github.com/google/jax), now open source, uses the GPU (CUDA) for hardware acceleration. Said by: www.zhihu.com/question/30…

 

JAX is a scientific computing library (NUMpy, SciPY) and neural network library (relu, SigmoID, ConV, etc.) that supports accelerators (Gpus and Tpus) and is more flexible and versatile than PyTorch and TensorFlow. This is why I recommend learning and doing this translation work, to take everyone to learn and master this framework.

Because the author is not English major, some neirong inevitably translation mistakes, welcome everyone’s criticism and correction. For some translations that the author is not sure about, use underlining and parentheses to supplement the original word, such as automatic differentiation.

 

Official definition: JAX is NumPy on CPUS, Gpus, and TPus with excellent automatic differentiation for high-performance machine learning research.

 

As an updated version of Autograd, JAX can automatically differentiate native Python and NumPy code. It can be differentiated with most of Python’s features (including loops, ifs, recursion, and closures), and even with derived classes of derived classes. It supports reverse mode and forward mode differentiation, and the two can be composed in any order.

What’s new is that JAX uses XLA to compile and run your NumPy code on accelerators such as gpus and TPUS. By default, compilation is done in the background, and library calls are compiled and executed in a timely manner. However, JAX even allows you to instantly compile your Own Python functions into an XLA-optimized kernel using single-function apis. Compilation and auto differentiation can be arbitrarily combined, so you can express complex algorithms and get the best performance without leaving Python.

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
Copy the code

Matrix multiplication

 

In the following example, we will generate random data. One big difference between NumPy and JAX is the way random numbers are generated. For more details, see Common Gotchas in JAX.

key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
Copy the code
[-0.372111 0.2642311-0.18252774-0.7368198-0.44030386-0.15214427-0.6713536-0.59086424 0.73168874 0.56730247]Copy the code

Times two big matrices.

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU
Copy the code
489 ms ± 3.98 ms per loop (mean ± std.dev. Of 7 runs, 1 loop each)Copy the code

We add that block_until_READY because JAX uses asynchronous execution by default (see asynchronous scheduling).

JAX NumPy functions can be used on regular NumPy arrays.

import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
Copy the code
488 ms ± 942 µs per loop (mean ± std.dev. Of 7 runs, 1 loop each)Copy the code

This is slower because it has to transfer data to the GPU every time. You can use it to ensure that NDArray is supported by device memory device_put().

from jax import device_put

x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
Copy the code
487 ms ± 9.94 ms per loop (mean ± std.dev. Of 7 runs, 1 loop each)Copy the code

The output device_put() is still the same as NDArray, but it copies values back to the CPU only if they need to be printed, plotted, printed, saved to disk, branch, etc. The behavior device_put() is equivalent to the function, but faster. jit(lambda x: x)

If you have a GPU (or TPU!) These calls will run on accelerators and may be much faster than on cpus.

x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
Copy the code
235 MS ± 546 µs per loop (mean ± std.dev. Of 7 runs, 1 loop each)Copy the code

JAX is more than just a GPU-supported NumPy. It also comes with some program conversions that are useful when writing numeric code. At present, there are three main ones:

  • Jit () to speed up your code
  • grad()forStrives for the gradient(derivatives)
  • Vmap () for automatic vectorization or batch processing.

Let’s introduce them all. We’ll also finally write these things in interesting ways.

 

usingjit()To speed up function

JAX runs transparently on gpus (if not, on cpus, and TPU is coming!). . However, in the example above, JAX assigns the kernel to the GPU one operation at a time. If we have a series of operations, we can compile multiple operations together using XLA using the @JIT decorator. Let’s try it.

Def selu (x, alpha = 1.67, lmbda = 1.05) : return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha) x = random.normal(key, (1000000,)) %timeit selu(x).block_until_ready()Copy the code
4.4ms ± 107 µs per loop (mean ± std.dev. Of 7 runs, 100 loops each)Copy the code

We can use the speedup @jit, which will be cached after the first selu call jit-compile.

selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
Copy the code
860 µs ± 27.2 µs per loop (mean ± std.dev. Of 7 runs, 1000 loops each)Copy the code

throughGrad () computes the gradient

In addition to evaluating numerical functions, we also want to transform them. One transformation is automatic differentiation. In JAX, just as in Autograd, you can use the grad() function to calculate gradients.

def sum_logistic(x):
  return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
Copy the code
[0.25, 0.19661197, 0.10499357]Copy the code

Let’s use finite differences to verify that our result is correct.

def first_finite_differences(f, x):
  eps = 1e-3
  return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
                   for v in jnp.eye(len(x))])


print(first_finite_differences(sum_logistic, x_small))
Copy the code
[0.24998187, 0.1964569, 0.10502338]Copy the code

The gradient can be solved by simply calling grad(). Grad () and JIT () can be mixed arbitrarily. In the example above, we jitter sum_logistic and then take its derivative. Let’s continue with the experiment:

Print (grad (jit (grad (jit (grad (sum_logistic))))) (1.0))Copy the code
0.035325594Copy the code

For more advanced autodiff, jax.vjp() can be used for the backward mode vector Jacobian product and jax.jvp() forward mode Jacobian cross product. Both can be combined arbitrarily with each other or with other JAX-conversions. This is one way to combine them to form a function that effectively computes the complete Hessian matrix:

from jax import jacfwd, jacrev
def hessian(fun):
  return jit(jacfwd(jacrev(fun)))
Copy the code

Automatic vectorizationvmap()

JAX has another transformation in its API that you may find useful: the vmap() vectorization mapping. It has the familiar semantics of mapping functions along the array axis, but instead of keeping the loop outside, it pushes the loop into the function’s original operation to improve performance. When jit() is combined with, it can be as fast as manually adding batch sizes.

We will use a simple example and use the matrix vector product to be promoted to a matrix matrix product vmap(). Although it is easy to do this manually in this particular case, the same techniques can be applied to more complex functions.

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)
Copy the code

Given functionality such as apply_matrix, we can loop through batch dimensions in Python, but the performance of doing so is generally poor.

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
Copy the code
Naively batched
Copy the code
4.43 ms ± 9.91 µs per loop (mean ± std.dev. Of 7 runs, 100 loops each)Copy the code

We know how to manually batch this operation. In this case, JNP.DOT transparently handles additional lot sizes.

@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
Copy the code
Manually batched
Copy the code
51.9 µs ± 1.72 µs per loop (mean ± std.dev. Of 7 runs, 10000 loops each)Copy the code

However, assuming no batch support, our functionality is more complex. We can use vmap() to automatically add batch support.

@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
Copy the code
Auto-vectorized with vmap
Copy the code
79.7 µs ± 249 ns per loop (mean ± std.dev. Of 7 runs, 10000 loops each)Copy the code

Of course, vmap() can be combined with any JIT (), grad(), and any other JAX transform.

This is just something JAX can do. We are happy to see your operation!