Inference#

Preparation#

Before running the inference code, users must ensure that they have correctly installed and configured all necessary environments according to the instructions in the Installation Document.

For those who prefer not to delve into the extensive technical details, you can just execute bash demos/start.sh and enjoy.

Prepare Checkpoints#

Our checkpoints are released at πŸ€—Hugging Face

For newer versions of LLaMA2-Accessory, the meta/config/tokenizer information is saved together with the model weights, so the saved checkpoints should present the following organization:

path/to/checkpoint
# model weights
β”œβ”€β”€ consolidated.00-of-02.model.pth
β”œβ”€β”€ consolidated.01-of-02.model.pth
# spm-format tokenizer 
β”œβ”€β”€ tokenizer.model
# huggingface-format tokenizer
β”œβ”€β”€ tokenizer.json
β”œβ”€β”€ tokenizer_config.json
# model configuration
β”œβ”€β”€ config.json
# meta information, currently only contains model type
└── meta.json

The Model weights are split and saved into m consolidated.n-of-m.model.pth files, where m is model parallel size. Note that for tokenizers, both spm and huggingface formats are supported. Either of them is enough and there is no need to have files of both formats simultaneously.

Legacy Checkpoints

Checkpoints saved by legacy versions of LLaMA2-Accessory only contain the consolidated.*.model.pth weight files. Such checkpoints are still usable, but the information about llama_type, llama_config and tokenizer_path need to be manually specified.

General Pipeline#

Model Instantiation#

The static method MetaModel.from_pretrained support convenient instantiation of LLaMA2-Accessory models based on pretrained checkpoints.

classmethod from_pretrained(pretrained_path: str | List[str], llama_type: Optional[str] = None, llama_config: Optional[str | List[str]] = None, tokenizer_path: Optional[str] = None, with_visual: bool = False, max_seq_len: int = 4096, mp_group: Optional[torch.distributed.ProcessGroup] = None, dtype=torch.bfloat16, device='cuda', quant=False)

Besides loading the consolidated.*.pth model weights, this function also tries to find tokenizer, β€˜meta.json’, and β€˜config.json’ under pretrained_path to configure the tokenizer_path, llama_type, and llama_config of the model. The automatically determined values will be overridden by user’s exploit specification of the arguments.

Parameters:
  • pretrained_path – Paths to directories containing consolidated.*.pth weight files. If multiple paths are given, weights will be loaded sequentially. Now repo_id also can be specified as a path.

  • llama_type – Type of the inner LLM. The corresponding model class definition should be found in accessory/model/LLM/llama_type.py. If not specified, this function will probe the meta.json file under pretrained_path to try to determine the value.

  • llama_config – Inner LLM configurations. Can be one or a list of strings, each of which is the path to a *.json configuration file. If not specified, this function will probe the config.json file under pretrained_path to try to determine the value.

  • tokenizer_path – LLaMA2-Accessory supports both spm tokenizers (provided by Meta, generally named tokenizer.model) and huggingface tokenizers (composed of tokenizer.json and tokenizer_config.json). When using spm tokenizers, tokenizer_path should point to the tokenizer.model file; when using huggingface tokenizers, tokenizer_path should point to the directory containing tokenizer.json and tokenizer_config.json. If not specified, this function will probe the pretrained_path directory for tokenizer in either format.

  • with_visual – Set it to True if the model is expected to receive image input. Inner LLM models rely on this argument to decide whether to instantiate the visual encoder.

  • max_seq_len – max context window size of the model

  • mp_group – If the parameters of the model are not split on multiple GPUs with model parallel, namely model parallel size == 1, then mp_group can be left to None. However, if model parallel is needed, mp_group should be an already initialized torch process group, ranks within which compose a logically complete model.

  • dtype – parameter data type

  • device – parameter device

  • quant – whether to quantize the model to 4bit

Returns:

An Accessory.model.MetaModel object with pretrained checkpoints loaded.

For checkpoints saved with newer versions of LLaMA2-Accessory, recourses like tokenizer and config are saved together with model weights. In such cases, only pretrained_path need to be specified.

from accessory.model.meta import MetaModel
model = MetaModel.from_pretrained("/path/to/pretrained")

Otherwise, for legacy versions of checkpoints, only consolidated.*.pth model weights are saved. In such cases, explicit specification of llama_type, llama_config and tokenizer_path is needed. Generally, this can be achieved by assigning the three arguments with the same assignment used for training.

