Part 5 of Jaxformer (Part 4: Dataset & Config | Part 6: Mixture of Experts)
We now introduce the main training script that will be used to launch the training. This covers the infrastructure, distributed functions and training loops that will sync all devices together.
We will begin by configuring JAX. In JAX, XLA flags optimize performance and are related to communications that occur between GPUs . We can follow general practice to enable flags allowing for faster performance. Note we train on TPUs but the flags do not hurt performance in general.
import os
os.environ["XLA_FLAGS"] = (
    "--xla_gpu_triton_gemm_any=True --xla_gpu_enable_latency_hiding_scheduler=true "
)
We can also use JAX’s optional disk cache which enables JAX to store copies of complied programs on disk, saving recompilation time when running the same or similar tasks repeatedly. We use a remote file storage to sync cache across multi-controller nodes.
jax.config.update("jax_compilation_cache_dir", "gs://jaxformer-cache/")
jax.config.update("jax_persistent_cache_min_entry_size_bytes", -1)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 0)
jax.config.update(
    "jax_persistent_cache_enable_xla_caches", "xla_gpu_per_fusion_autotune_cache_dir"
)
We can begin by writing a helper function to print values on multi-controller JAX since we only want to print once on the main node’s process.
def log(msg: str):
    if jax.process_index() == 0:
        print(msg)
We can now create a function to setup the mesh for the given TPU topology. The function takes in the number of devices per axes, the name of each axis and returns the JAX mesh. It first makes the devices into an np array and ensures we have the right number of devices in the desired mesh as the device count.
We try to use the jax.make_mesh function as that makes the most optimized mesh given the topology of TPUs; however, if it cannot, it throws an exception hence we wrap it in a try-catch and make the mesh ourself. The mesh is then returned.
def init_devices(
    axes: Tuple[int, ...], axes_name: Tuple[str, ...]
) -> jax.sharding.Mesh:
    devices = np.array(jax.devices())
    # print for convenience
    # Assumes you are on TPU
    for idx in np.ndindex(devices.shape):
        d = devices[idx]
        log(
            f"  {idx} ID: {d.id}, Process: {d.process_index}, "
            f"Coords: {d.coords}, Core: {d.core_on_chip}"
        )
    assert devices.size == np.prod(axes), (
        f"Expected {np.prod(axes)} devices, got {devices.shape[0]}"
    )
    try:
        mesh = jax.make_mesh(axes, axes_name)
    except:
        log("Failed to create mesh with make_mesh, falling back to sharding.Mesh")
        mesh = jax.sharding.Mesh(devices.reshape(axes), axes_name)
    return mesh
The three different axes represent the different parallelism strategies we have and can be visualized as shown below.
 
 With the helper functions established, we can now begin the main training loop function. Our main function will take in the config we described earlier. Since we are assuming this script is for 3-D parallelism, we can assign variables to the device size for each axis and setup the key with the initial seed.
def main(cfg: config):
    key = jax.random.PRNGKey(cfg.seed)
    DATA_PARALLEL, LAYER_PARALLEL, TENSOR_PARALLEL = cfg.device_config.n_device_axis
We can now initialize and log our mesh.
def main(cfg: config):
  ...
  axes = (*cfg.device_config.n_device_axis,)
  axes_name = ("dp", "pp", "tp")
  mesh = init_devices(axes, axes_name)
  log(mesh)
