About Me

I'm currently a research scientist at Google DeepMind. I've obtained my PhD from UC Berkeley, advised by Professor Sergey Levine. I am interested in the infrastructure, data and modeling aspects of large-scale self-supervised learning. Specifically, my research focuses include language models, vision-language models, AI for science and reinforcement learning.

Before starting my PhD program in Berkeley, I did the one year Google AI Residency program. Before joining Google, I studied computer science and statistics at University of California, Berkeley. During my undergraduate study, I worked with Professor Pieter Abbeel, Professor Sergey Levine and Professor Alexei Efros as a research assistant in the Berkeley Artificial Intelligence Research (BAIR) Lab.


OpenLLaMA: An Open Reproduction of LLaMA

OpenLLaMA is a permissively licensed open source reproduction of Meta AI's LLaMA large language model. We provide PyTorch and JAX weights of our pre-trained OpenLLaMA 7B model, as well as evaluation results and comparison against the original LLaMA models. The OpenLLaMA model weights can serve as a drop in replacement for the original LLaMA in downstream applications.

Koala: A Dialogue Model for Academic Research

Check out our recent release of Koala, a dialogue language model fine-tuned on top of Meta's LLaMA using dialogue data gathered from the web.

Research Software


Scalax is a collection of utilties for helping developers to easily scale up JAX based machine learning models. The main idea of scalax is pretty simple: users write model and training code for a single GPU/TPU, and rely on scalax to automatically scale it up to hundreds of GPUs/TPUs. This is made possible by the JAX jit compiler, and scalax provides a set of utilities to help the users obtain the sharding annotations required by the jit compiler. Because scalax wraps around the jit compiler, existing JAX code can be easily scaled up using scalax with minimal changes.


Easy to use model parallel large language models training and evaluation in JAX/Flax using pjit on cloud TPU pods, with support for popular language models such as GPT-J, OPT and Roberta.


Machine Learning eXperiment Utilities: convenient utilities for running machine learning experiments, parsing experiment configurations and logging results.


Offline Q-Learning on Diverse Multi-Task Data Both Scales And Generalizes

Aviral Kumar*, Rishabh Agarwal*, Xinyang Geng, George Tucker, Sergey Levine

The potential of offline reinforcement learning (RL) is that high-capacity models trained on large, heterogeneous datasets can lead to agents that generalize broadly, analogously to similar advances in vision and NLP. However, recent works argue that offline RL methods encounter unique challenges to scaling up model capacity. Drawing on the learnings from these works, we re- examine previous design choices and find that with appropriate choices: ResNets, cross-entropy based distributional backups, and feature normalization, offline Q-learning algorithms exhibit strong performance that scales with model capacity. Using multi-task Atari as a testbed for scaling and generalization, we train a single policy on 40 games with near-human performance using up-to 80 million parameter networks, finding that model performance scales favorably with capacity. In contrast to prior work, we extrapolate beyond dataset performance even when trained entirely on a large (400M transitions) but highly suboptimal dataset (51% human-level performance). Compared to return- conditioned supervised approaches, offline Q-learning scales similarly with model capacity and has better performance, especially when the dataset is suboptimal. Finally, we show that offline Q-learning with a diverse dataset is sufficient to learn powerful representations that facilitate rapid transfer to novel games and fast online learning on new variations of a training game, improving over existing state-of-the-art representation learning approaches.

Towards Better Few-Shot and Finetuning Performance with Forgetful Causal Language Models

Hao Liu*, Xinyang Geng*, Lisa Lee, Igor Mordatch, Sergey Levine, Sharan Narang, Pieter Abbeel

