looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.
`vmap`-able differential equation solving is really cool.
classic NN takes a vector of data through layers to make a prediction. Backprop adjusts network weights till predictions are right. These network weights form a vector, and training changes this vector till it hits values that mean "trained network".
Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.
JAX Talk: Diffrax https://www.youtube.com/watch?v=Jy5Jw8hNiAQ
`vmap`-able differential equation solving is really cool.
[1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox
Kidger's thesis is wonderful https://arxiv.org/abs/2202.02435
jax is fun but as effective as i’d like for CPU
Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.