In this section, we’ll take a closer look at how JAX works. We’ll talk a little bit about jax.JIT () conversions for JAX, which compile a JAX.JIT () Python function in a just-in-time (JIT) manner for efficient execution in XLA.

In the previous section on JAX sharing, we learned that JAX can transform Python functions to produce a new function. This is done by first converting Python functions into a simple intermediate language called JAXPR. The transformation then works on jaxPR’s representation.

Next, use JAX.make_jaxpr to show that jaxPR for a function represents a Python function.

Conceptually, the first thing a JAXTRANSFORMATION does is turn a Python function into a lightweight, well-behaved intermediate form. This process can be thought of as a specific trace, and Jaxpr performs the transformation through an internal interpreter. The reason JAX is able to cram so much functionality into such a small package is that not only do you start with a familiar, flexible programming interface (Python with NumPy) and use an actual Python interpreter to do most of the heavy lifting, It distills the essence of computation into a simple statically typed expression language with limited higher-order features. That language is the JAXPR language.

import jax
import jax.numpy as jnp

global_list = []

def log2(x) :
  global_list.append(x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2)(3.0))
Copy the code
{ lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }
Copy the code

The “Understanding Jaxprs” section of the documentation provides more information about what this output means.

It’s important to note that JAXPR does not trace the side effects of the function: there is no global_List.append (x) content found in JAXPR in the conversion. This is a feature, not an error. JAX is designed to understand code that has no side effects (that is, pure functions).

The JAX internal representation is purely functional, but there are some programming restrictions on user use given the highly dynamic nature of the Python language. For example, Python functions that are automatically differentiated by JAX support only pure functions, requiring the user to ensure this. If the user code writes side effects, the function generated by the JAXtransform may not perform as expected. Because the JAX Trace function is a pure function, you may need to retrace when global variables or configuration information changes.

During tracing, JAX wraps each parameter with a tracer object, which then records all JAX operations on the parameter during the function call (this happens in normal Python). JAX then uses the tracker’s records to refactor the entire function. The output of this refactoring is jaxPR in the middle. Because the tracker does not record Python side effects, the code for side effects does not show up in JAXPR. During the follow-up, side effects still occur.

def log2_with_print(x) :
  print("printed x:", x)
  ln_x = jnp.log(x)
  ln_2 = jnp.log(2.0)
  return ln_x / ln_2

print(jax.make_jaxpr(log2_with_print)(3.))
Copy the code

Note: Python’s print() function is not a storage function either. Since the text output is input to the IO operation, it can be viewed as a side effect, so print is not a pure function. Therefore, any print() does not appear in JAXPR.

printed x: Traced<ShapedArray(float32[], weak_type=True) >with<DynamicJaxprTrace(level=1/0) > {lambda  ; a.
  let b = log a
      c = log 2.0
      d = div b c
  in (d,) }
Copy the code

Do you see that the printed x is a trace object? That’s the internal work of JAX. The fact that Python code runs at least once is strictly an implementation detail and should not be relied on. However, this is useful to understand because you can use it during debugging to print out the computed median.