About Me

I'm currently a research scientist at Google DeepMind. I am interested in the infrastructure, data and modeling aspects of large-scale self-supervised learning. Specifically, my research focuses include language models, multimodal language models and reinforcement learning. In my spare time, I also develop open source infrastructures for training large-scale models.

I've obtained my PhD from UC Berkeley, advised by Professor Sergey Levine. Before starting my PhD program in Berkeley, I did the one year Google AI Residency program. Before joining Google, I studied computer science and statistics at University of California, Berkeley. During my undergraduate study, I worked with Professor Pieter Abbeel, Professor Sergey Levine and Professor Alexei Efros as a research assistant in the Berkeley Artificial Intelligence Research (BAIR) Lab.





Research Software

Scalax

Scalax is a collection of utilties for helping developers to easily scale up JAX based machine learning models. The main idea of scalax is pretty simple: users write model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs. This is made possible by the JAX jit compiler, and scalax provides a set of utilities to help the users obtain the sharding annotations required by the jit compiler. Because scalax wraps around the jit compiler, existing JAX code can be easily scaled up using scalax with minimal changes.

MinText

Mintext is a minimal but scalable implementation of large language models in JAX. Specifically, it implements the LLaMA architecture in a clean and modular way, which makes it easy to modify and extend. The codebase is designed to be a didactic example of how one can implement a large language model from scratch in JAX with fairly minimal code, while still retaining the ability to scale to large models on thousands of accelerators. Specifically, the codebase supports a combination of data parallelism, fully sharded model parallelism, sequence parallelism (ring attention) and tensor parallelism.

EasyLM

Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. EasyLM can scale up LLM training to hundreds of TPU/GPU accelerators by leveraging JAX's pjit functionality without the complexity in many other frameworks.

TPU Pod Commander

TPU Pod Commander is a package for setting up and launching jobs on Google Cloud TPU pods. It provides a simple and easy-to-use interface for controlling Cloud TPUs via command line args and configuration scripts.

MLXU

Machine Learning eXperiment Utilities: convenient utilities for running machine learning experiments, parsing experiment configurations and logging results.

Open Source Models

OpenLLaMA: An Open Reproduction of LLaMA

OpenLLaMA is a permissively licensed open source reproduction of Meta AI's LLaMA large language model. We provide PyTorch and JAX weights of our pre-trained OpenLLaMA 7B model, as well as evaluation results and comparison against the original LLaMA models. The OpenLLaMA model weights can serve as a drop in replacement for the original LLaMA in downstream applications.

Koala: A Dialogue Model for Academic Research

Check out our recent release of Koala, a dialogue language model fine-tuned on top of Meta's LLaMA using dialogue data gathered from the web.