Large language models (LLM) trained using the next-token-prediction objective, such as GPT3 and PaLM, have revolutionized natural language processing in recent years by showing impressive zero-shot and few-shot capabilities across a wide range of tasks. In this work, we propose a simple technique that significantly boosts the performance of LLMs without adding computational cost. Our key observation is that, by performing the next token prediction task with randomly selected past tokens masked out, we can improve the quality of the learned representations for downstream language understanding tasks. We hypothesize that randomly masking past tokens prevents over-attending to recent tokens and encourages attention to tokens in the distant past. We find that our method, Forgetful Causal Masking (FCM), significantly improves both few-shot and finetuning performance of PaLM. We further consider a simple extension, T-FCM, which introduces bidirectional context to causal language model without altering the sequence order, and further improves finetuning performance.

Multimodal Masked Autoencoders Learn Transferable Representations

Xinyang Geng*, Hao Liu*, Lisa Lee, Dale Schuurams, Sergey Levine, Pieter Abbeel

Building scalable models to learn from diverse, multimodal data remains an open challenge. For vision-language data, the dominant approaches are based on contrastive learning objectives that train a separate encoder for each modality. While effective, contrastive learning approaches introduce sampling bias depending on the data augmentations used, which can degrade performance on downstream tasks. Moreover, these methods are limited to paired image-text data, and cannot leverage widely-available unpaired data. In this paper, we investigate whether a large multimodal model trained purely via masked token prediction, without using modality-specific encoders or contrastive learning, can learn transferable representations for downstream tasks. We propose a simple and scalable network architecture, the Multimodal Masked Autoencoder ( M3AE), which learns a unified encoder for both vision and language data via masked token prediction. We provide an empirical study of M3AE trained on a large-scale image-text dataset, and find that M3AE is able to learn generalizable representations that transfer well to downstream tasks. Surprisingly, we find that M3AE benefits from a higher text mask ratio (50- 90%), in contrast to BERT whose standard masking ratio is 15%, due to the joint training of two data modalities. We also provide qualitative analysis showing that the learned representation incorporates meaningful information from both image and language. Lastly, we demonstrate the scalability of M3AE with larger model size and training time, and its flexibility to train on both paired image- text data as well as unpaired data.

Conservative objective models for effective offline model-based optimization

Brandon Trabucco*, Aviral Kumar*, Xinyang Geng, Sergey Levine

In this paper, we aim to solve data-driven model-based optimization (MBO) problems, where the goal is to find a design input that maximizes an unknown objective function provided access to only a static dataset of inputs and their corresponding objective values. Such data-driven optimization procedures are the only practical methods in many real-world domains where active data collection is expensive (eg, when optimizing over proteins) or dangerous (eg, when optimizing over aircraft designs, actively evaluating malformed aircraft designs is unsafe). Typical methods for MBO that optimize the input against a learned model of the unknown score function are affected by erroneous overestimation in the learned model caused due to distributional shift, that drives the optimizer to low-scoring or invalid inputs. To overcome this, we propose conservative objective models (COMs), a method that learns a model of the objective function which lower bounds the actual value of the ground-truth objective on out-of-distribution inputs and uses it for optimization. In practice, COMs outperform a number existing methods on a wide range of MBO problems, including optimizing controller parameters, robot morphologies, and superconducting materials.

Design-bench: Benchmarks for data-driven offline model-based optimization

Brandon Trabucco*, Xinyang Geng*, Aviral Kumar, Sergey Levine

Black-box model-based optimization (MBO) problems, where the goal is to find a design input that maximizes an unknown objective function, are ubiquitous in a wide range of domains, such as the design of proteins, DNA sequences, aircraft, and robots. Solving model-based optimization problems typically requires actively querying the unknown objective function on design proposals, which means physically building the candidate molecule, aircraft, or robot, testing it, and storing the result. This process can be expensive and time consuming, and one might instead prefer to optimize for the best design using only the data one already has. This setting -- called offline MBO -- poses substantial and different algorithmic challenges than more commonly studied online techniques. A number of recent works have demonstrated success with offline MBO for high-dimensional optimization problems using high-capacity deep neural networks. However, the lack of standardized benchmarks in this emerging field is making progress difficult to track. To address this, we present Design- Bench, a benchmark for offline MBO with a unified evaluation protocol and reference implementations of recent methods. Our benchmark includes a suite of diverse and realistic tasks derived from real-world optimization problems in biology, materials science, and robotics that present distinct challenges for offline MBO. Our benchmark and reference implementations are released at github.com/rail-berkeley/design-bench and github.com/rail-berkeley/design- baselines.