from accessory.model.meta import MetaModel
example_llama_type='llama_peft'
example_llama_config=[
    '/path/to/llama/7B/params.json',
    'configs/model/finetune/sg/llamaPeft_normBiasLora.json']
example_tokenizer_path='/path/to/tokenizer.model'
model = MetaModel.from_pretrained(
    "/path/to/pretrained", llama_type=example_llama_type,
    llama_config=example_llama_config, tokenizer_path=example_tokenizer_path
)

Tip

See FAQ to know more about llama_config and tokenizer_path.

Input Construction & Response Generation#

Pretrained Models#

For pretrained models, namely those trained on large scale corpus without specific template, you can use any text as a prompt to make the model continue writing the content.

from accessory.model.meta import MetaModel

model = MetaModel.from_pretrained("/path/to/pretrained", max_seq_len=2048)

# for pretrained model (i.e. trained on large corpus without specific template)
prompt = "The best programming language in the world is"
response = model.generate([prompt], images=None, max_gen_len=512)[0]
print(response)
# or if you want to generate the response token by token
response = None
for response_in_progress in model.stream_generate(prompt, image=None, max_gen_len=512):
    response = response_in_progress['text']
print(response)

Single-turn-finetuned Models#

After instruction finetuning, it is important to keep the template consistent across finetuning and inference. The following shows an example of the Alpaca template, which is the default choice of LLaMA2-Accessory for single-turn finetuning. If you have used different templates during finetuning, don’t forget to continue using them for inference.

from accessory.model.meta import MetaModel
from accessory.data.system_prompt import format_prompt

model = MetaModel.from_pretrained("/path/to/pretrained", max_seq_len=2048)

# for single-turn-finetuned model
instruction = "What's the best programming language in the world?"
prompt = format_prompt({"instruction": instruction}, sys_name="alpaca")
# prompt is equal to:
#   "Below is an instruction that describes a task."
#   "Write a response that appropriately completes the request.\n\n"
#   "### Instruction:\nWhat's the best programming language in the world?\n\n### Response:"

response = model.generate([prompt], images=None, max_gen_len=512)[0]
print(response)
# or if you want to generate the response token by token
response = None
for response_in_progress in model.stream_generate(prompt, image=None, max_gen_len=512):
    response = response_in_progress['text']
print(response)

Multi-turn-finetuned Models#

Similar to the single-turn case, the template for multi-turn conversation should also be consistent across finetuning and inference. The following shows an example of the default template used by LLaMA2-Accessory.

from accessory.model.meta import MetaModel
from accessory.data.conversation import default_conversation

model = MetaModel.from_pretrained("/path/to/pretrained", max_seq_len=2048)

conv = default_conversation()

# for multi-turn-finetuned model
q1 = "What's the best programming language in the world?"
a1 = "The best programming language in the world in PHP."
q2 = "Are you sure? Why not Python?"
qas = [[q1, a1], [q2, None]]  # leave the last answer, namely the one to generate, to None
conv.load_qas(qas)
prompt = conv.get_prompt()
# prompt is equal to:
#   "A chat between a curious human and an artificial intelligence assistant. "
#   "The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n"
#   "### Human: What's the best programming language in the world?\n"
#   "### Assistant: The best programming language in the world in PHP.\n"
#   "### Human: Are you sure? Why not Python?\n"
#   "### Assistant:"

# conv_sep is the symbol marking the end of one response, equal to "###" in this example
conv_sep = conv.response_end_signal

# ------EITHER-------
response = None
for response_in_progress in model.stream_generate(
        prompt, image=None, max_gen_len=512, additional_stop_symbols=[conv_sep]
):
    response = response_in_progress['text']
print(response)
# --------OR---------
response = None
for response_in_progress in model.stream_generate(prompt, image=None, max_gen_len=512):
    sep_pos = response_in_progress["text"].find(conv_sep)
    if sep_pos != -1:
        response = response_in_progress["text"][:sep_pos]
        break
    else:
        response = response_in_progress["text"]
print(response)

Important

For pretrained and single-turn models, the end of the response is controlled by the generation of the <EOS> token. In contrast, for multi-turn models, the end of the response is determined by template-specific seperator, e.g. ### in the example above. Since MetaModel.generate (only when batch size == 1) and MetaModel.stream_generate automatically halt when <EOS> is generated, for pretrained and single-turn models, nothing special about generation halting need to be taken care of. However, for the multi-turn case, the halting symbol need to be explicitly specified, as we show in the example.

Multi-modal Models#