The next step is to setup checkpointing for our model using the orbax library and the orbax.checkpointing (ocp) module which conveniently handles checkpointing on multiprocess and remote storage for us. All we need to do is give it the google storage url and make sure to run it on every process. We first make the directory from the config by combining the GCS url with the unique name of the run and setup the orbax checkpoint manager. We can then use this to see if a latest step exists in which case we are loading from a previous run.
checkpoint_dir = cfg.output_dir + cfg.name
options = ocp.CheckpointManagerOptions(max_to_keep=1)
checkpoint_manager = ocp.CheckpointManager(checkpoint_dir, options=options)
load = checkpoint_manager.latest_step() is not None
We begin setting up our dataset. We first make the data partition. Every data shard that is loaded will be of the form (G, M , B, T) where G is the total batches in a shard, M are the microbatches in the batch (for pipelining), B is the batch size per microbatch and T is the sequence length. Thus we want the M to be split amongst the pipeline, B to be split amongst the data and T to be spilt initially along Tensor as discussed previously. Thus we obtain the following PartitionSpec and NamedSharding and can use it to initialize our dataset class written previously.
data_spec = P(None, "pp", "dp", "tp")
data_partition = jax.sharding.NamedSharding(mesh, data_spec)
train_dataset, val_dataset = Dataset.getDataset(
    cfg.data_config,
    partition=data_partition,
    dp=DATA_PARALLEL,
    pp=LAYER_PARALLEL,
    tp=TENSOR_PARALLEL,
)
We can now create our model using the ShardedModel, creating our init key and and initializing our params.
model = shardedModel(cfg.model_config)
log("creating sharded model ...")
key, init_key = jax.random.split(key, 2)
params = model.init_weights(init_key, mesh)
We can now use these params to initialize our optimizer. Since the params are sharded, the optimizer states will be sharded as well.
lr_scheduler = optax.warmup_cosine_decay_schedule(
    init_value=cfg.lr.min_lr,
    peak_value=cfg.lr.max_lr,
    warmup_steps=cfg.lr.warmup_steps,
    decay_steps=cfg.lr.end_steps,
    end_value=cfg.lr.end_lr,
)
tx = optax.chain(
    optax.clip_by_global_norm(config.grad_clip_norm),
    optax.inject_hyperparams(optax.adamw)(learning_rate=lr_scheduler),
)
One bug that we observed in Optax is that the params with no dims (i.e scalar values) are not replicated across devices leading to errors when trying to reload from checkpoints and use them in distributed functions calls (i.e train step which is written below). Hence we can write a simple map function that says if the value has no dimensions, replicate it across each device.
default_sharding = jax.sharding.NamedSharding(mesh, P())
opt_state = jax.tree.map(
    lambda x: x if jnp.ndim(x) != 0 else jax.device_put(x, default_sharding),
    tx.init(params),
)
We can now setup our misc variables such as our starting step, whether to use wandb (has to be enabled in config and process 0), and a placeholder for the id.
init_step = 0
use_wandb = cfg.wandb is True and jax.process_index() == 0
wandb_id = None
We can also write a save-checkpoint function to ensure the PyTree saves. Note we decrement the shard_idx for the train/val dataset because when loading a shard, we increment by 1, so we want to revert that change.
def make_save_tree(step):
  model_state = {
      "params": params,
      "opt_state": opt_state,
  }
  save_tree = {
      "state": model_state,
      "key": jax.device_get(key),
      "train_step_idx": train_dataset.step_idx,
      "train_shard_idx": (train_dataset.shard_idx - 1) % len(train_dataset.data),
      "val_step_idx": val_dataset.step_idx,
      "val_shard_idx": (val_dataset.shard_idx - 1) % len(val_dataset.data),
      "step": step,
  }
  metadata = {
      "wandb_id": wandb_id
  }
  return save_tree, metadata
Our save_checkpoint function can now just take the step and call the checkpoint manager.
def save_checkpoint(
    step,
):
    save_tree, metadata = make_save_tree(step)
    checkpoint_manager.save(step, args=ocp.args.Composite(
      state=ocp.args.StandardSave(save_tree),
      metadata=ocp.args.JsonSave(metadata)
    ))
Before the main training functions or training loop, we should add model-loading logic if we want to resume from a checkpoint. Since we always initialize, we can pass Orbax the sharding and array metadata from the current parameters and use that to load with the correct sharding.
def main(cfg: config):
    ...
    if load:
        # get PyTree metadata
        abstract_tree_map = jax.tree.map(
          ocp.utils.to_shape_dtype_struct, make_save_tree(init_step)
        )
        # load checkpoint
        tree = checkpoint_manager.restore(
            checkpoint_manager.latest_step(),
            args=ocp.args.Composite(
                state=ocp.args.StandardRestore(abstract_tree_state),
                metadata=ocp.args.JsonRestore(),
        ))
        # assign all variables
        tree_state, tree_metadata = tree.state, tree.metadata
        init_step = tree_state["step"]
        log(f"loading checkpoint @ step {init_step}")
        key.key = tree_state["key"]
        params = tree_state["state"]["params"]
        opt_state = tree_state["state"]["opt_state"]
        train_dataset.step_idx = tree_state["train_step_idx"]
        train_dataset.shard_idx = tree_state["train_shard_idx"]
        train_dataset.load_next_shard()
        val_dataset.step_idx = tree_state["val_step_idx"]
        val_dataset.shard_idx = tree_state["val_shard_idx"]
        val_dataset.load_next_shard()
        wandb_id = tree_metadata["wandb_id"]
        if use_wandb:
            assert wandb_id is not None, "wandb_id is None"
            wandb.init(
                entity="waterloo2",
                project="jaxformer",
                name=cfg.name,
                resume="must",
                id=wandb_id,
                config=asdict(cfg),
            )
