Conclusion: Summary and Future Extensions

Part 8 of Jaxformer (Part 7: Training Results | The End)

Conclusion

Throughout this guide, we walked step-by-step through building a modern, scalable transformer model in JAX, focusing on both architectural advances and scaling practices.

Article Summaries

Together, these parts form a zero-to-one guide on how to scale transformers from a simple baseline to cutting-edge distributed training.

Future Directions

In the future, this can be extended further by using more novel methods such as replacing GPipe with DualPipe and incorporating higher dimensions of parallelism such as expert, and/or sequence. We can also extend the tokenziation process by streaming Parquet files over a distributed network.

*How to get in touch: leave a comment on any page, reach us on socials, or start a discussion thread on the Github repo.

Authors & Contact

We are all currently 1st and 2nd year undergraduate students at the University of Waterloo studying Computer Science.

Author Twitter / X LinkedIn
Aditya Makkar @AdityaMakkar000 Aditya Makkar
Divya Makkar @_DivyaMakkar Divya Makkar
Chinmay Jindal @chinmayjindal_ Chinmay Jindal