To inference with multi-modal models, you simply need to instantiate the MetaModel with with_visual=True, and pass the image(s) to the generation function:

from accessory.model.meta import MetaModel
from accessory.data.transform import get_transform
from PIL import Image

model = MetaModel.from_pretrained("/path/to/pretrained", with_visual=True, max_seq_len=2048)

image = Image.open("/path/to/image").convert("RGB")
transform_type = "padded_resize"  # or "resized_center_crop". Make it consistent across training & inference
transform = get_transform(transform_type, getattr(model.llma, 'image_size', 224))
image = transform(image).unsqueeze(0).cuda().bfloat16()

# ---------single turn---------
from accessory.data.system_prompt import format_prompt

prompt = format_prompt({"instruction": "What's in the image?"}, sys_name="alpaca")
response = None
for response_in_progress in model.stream_generate(prompt, image=image, max_gen_len=512):
    response = response_in_progress['text']
print(response)


# ---------multi turn---------
from accessory.data.conversation import default_conversation

qas = [["What's in the image?", None]]
conv = default_conversation()
conv.load_qas(qas)
prompt = conv.get_prompt()
conv_sep = conv.response_end_signal

response = None
for response_in_progress in model.stream_generate(
        prompt, image=image, max_gen_len=512, additional_stop_symbols=[conv_sep]
):
    response = response_in_progress['text']
print(response)

Multi-GPU Inference with Model Parallelism#

from accessory.model.meta import MetaModel
from accessory.data.system_prompt import format_prompt

import random 
import numpy as np

import torch
import torch.distributed as dist
import multiprocessing as mp

def main(world_size, rank) -> None:
    # specify random seed to ensure consistent token sampling among model parallel ranks
    random.seed(0)
    torch.random.manual_seed(0)
    np.random.seed(0)
    
    dist.init_process_group(
        backend="nccl", rank=rank, world_size=world_size,
        init_method=f"tcp://127.0.0.1:23560",
    )
    torch.cuda.set_device(rank)
    
    # mp_group identifies which ranks will work collaboratively through model parallelism
    model = MetaModel.from_pretrained("/path/to/pretrained", max_seq_len=2048,
                                      mp_group=dist.new_group(ranks=list(range(dist.get_world_size()))))

    instruction = "What's the best programming language in the world?"
    prompt = format_prompt({"instruction": instruction}, sys_name="alpaca")

    response = None
    for response_in_progress in model.stream_generate(prompt, image=None, max_gen_len=512):
        response = response_in_progress['text']
        print(response)


if __name__ == "__main__":
    N_GPU = 2
    if N_GPU == 1:
        main(world_size=1, rank=0)
    elif N_GPU > 1:
        # You can use whatever method, e.g. torchrun, slurm, etc. for distributed launch
        # Just be sure to initialize torch distributed (by invoking dist.init_process_group)
        # before creating the model if model parallel size > 1 is used
        mp.set_start_method("spawn")
        for rank in range(N_GPU):
            process = mp.Process(target=main, args=(N_GPU, rank))
            process.start()
    else:
        raise ValueError

Host Local Demos#

We provide a series of scripts to host local gradio demos for easier interaction with trained LLaMA2-Accessory models.

Important

As we have mentioned in Prepare Checkpoints, the --llama_type``, --llama_config, and --tokenizer_path arguments in the launching commands listed below can be omitted as long as files recording the corresponding information exist under the path --pretrained_path point to.

Single-turn Single-modal Dialogue#

Use the single_turn.py script for single-turn dialogues:

torchrun --nproc-per-node=$NPROC --master-port=$PORT demos/single_turn.py \
--pretrained_path $PRETRAINED --llama_type $LLAMA_TYPE --llama_config $LLAMA_CONFIG --tokenizer_path $TOKENIZER

# (Optional) Quantization-assistant Inference. To run on GPUs with limited VRAM, add the "--quant" flag.
# For example, less than 7GB of VRAM is required for the 7B model.
torchrun --nproc-per-node=$NPROC --master-port=$PORT demos/single_turn.py \
<--some_flags> --quant

Single-turn Multi-modal Dialogue#

Use the single_turn_mm.py script for single-turn multi-modal dialogues:

torchrun --nproc-per-node=$NPROC --master-port=$PORT demos/single_turn_mm.py \
--pretrained_path $PRETRAINED --llama_type $LLAMA_TYPE --llama_config $LLAMA_CONFIG --tokenizer_path $TOKENIZER