Otherwise, if we are not loading, we can save the first checkpoint and initialize the wandb run if needed.
def main(cfg: config):
  ...
  if load:
      ...
  else:
      log("no checkpoint found, saving init copy")
      save_checkpoint(init_step)
      if use_wandb:
          wandb.init(
              entity="waterloo2",
              project="jaxformer",
              name=cfg.name,
              resume="allow",
              config=asdict(cfg),
          )
          wandb_id = wandb.run.id
Finally, we can print our parameter count for convenience.
param_count = jax.tree.reduce(
    lambda x, y: x + y.size,
    params,
    0,
)
log(f"Total parameters: {param_count:,}")
Now we can introduce the step functions that call our model. We begin by writing a general step that runs the model forward and returns the loss along with other metrics. Note that we will use communication operations (e.g., pmean), but since this will ultimately be wrapped under a shard_map, this is allowed (you cannot call pmean unless you are under a mesh context, as there is otherwise no information about the distributed setting). Our step function is defined by wrapping loss_fn in a closure under the training variable.
def step(params, x, y, key, train):
  def loss_fn(params, x, y, key):
      ...
  return loss_fn(params, x, y, key)
We can first get the logits from the model by calling pipe_step, discarding the cache output.
def loss_fn(...):
    logits, _ = model.pipe_step(
        params,
        x,
        key=key,
        train=train,
    )
We can first begin by stepping through. We can use the JAX built in function to turn logits into log-probs and reshape it into a 2D tensor combining all dims other then the distribution into a batch.
def loss_fn(...):
    ...
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    M, B, T, V = logits.shape
    y = y.reshape(-1)
    log_probs = log_probs.reshape(M * B * T, V)
Note that logits is 4D originally since it is a tensor with dimensions defined as microbatch, batches per microbatch, sequence and vocab. We can get our cross-entropy loss by applying a vmap over and selecting the index that using a dynamic slice method, negating and then meaning it.
def loss_fn(...):
    loss_idx = lambda x, idx: jax.lax.dynamic_slice(x, (idx,), (1,))
    loss_cross = -(jax.vmap(loss_idx, in_axes=(0, 0))(log_probs, y)).mean()
To perform FSDP, we average the loss over the dp axis. We do the same for the pp and tp axes as well, since batches span multiple devices. The jax.grad function will then handle the reverse communication operations needed to propagate the gradients.
loss_cross = jax.lax.pmean(loss_cross, axis_name="dp")
loss_cross = jax.lax.pmean(loss_cross, axis_name="tp")
loss_cross = jax.lax.pmean(loss_cross, axis_name="pp")
We can make a dict of metrics and return that as well. This will be useful when we have to log stats on MoE.
metrics = {
    "loss": loss,
    "loss_cross": loss_cross,
}
return loss, metrics
Now we can get the partition spec of the params, model, key and write the distributed step function.
param_spec = shardedModel.get_p_spec(
    [model.embedding, model.block], mesh, cfg.model_config
)
opt_spec = jax.tree.map(lambda x: x.sharding.spec, opt_state)
key_spec = P("dp", "pp", "tp")
Note each device gets a unique key since we want every operation done on every device to be unique otherwise there is no reason to use more then 1 device. We start by writing the train step. In here, our step function will be the value and grad of step functions previously written since we also want to compute the gradients.
def train_step(params, opt_state, x, y, key):
    step_fn = jax.value_and_grad(step, has_aux=True)
