Part 8 of Jaxformer (Part 7: Training Results | The End)
Throughout this guide, we walked step-by-step through building a modern, scalable transformer model in JAX, focusing on both architectural advances and scaling practices.
Together, these parts form a zero-to-one guide on how to scale transformers from a simple baseline to cutting-edge distributed training.
In the future, this can be extended further by using more novel methods such as replacing GPipe with DualPipe and incorporating higher dimensions of parallelism such as expert, and/or sequence. We can also extend the tokenziation process by streaming Parquet files over a distributed network.
*How to get in touch: leave a comment on any page, reach us on socials, or start a discussion thread on the Github repo.
We are all currently 1st and 2nd year undergraduate students at the University of Waterloo studying Computer Science.
Author | Twitter / X | |
---|---|---|
Aditya Makkar | @AdityaMakkar000 | Aditya Makkar |
Divya Makkar | @_DivyaMakkar | Divya Makkar |
Chinmay Jindal | @chinmayjindal_ | Chinmay Jindal |