# (Optional) Quantization-assistant Inference. To run on GPUs with limited VRAM, add the "--quant" flag.
# For example, less than 7GB of VRAM is required for the 7B model.
torchrun --nproc-per-node=$NPROC --master-port=$PORT demos/single_turn.py \
<--some_flags> --quant

Multi-turn Single-modal Dialogue#

For multi-turn single-modal dialogues, use the multi_turn.py script:

python demos/multi_turn.py --n_gpus $NPROC \
--pretrained_path $PRETRAINED --llama_type $LLAMA_TYPE --llama_config $LLAMA_CONFIG --tokenizer_path $TOKENIZER

# (Optional) Quantization-assistant Inference. To run on GPUs with limited VRAM, add the "--quant" flag.
# For example, less than 7GB of VRAM is required for the 7B model.
python demos/multi_turn.py <--some_flags> --quant

Multi-turn Multi-modal Dialogue#

For multi-turn multi-modal dialogues, use the multi_turn_mm.py script:

python demos/multi_turn_mm.py --n_gpus $NPROC \
--pretrained_path $PRETRAINED --llama_type $LLAMA_TYPE --llama_config $LLAMA_CONFIG --tokenizer_path $TOKENIZER

# (Optional) Quantization-assistant Inference. To run on GPUs with limited VRAM, add the "--quant" flag.
# For example, less than 7GB of VRAM is required for the 7B model.
python demos/multi_turn_mm.py <--some_flags> --quant

Model Zoo#

β”œβ”€β”€ convert
β”‚   └── sg
β”‚       β”œβ”€β”€ mixtral-8x7b-32kseqlen
β”‚       β”œβ”€β”€ Falcon
β”‚       β”œβ”€β”€ Falcon_180b
β”‚       └── InternLM
└── finetune
    β”œβ”€β”€ mm
    β”‚   β”œβ”€β”€ alpacaLlava_llamaQformerv2
    β”‚   β”œβ”€β”€ alpacaLlava_llamaQformerv2_13b
    β”‚   β”œβ”€β”€ alpacaLlava_llamaQformerv2Peft_13b
    β”‚   β”œβ”€β”€ caption_llamaQformerv2
    β”‚   β”œβ”€β”€ caption_llamaQformerv2_13b
    β”‚   └── SPHINX
    β”‚       β”œβ”€β”€ SPHINX
    β”‚       β”œβ”€β”€ SPHINX-1k
    β”‚       └── SPHINX-v2-1k
    └── sg
        β”œβ”€β”€ alpaca
        β”œβ”€β”€ alpaca_internLM_en
        β”œβ”€β”€ alpaca_internLM_zh
        β”œβ”€β”€ alpaca_llamaPeft_normBias
        β”œβ”€β”€ dialog_flan
        β”œβ”€β”€ dialog_lima
        β”œβ”€β”€ dialog_mossΒ 
        β”œβ”€β”€ dialog_platypus
        β”œβ”€β”€ dialog_sharegpt
        β”œβ”€β”€ dialog_sharegpt_70b
        β”œβ”€β”€ dialog_ultra
        β”œβ”€β”€ dialog_wizardcode
        β”œβ”€β”€ dialog_wizardcode_codellama
        β”œβ”€β”€ dialog_wizardcode_loadcode220k
        β”œβ”€β”€ dialog_wizardLM
        └── gorilla

How to Apply Delta Weights (Outdated)#

Warning

This section may be outdated as we have now released the full-version (i.e. merged) pretrained weights directly. Applying delta is no longer needed.

We release checkpoints as delta weights to comply with the LLaMA2 model license. To use our provided weights for inference or further tuning, please first add our delta to the original LLaMA2 weights to obtain the full weights:

Instructions:

  1. After agreeing to the License, Acceptable Use Policy, and Meta’s privacy policy, proceed to download the LLaMA2 weights from here.

  2. Utilize the following scripts to obtain finetuned weights by applying our delta. Make sure to download the delta weights from the model release page.

For those who wish to download smaller models like peft, we have retained the delta weights. Simply add the --down_diff argument during download to facilitate the process.

# For Download
python tools/download.py  --model_name check/in/release/page --input_type sg/or/mm --output_path path/to/save --model_size 7B/13B/70B --down_config --down_diff
# For Merging
python tools/weight_operate.py  --pretrained_path /path/to/llama2/ --delta_path /path/to/delta --output_path /path/to/finetuned
# For Separation
python tools/weight_operate.py  --pretrained_path /path/to/llama2/ --delta_path /path/to/finetuned --output_path /path/to/delta --operate_type extract