Then to allow for gradient accumulation we write a single step function. We take in the past gradients and the batch and then accumulate the grads.
def train_step(params, opt_state, x, y, key):
    step_fn = jax.value_and_grad(step, has_aux=True)
    def single_step(grads, batch):
        (_, metrics), grads_current = step_fn(params, *batch, train=True)
        grads = jax.tree.map(lambda x, y: x + y, grads, grads_current)
        return grads, metrics
We can then initialize the grads and reshape the keys to be PRNG keys again (get rid of leading dims) and then use the jax.lax.scan function to sequentially loop over the leading dim of the (x,y,key) batch.
def train_step(params, opt_state, x, y, key):
    step_fn = jax.value_and_grad(step, has_aux=True)
    def single_step(grads, batch):
        ...
        return grads, metrics
    grads = jax.tree.map(lambda x: jnp.zeros_like(x), params)
    key = key.reshape(cfg.grad_step, 2)
    grads, metrics = jax.lax.scan(
        single_step,
        grads,
        (x, y, key),
  )
We then average the gradients and metrics, apply the updates to the parameters, and return the updated parameters, optimizer state, and metrics. Thus, our final function is as follows:
def train_step(params, opt_state, x, y, key):
    step_fn = jax.value_and_grad(step, has_aux=True)
    def single_step(grads, batch):
        (_, metrics), grads_current = step_fn(params, *batch, train=True)
        grads = jax.tree.map(lambda x, y: x + y, grads, grads_current)
        return grads, metrics
    grads = jax.tree.map(lambda x: jnp.zeros_like(x), params)
    key = key.reshape(cfg.grad_step, 2)
    grads, metrics = jax.lax.scan(
        single_step,
        grads,
        (x, y, key),
    )
    grads = jax.tree.map(lambda x: x / cfg.grad_step, grads)
    metrics = jax.tree.map(lambda x: x.mean(axis=0), metrics)
    updates, opt_state = tx.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, metrics
We now can wrap this function under a shard map to allow for the distributed training to occur. For the arguments, we use the spec we have defined throughout the script and the outputs follow the same way. Metrics are replicated across every device.
@jax.jit
@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(param_spec, opt_spec, data_spec, data_spec, key_spec),
    out_specs=(param_spec, opt_spec, P()),
    check_vma=False,
)
def train_step(params, opt_state, x, y, key):
    ...
    return params, opt_state, metrics
We can similarly write the eval_step function. The only difference is we don’t have to keep the grads, thus the carry argument of the jax.lax.scan can be ignored in the single_step function.
@jax.jit
@partial(
    jax.shard_map,
    mesh=mesh,
    in_specs=(param_spec, data_spec, data_spec),
    out_specs=P(),
    check_vma=False,
)
def eval_step(params, x, y):
    def single_step(_, batch):
        loss, metrics = step(
          params, *batch, key=jax.random.PRNGKey(0), train=False
        )  # Key does not matter
        return loss, metrics
    _, metrics = jax.lax.scan(single_step, 0, (x, y))
    metrics = jax.tree.map(lambda x: x.mean(axis=0), metrics)
    return metrics
We now define our final variables and sync devices. Note we split the sample key before the loop since we want to have the same random key for each inference to see how the model evolves. We also keep an array to append the training loss and average when we need to print for each eval step.
def main(cfg: config):
    ...
    total_steps = cfg.training_steps
    total_tokens = train_dataset.tokens_per_step
    jax.experimental.multihost_utils.sync_global_devices("sync")
    log(f"Total steps: {total_steps}")
    log(f"Total tokens per step: {total_tokens:,}")
    key, sample_key = jax.random.split(key, 2)
    start = time.time()
    train_loss = [] # used to keep track of loss and averaged when printing
Our last helper function is to make the keys. Essentially we want to create new keys for each device for a total of our grad steps. Therefore, we can make this a param and re-jit the function for each new value since it must be a static parameter. This doesn’t slow us down since this function is called and compiled once.
@partial(jax.jit, static_argnames=["steps"])
def make_sharded_key(key, steps=1):
    key = jax.random.split(
      key, DATA_PARALLEL * LAYER_PARALLEL * TENSOR_PARALLEL * steps
    ) # python array currently make it into a jax array
    key = jnp.asarray(key).reshape(
      (DATA_PARALLEL, LAYER_PARALLEL, TENSOR_PARALLEL, steps, 2)
    )
    return key
