mixtral-8x7b#

mixtral-8x7b is a Mixture-of-Expert (MoE) model. In this tutorial, we will introduce how to inference with and to finetune the model.

Online Demo of Finetuned Model 🚀🚀🚀

We host a web demo 💻here, which shows a mixtral-8x7b model finetuned on evol-codealpaca-v1 and ultrachat_200k, with LoRA and Bias tuning.

Features#

With LLaMA2-Accessory, mixtral-8x7b enjoys the following features:

  1. Two Implementations

  2. Load Balancing Loss

  3. Tensor Parallel and FSDP for efficiently training

  4. Distributed and/or quantized inference

  5. Multi-modal support

Model Implementation#

There are generally two approaches to implement the Mixture of Experts (MoE) layers:

  1. The base implementation (Distribute Different Experts to Different GPUs): For example, given 8 experts and 4 GPUs, each GPU will be allocated with two experts. This is the approach adopted by DiscoResearch and llama-mistral.

  2. The sparse implementation (Distribute a Part of Each Expert to Every GPU): For example, given 8 experts and 4 GPUs, each GPUs will hold 1/4 of each expert. Such portioning of individual expert is achieved by splitting along the FFN hidden dim. This is the approach officially adopted by MistralAI. We call it the sparse approach because it reformulates MoE computation to block-sparse operations.

LLaMA2-Accessory supports both implementations. The two implementations are completely interchangeable. However, Benefited from the meticulously designed operators and the desirable nature of balanced computation load among GPUs, the second implementation is generally more efficient. On the other hand, the first one may be easier for beginners to understand, and is also easier to be combined with LoRA.

The base implementation of Mixtral-8x7b is in mixtral.py; a corresponding PEFT version (supporting bias/norm/LoRA tuning) is in mixtral_peft.py. For the base implementation, we prioritize simplicity over efficiency.

The sparse implementation of Mixtral-8x7b is in mixtral_sparse.py. Based on the implementation, mixtral_sparse_ens.py implements a multi-modal model, with similar architecture to SPHINX but using mixtral-8x7b instead of LLaMA2 as LLM backbone. We are actively working on this multi-modal model and the checkpoint will be released soon. For the sparse implementation, we place greater emphasis on efficiency. Specifically, we have referred to the official implementation and introduced some efficient operators from megablocks and stk.

Install#

Please follow the instructions here to install LLaMA2-Accessory, which is an easy-to-use and comprehensive toolkit for LLM development. If you want to use the sparse implementation of mixtral-8x7b, please also install megablocks and stk according their the official guides.

Prepare Checkpoint#

Given the official mixtral-8x7b checkpoints, a step of format conversion is needed to make them usable by LLaMA2-Accessory. We have released the off-the-shelf converted checkpoints. Alternatively, you can convert them by yourself according to the following guides.

Important

Despite being two equivalent implementations of the same model, the checkpoints of the base and the sparse implementations are not interchangeable. Please ensure to use the correct checkpoint.

A. Download Converted Checkpoints#

The converted checkpoints are released at 🤗HuggingFace. For the base implementation, the checkpoint is provided at 🤗base checkpoint; For the sparse implementation, the checkpoint is provided at 🤗sparse checkpoint. please download all the files in the folders to your machine.

B. Convert by Yourself#

1. prepare the original checkpoints#

The original checkpoints (torrent release) are available at https://huggingface.co/someone13574/mixtral-8x7b-32kseqlen, please first download the 10 splits and then cat them into one follow the official guides. After this step, you should have the consolidated.00.pth file.

2. convert#

For base implementation

Download the split.py script and put it in the same directory as consolidated.00.pth. Run the following command to convert:

python split.py

After running, you should see a folder named converted created, with eight consolidated.**-of-08.model.pth files therein.

For sparse implementation

Download the split_sparse.py script and put it in the same directory as consolidated.00.pth. Run the following command to convert:

python split_sparse.py

After running, you should see a folder named converted_sparse created, with eight consolidated.**-of-08.model.pth files therein.

3. prepare other resources#

For base implementation

Finally, please download the following three files from our HuggingFace repo:

and put them under the converted directory, next to the weight files you obtained in the previous step.

For sparse implementation

Finally, please download the following three files from our HuggingFace repo:

and put them under the converted_sparse directory, next to the weight files you obtained in the previous step.

Result#

No matter you have downloaded or converted the checkpoints on your own, you should finally get the following file structure:

path/to/converted OR path/to/converted_sparse
# model weights
├── consolidated.00-of-04.model.pth
├── consolidated.01-of-04.model.pth
├── consolidated.02-of-04.model.pth
├── consolidated.03-of-04.model.pth
# spm-format tokenizer 
├── tokenizer.model
# model configuration
├── config.json
# meta information, currently only contains model type
└── meta.json

Inference#

Simple Inference#

You can run inference on 8, 4, 2, or 1 GPUs. With tensor parallel and distributed MoE, the more GPUs you use, the less memory and computation load exists on each individual GPU. The following code exemplifies the inference process.

