projects
(wip)
JAX
~50 papers implemented in JAX
4 are open source
FT Transformer
- Implementation of
Revisiting Deep Learning Models for Tabular Data
in JAX
pip install fttjax
GateLoop Transformer
- Implementation of
GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling
in JAX
pip install gateloop
Bidirectional Cross Attention
- Implementation of
Perceiving Longer Sequences With Bi-Directional Cross-Attention Transformers
in JAX
pip install bidirectional-cross-attention-jax
Axial Positional Embedding
- JAX implementation of axial positional embedding used in
Reformer: The Efficient Transformer
pip install axial-positional-embedding-jax
JAXVision
Common utilities and architectures for computer vision in JAX
Being used in production
Will be open sourced soon
JAX equivalent of torchvision
axra.org
This website has a lot of cool WASM stuff, all made from scratch by me
+ ~100 more incomplete ones