Then, the final training loop can be written. We start by splitting our key and then making our train keys, getting our data and finally calling the train step.
def main(cfg: config):
    for current_step in range(init_step, total_steps):
        key, train_key = jax.random.split(key)
        train_key = make_sharded_key(train_key, steps=cfg.grad_step)
        x, y = train_dataset(step=cfg.grad_step)
        params, opt_state, metrics = train_step(params, opt_state, x, y, train_key)
        train_loss.append(metrics["loss"])
We then add wandb logging metrics and add in our eval step.
def main(cfg: config):
    for current_step in range(init_step, total_steps):
        ...
        if use_wandb:
            wandb_log = {
                "step": current_step,
                "loss/train_loss": metrics["loss"],
                "loss/train_cross_entropy_loss": metrics["loss_cross"],
                "lr": opt_state[1].hyperparams["learning_rate"],
            }
            if current_step % cfg.checkpoint_steps == 0:
                time_per_batch = time.time() - start
                eval_x, eval_y = val_dataset(step=cfg.eval_steps)
                val_metrics = eval_step(params, eval_x, eval_y)
                if use_wandb:
                    wandb_log["loss/val_loss"] = val_metrics["loss"]
                    wandb_log["loss/val_cross_entropy_loss"] = val_metrics["loss_cross"]
                jax.experimental.multihost_utils.sync_global_devices("sync")
                tokens_per_second = cfg.checkpoint_steps * total_tokens / time_per_batch
                train_loss = jnp.array(train_loss).mean().item()
                eval_loss = val_metrics["loss"].item()
                log_string = f"Step {current_step + 1}, Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}, tk/s: {tokens_per_second:,.2f}"
                log(log_string)
To avoid slowdown we can checkpoint every 10 eval steps. We can also include checkpointing to get the final training loop.
def main(cfg: config):
    ...
    for current_step in range(init_step, total_steps):
        key, train_key = jax.random.split(key)
        train_key = make_sharded_key(train_key, steps=cfg.grad_step)
        x, y = train_dataset(step=cfg.grad_step)
        params, opt_state, metrics = train_step(params, opt_state, x, y, train_key)
        train_loss.append(metrics["loss"])
        if use_wandb:
            wandb_log = {
                "step": current_step,
                "loss/train_loss": metrics["loss"],
                "loss/train_cross_entropy_loss": metrics["loss_cross"],
                "lr": opt_state[1].hyperparams["learning_rate"],
            }
        if current_step % cfg.checkpoint_steps == 0:
            time_per_batch = time.time() - start
            eval_x, eval_y = val_dataset(step=cfg.eval_steps)
            val_metrics = eval_step(params, eval_x, eval_y)
            if use_wandb:
                wandb_log["loss/val_loss"] = val_metrics["loss"]
                wandb_log["loss/val_cross_entropy_loss"] = val_metrics["loss_cross"]
            jax.experimental.multihost_utils.sync_global_devices("sync")
            tokens_per_second = cfg.checkpoint_steps * total_tokens / time_per_batch
            train_loss = jnp.array(train_loss).mean().item()
            eval_loss = val_metrics["loss"].item()
            log_string = f"Step {current_step + 1}, Loss: {train_loss:.4f}, Eval Loss: {eval_loss:.4f}, tk/s: {tokens_per_second:,.2f}"
            log(log_string)
            start = time.time()
            train_loss = []
    if current_step % 10 * cfg.checkpoint_steps == 0:
        save_checkpoint(current_step)
        if use_wandb:
            wandb.log(data=wandb_log, step=current_step)
Finally, we end the main function by calling wandb.finish() if we are using wandb. To kick off training, we can add a main guard that called jax.distrbuted.intialize() to sync the multi-controller processes and print the cfg from the parse_args().
if __name__ == "__main__":
    jax.distributed.initialize()
    cfg = parse_args()
    print(json.dumps(cfg.__dict__, indent=4, default=lambda o: o.__dict__))
    main(cfg)
We now look at how to scale this model further with MoE.