from accessory.model.meta import MetaModel

import random 
import numpy as np

import torch
import torch.distributed as dist
import multiprocessing as mp

def main(world_size, rank) -> None:
    # specify random seed to ensure consistent token sampling among model parallel ranks
    random.seed(0)
    torch.random.manual_seed(0)
    np.random.seed(0)
    
    dist.init_process_group(
        backend="nccl", rank=rank, world_size=world_size,
        init_method=f"tcp://127.0.0.1:23560",
    )
    torch.cuda.set_device(rank)
    
    pretrained_path = "/path/to/converted"  # converted checkpoints of either base or sparse format
    # mp_group identifies which ranks will work collaboratively through model parallelism
    model = MetaModel.from_pretrained(pretrained_path, max_seq_len=2048,
                                      mp_group=dist.new_group(ranks=list(range(dist.get_world_size()))))

    prompt = "The best programming language in the world is"

    response = model.generate([prompt], images=None, max_gen_len=512)[0]
    if rank == 0:  # without this filter, the response will be printed for `world_size` times
        print(response)
    # or if you want to generate the response token by token
    response = None
    for response_in_progress in model.stream_generate(prompt, image=None, max_gen_len=512):
        response = response_in_progress['text']
        if rank == 0:
            print(response)


if __name__ == "__main__":
    N_GPU = 8 # 1, 2, 4, or 8
    if N_GPU == 1:
        main(world_size=1, rank=0)
    elif N_GPU > 1:
        # You can use whatever method, e.g. torchrun, slurm, etc. for distributed launch
        # Just be sure to initialize torch distributed (by invoking dist.init_process_group)
        # before creating the model if model parallel size > 1 is used
        mp.set_start_method("spawn")
        for rank in range(N_GPU):
            process = mp.Process(target=main, args=(N_GPU, rank))
            process.start()
    else:
        raise ValueError

A thorough tutorial over the inference with LLaMA2-Accessory can be found in the document. In the above example, pretrained_path should be replaced with the real path of the checkpoints prepared in the previous section. The from_pretrained method will then probe the meta.json file in the given path to discern the type of llm used, namely the llama_type argument for initializing a Meta model. For the base implementation, llama_type is mixtral; otherwise for the sparse implementation, llama_type is mixtral_sparse.

Host Local Demo#

LLaMA2-Accessory provides a series of gradio demos for efficient interaction with your model. To host a local demo for the pretrained mixtral-8x7b model, follow the steps below:

cd LLaMA2-Accessory/accessory
torchrun --nproc-per-node=$N_GPUS_TO_USE --master-port=$PORT demos/single_turn.py \
--pretrained_path $PRETRAINED_PATH

As we have mentioned in the Simple Inference section, $N-GPUS-TO-USE can be 1, 2, 4, or 8. $PRETRAINED should be the directory containing the converted (base or sparse) checkpoints, and $PORT can be any free port.

Tip

demos/single_turn.py file was designed to support both pretrained models and models finetuned with alpaca-style template. For pretrained models, please set the system_prompt option to None in the Web GUI. See the LLaMA2-Accessory document to know more about finetune and inference.

Finetuning#

LLaMA2-Accessory supports both full-parameter and parameter-efficient finetuning of mixtral-8x7b. It also supports the load balancing regularization loss. More advanced MoE support will come soon.

Data#

We use the following datasets to exemplify finetuning:

The two files are referred to by the dialog_ultrachat200kWizardcode.yaml file, which is then used by the *.sh experiments shown below to define the data for fientuning. Note that the data need to be processed to match the format usable by LLaMA2-Accessory. For convenience, we provide the processed data files for 💾evol-codealpaca-v1 and 💾ultrachat_200k. Please move them to the position specified by dialog_ultrachat200kWizardcode.yaml

Full Finetune#

For the base implementation:

cd LLaMA2-Accessory/accessory
srun -n32 --gres=gpu:8 --ntasks-per-node=8 bash \
exps/finetune/sg/dialog_ultrachat200kWizardcode_mixtral.sh \
/path/to/converted \
/path/to/converted/config.json \
/path/to/converted/tokenizer.model

For the sparse implementation, change dialog_ultrachat200kWizardcode_mixtral.sh to dialog_ultrachat200kWizardcode_mixtralSparse.sh (where the only different is changing the llama_type argument from mixtral to mixtral_sparse), and /path/to/converted to path/to/converted_sparse.

PEFT#

cd LLaMA2-Accessory/accessory
srun -n16 --gres=gpu:8 --ntasks-per-node=8 bash \
exps/finetune/sg/dialog_ultrachat200kWizardcode_mixtralPeft.sh \
/path/to/converted \
/path/to/converted/config.json \
/path/to/converted/tokenizer.model

Finetuned Model Release:

Host Local Demo

cd LLaMA2-Accessory/accessory
python demos/multi_turn.py --n_gpus $N_GPUS_TO_USE --pretrained_path $PATH_TO_FINETUNED

See the LLaMA2-Accessory document to know more about finetuning and inference.

Acknowledgement#