Tutorial: Writing JAXlike code in PyTorch with functorch
The majority of deep learning frameworks are organized around a clear objectoriented (OO) philosophy, in which the parameters and logic of each component are neatly encapsulated inside an object (e.g., the nn.Module
of PyTorch or the keras.Model
of TensorFlow). A small number of frameworks  most notably JAX  have embraced instead a pure functional programming (FP) perspective. The FP approach has two major advantanges, which are not immediately apparent. On one side, parameters live outside the models, making it very easy to manipulate them inbetween function calls.^{1} On the other side, it is simple to concatenate functional transformations on our model to obtain very sophisticated behaviours.^{2}
JAX in particular has steadily gained popularity in many fields requiring more flexibility than the standard OO frameworks, leading to an explosion of projects to combine the two worlds and bring some form of PyTorchlike functionalities in JAX (e.g., Flax, Haiku, Objax, Equinox). In this blog post, we go into the opposite direction, and we take a sneak peek at functorch, a prototype library under development to provide users the possibility of writing JAXlike code using standard PyTorch.
🔗 The code for this tutorial is available on a Colab notebook: https://colab.research.google.com/drive/1Oi6q6w9QdaE4Bx_rAzcspGJVclqbnW7?usp=sharing (
functorch
being under development means that most likely the code will break very soon, or is already broken).
Outline of the tutorial
To highlight the differences between PyTorch and JAX, it is helpful to look at specific examples. For this tutorial, we look at a toy multiview scenario: given an image model, we want to develop a new model that works on a set of images, corresponding to different views of the same input. While this can be done in many ways, the vmap
(vectorized map) function from JAX is particularly elegant.
To this end, we show three implementations: JAX (easy), PyTorch (harder), and finally functorch
(much easier). In moving from PyTorch to JAX, we need to transform the objects representing our neural networks into (pure) functions. As an added bonus, we show a simplified solution for how functorch
itself tackles the issue.
⚡ In this tutorial, we see how to use JAXlike programming in PyTorch. As a companion (writing PyTorchlike code in JAX), I strongly suggest From PyTorch to JAX: towards neural net frameworks that purify stateful code from Sabrina Mielke.
Warming up
Let us start with a tiny convolutional neural network (CNN):
tinycnn = nn.Sequential(
nn.Conv2d(3, 5, kernel_size=3),
nn.Conv2d(5, 5, kernel_size=3)
)
Nothing out of the ordinary. For completeness, let us run the CNN on a random 16x16 RGB image:
x = torch.randn((1, 3, 16, 16))
y = tinycnn(x)
We say that models in PyTorch are stateful, because the result of the last line will depend on the internal parameters of the module, which in turn depend on the context of the execution (e.g., calls to optimizer.step()
can modify their values as a side effect).
Now, suppose we are asked to implement a multiview variant of the CNN, i.e., a CNN taking as input several photos of the same object, where for simplicity we will assume a fixed number of views. A common solution to multiview problems is to process each image with a shared CNN, then pool the results using, e.g., maxpooling, to obtain a representation which is independent on the order of the views:^{3}
Concretely, we represent our multiview dataset by adding a trailing dimension to our input tensor, representing the views of the same input:
x = torch.randn((
1, # Batch dimension
3, # Channels
16, # Width
16, # Height
5 # Views
))
Applying a CNN on a new dimension requires to vectorize its operations along the new axis, something which calls to mind the vmap
transform from JAX. To investigate this, let us see how this problem would be implemented in pure JAX.^{4}
Solution 1: JAX with vmap
We use the stax module from JAX to reimplement the same neural network above:
tinycnn_init, tinycnn_apply = stax.serial(
stax.Conv(5, (3,3)),
stax.Conv(5, (3,3))
)
Because JAX fully embraces the functional paradigm, a model is not represented as an instance of an object. Instead, it is represented with a pair of functions to, respectively, initialize the model and apply it. The model application is stateless, because we explicitly need to pass the parameters to run the model (pure functions not being allowed to have side effects):
key = jax.random.PRNGKey(42)
_, params = tinycnn_init(key, (1, 16, 16, 3)) # Channels last
x = jax.random.normal(key, (3, 16, 16, 3))
y = tinycnn_apply(params, x)
The advantage of working with functions is that it is very easy to transform them into other functions. For example, we can vectorize our tinycnn_apply
function on our new axis, to work on the multiview tensor:
tinycnn_apply_multiview = jax.vmap(partial(tinycnn_apply, params), in_axes=4, out_axes=4)
It works out the box:
x = jax.random.normal(key, (1, 16, 16, 3, 5))
print(tinycnn_apply_multiview(X).shape)
# Out: (3, 12, 12, 5, 5)
Can we replicate this in PyTorch?
Solution 2: PyTorch with vmap
PyTorch already has a number of prototype implementations of JAX functionalities, including a vmap prototype, and a functional autograd API. However, both modules work on pure functions, while PyTorch modules are classes. This requires, basically, to “extract” the forward method of the module and “purify” it by making the parameters an explicit argument. Let me steal the following image with a reminder to read this fantastic blog post:
When the autograd Functional API was released, I asked about its interoperability on the PyTorch forum, and one of the developers of the API suggested the following transformation (simplified for readability, see the associated notebook):
def make_functional(module):
orig_params = tuple(module.parameters())
# Remove all the parameters in the model
names = []
for name, p in list(module.named_parameters()):
del_attr(module, name.split("."))
names.append(name)
# This is equivalent to the "tinycnn_apply" function above
def functional_module_fw(params, x):
for name, p in zip(names, params):
set_attr(module, name.split("."), p)
return module(x)
return orig_params, functional_module_fw
What the transformation does is lifting the parameters from inside the module, and defining a function taking them as input.^{5} Now we can call the function just like in JAX:
params, model_fcn = make_functional(tinycnn)
x = torch.randn((1, 3, 16, 16))
y = model_fcn(params, x)
More importantly, now that we have a pure function, we can use the PyTorch vmap
! The next instructions are almost linebyline identical to JAX:
model_fcn_multiview = torch.vmap(partial(model_fcn, params), in_dims=4, out_dims=4)
x = torch.randn((1, 3, 16, 16, 5))
print(model_fcn_multiview(X).shape)
# Out: (1, 12, 12, 5, 5)
Solution 3: functorch
, finally
Although it works, it is tricky, and it also breaks down in some situations (e.g., if a layer has buffers like batch normalization, we need to extract those independently). However, if you run the code above, you will get an interesting warning, telling you that the development of these functional tools has moved to a separate project, functorch (see also #42368 RFC: torch.vmap and related issues).
functorch
is interesting because it promises to be an almost complete porting of JAX functionalities inside PyTorch. To begin with, it has its own, more professional version of make_functional:
import functorch
model_fcn, params = functorch.make_functional(tinycnn)
It has a more advanced version of vmap
, with several improvements over the (now deprecated) variant in PyTorch:
model_fcn_multiview = functorch.vmap(partial(model_fcn, params), in_dims=4, out_dims=4)
print(model_fcn_multiview(x).shape)
# Out: (1, 12, 12, 5, 5)
Apart from vmap
, functorch has its own implementation of grad
, vjp
, and jacrev
, along with a variant of make_functional
that works also with buffers, and some additional methods apparently under way.
These can be composed like in JAX to, e.g., obtain the Jacobian of the multiview function:
jacobian_fcn = jacrev(model_fcn_multiview)
This is it for the tutorial! functorch
is still under development, but it promises to be an interesting middle ground for users wanting only a “tiny bit of JAX” in their PyTorch applications. If you are intrigued, you should definitely go take a look!
Liked the tutorial? Remember that you can follow me on Twitter for almost daily updates. Feel free to reach out for ideas, feedbacks, or comments of any kind.
Footnotes and references

This is incredibly useful when moving beyond pure firstorder optimization. For example, two years ago I was playing with an optimization method called successive convex approximations, which required (among other things) to compute an expression of the form \(w^\top A w\), for some matrix \(A\) and network’s parameters \(w\). Using JAX’s pytrees, this was trivial. ↩

Consider the classical example
jacfwd(jacrev(f))
of composing forwardoverreverse automatic differentiation to compute dense Hessians. A personal example here (from the same paper as footnote 1) was linearizing a network \(f\) and then repeatedly computing the product with several vectors \(v\), which is even a primitive in JAX. ↩ 
Su, H., Maji, S., Kalogerakis, E. and LearnedMiller, E., 2015. Multiview convolutional neural networks for 3d shape recognition. IEEE ICCV (pp. 945953). ↩

🧠 Since this is a toy problem, we can of course implement the multiview CNN in a myriad of alternative ways (e.g., transpositions and reshapings). ↩

Dynamically modifying the behaviour of an object at runtime for compatibility with other methods is called monkey patching. higher has a similar “monkey patched” transformation of a module to compute metaloops of optimization. See also #58839 [RFC] Functional API design for torch.nn.Module for some additional comments on this. ↩