Meta-Reinforcement Learning Robust to Distributional Shift via Model Identification and Experience Relabeling

Russell Mendonca*, Xinyang Geng*, Chelsea Finn, Sergey Levine

Reinforcement learning algorithms can acquire policies for complex tasks autonomously. However, the number of samples required to learn a diverse set of skills can be prohibitively large. While meta-reinforcement learning methods have enabled agents to leverage prior experience to adapt quickly to new tasks, their performance depends crucially on how close the new task is to the previously experienced tasks. Current approaches are either not able to extrapolate well, or can do so at the expense of requiring extremely large amounts of data for on-policy meta-training. In this work, we present model identification and experience relabeling (MIER), a meta-reinforcement learning algorithm that is both efficient and extrapolates well when faced with out-of-distribution tasks at test time. Our method is based on a simple insight: we recognize that dynamics models can be adapted efficiently and consistently with off-policy data, more easily than policies and value functions. These dynamics models can then be used to continue training policies and value functions for out-of-distribution tasks without using meta-reinforcement learning at all, by generating synthetic experience for the new task.

Rewriting History with Inverse RL: Hindsight Inference for Policy Improvement

Benjamin Eysenbach*, Xinyang Geng*, Sergey Levine, Ruslan Salakhutdinov

Multi-task reinforcement learning (RL) aims to simultaneously learn policies for solving many tasks. Several prior works have found that relabeling past experience with different reward functions can improve sample efficiency. Relabeling methods typically ask: if, in hindsight, we assume that our experience was optimal for some task, for what task was it optimal? In this paper, we show that hindsight relabeling is inverse RL, an observation that suggests that we can use inverse RL in tandem for RL algorithms to efficiently solve many tasks. We use this idea to generalize goal-relabeling techniques from prior work to arbitrary classes of tasks. Our experiments confirm that relabeling data using inverse RL accelerates learning in general multi-task settings, including goal-reaching, domains with discrete sets of rewards, and those with linear reward functions.

Dynamical Distance Learning for Unsupervised and Semi-Supervised Skill Discovery

Kristian Hartikainen, Xinyang Geng, Tuomas Haarnoja, Sergey Levine

In ICLR, 2020.

Reinforcement learning requires manual specification of a reward function to learn a task. While in principle this reward function only needs to specify the task goal, in practice reinforcement learning can be very time-consuming or even infeasible unless the reward function is shaped so as to provide a smooth gradient towards a successful outcome. This shaping is difficult to specify by hand, particularly when the task is learned from raw observations, such as images. In this paper, we study how we can automatically learn dynamical distances: a measure of the expected number of time steps to reach a given goal state from any other state. These dynamical distances can be used to provide well-shaped reward functions for reaching new goals, making it possible to learn complex tasks efficiently. We show that dynamical distances can be used in a semi-supervised regime, where unsupervised interaction with the environment is used to learn the dynamical distances, while a small amount of preference supervision is used to determine the task goal, without any manually engineered reward function or goal examples. We evaluate our method both on a real-world robot and in simulation. We show that our method can learn to turn a valve with a real-world 9-DoF hand, using raw image observations and just ten preference labels, without any other supervision.

Improved Generalization with Curvature Regularization

Xinyang Geng, Lechao Xiao, Hossein Mobahi, Jeffrey Pennington

In ICML 2018 Workshop: Modern Trends in Nonconvex Optimization for Machine Learning

