1 comments

  • marmaduke 288 days ago
    looks like a nice overview. i’ve implemented neural ODEs in Jax for low dimensional problems and it works well, but I keep looking for a good, fast, CPU-first implementation that is good for models that fit in cache and don’t require a GPU or big Torch/TF machinery.
    • sitkack 288 days ago
      • yberreby 287 days ago
        Anecdotally, I used diffrax (and equinox) throughout last year after jumping between a few differential equation solvers in Python, for a project based on Dynamic Field Theory [1]. I only scratched the surface, but so far, it's been a pleasure to use, and it's quite fast. It also introduced me to equinox [2], by the same author, which I'm using to get the JAX-friendly equivalent of dataclasses.

        `vmap`-able differential equation solving is really cool.

        [1]: https://dynamicfieldtheory.org/ [2]: https://github.com/patrick-kidger/equinox

      • marmaduke 287 days ago
        no, wrote it by hand for use with my own Heun implementation, since it’s for use within stochastic delayed systems.

        jax is fun but as effective as i’d like for CPU

    • barrenko 286 days ago
      How would you describe what a neural ODE is in the simplest possible terms? Let's say I know what an NN and a DE are :).
      • kk58 286 days ago
        classic NN takes a vector of data through layers to make a prediction. Backprop adjusts network weights till predictions are right. These network weights form a vector, and training changes this vector till it hits values that mean "trained network".

        Neural ODE reframes this: instead of focusing on the weights, focus on how they change. It sees training as finding a path from untrained to trained state. At each step, it uses ODE solvers to compute the next state, continuing for N steps till it reaches values matching training data. This gives you the solution for the trained network.

        • barrenko 286 days ago
          Pretty cool approach, looking more into it, thank you!