Recent advances in high-performance computing and the abundance of large labeled datasets have enabled machine learning practitioners to successfully develop and deploy deep learning models with enormous numbers of parameters. Owing to the large degree of overparameterization, it is perhaps surprising that in practice such models often generalize well. Indeed, identifying which types of large models will generalize well and which will generalize poorly remains an important research direction in the theory of deep learning. One recent proposal is that parameter configurations corresponding to “sharp” minima may generalize worse than those that correspond to “wide” minima. In this paper, we propose a computationally-efficient method for approximating a sharpness measure based on the mean curvature of the loss landscape near a critical point. We devise a new form of regularization for deep learning models based on this sharpness measure. Our experiments on fully connected networks and convolutional networks show that such regularization can significantly improve generalization performance.

Automatic Goal Generation for Reinforcement Learning Agents

Carlos Florensa*, David Held*, Xinyang Geng*, Pieter Abbeel

In ICML, 2018.

Reinforcement learning (RL) is a powerful technique to train an agent to perform a task; however, an agent that is trained using RL is only capable of achieving the single task that is specified via its reward function. Such an approach does not scale well to settings in which an agent needs to perform a diverse set of tasks, such as navigating to varying positions in a room or moving objects to varying locations. Instead, we propose a method that allows an agent to automatically discover the range of tasks that it is capable of performing in its environment. We use a generator network to propose tasks for the agent to try to accomplish, each task being specified as reaching a certain parametrized subset of the state-space. The generator network is optimized using adversarial training to produce tasks that are always at the appropriate level of difficulty for the agent, thus automatically producing a curriculum. We show that, by using this framework, an agent can efficiently and automatically learn to perform a wide set of tasks without requiring any prior knowledge of its environment, even when only sparse rewards are available.

Real-Time User-Guided Image Colorization with Learned Deep Priors

Richard Zhang*, Jun-Yan Zhu*, Phillip Isola, Xinyang Geng, Angela S. Lin, Tianhe Yu, Alexei A. Efros


We propose a deep learning approach for user-guided image colorization. The system directly maps a grayscale image, along with sparse, local user “hints" to an output colorization with a Convolutional Neural Network (CNN). Rather than using hand-defined rules, the network propagates user edits by fusing low-level cues along with high-level semantic information, learned from large-scale data. We train on a million images, with simulated user inputs. To guide the user towards efficient input selection, the system recommends likely colors based on the input image and current user inputs. The colorization is performed in a single feed-forward pass, enabling real-time use. Even with randomly simulated user inputs, we show that the proposed system helps novice users quickly create realistic colorizations, and offers large improvements in colorization quality with just a minute of use. In addition, we demonstrate that the framework can incorporate other user “hints" to the desired colorization, showing an application to color histogram transfer.

Deep Reinforcement Learning for Tensegrity Robot Locomotion

Marvin Zhang*, Xinyang Geng*, Jonathan Bruce*, Ken Caluwaerts, Massimo Vespignani, Vytas SunSpiral, Pieter Abbeel, Sergey Levine

In ICRA, 2017.

Tensegrity robots, composed of rigid rods connected by elastic cables, have a number of unique properties that make them appealing for use as planetary exploration rovers. However, control of tensegrity robots remains a difficult problem due to their unusual structures and complex dynamics. In this work, we show how locomotion gaits can be learned automatically using a novel extension of mirror descent guided policy search (MDGPS) applied to periodic locomotion movements, and we demonstrate the effectiveness of our approach on tensegrity robot locomotion. We evaluate our method with realworld and simulated experiments on the SUPERball tensegrity robot, showing that the learned policies generalize to changes in system parameters, unreliable sensor measurements, and variation in environmental conditions, including varied terrains and a range of different gravities. Our experiments demonstrate that our method not only learns fast, power-efficient feedback policies for rolling gaits, but that these policies can succeed with only the limited onboard sensing provided by SUPERball’s accelerometers. We compare the learned feedback policies to learned open-loop policies and hand-engineered controllers, and demonstrate that the learned policy enables the first continuous, reliable locomotion gait for the real SUPERball robot.