LLaVA-NeXT 学习笔记
1. Prompt
是否可以请您参考PyTorch的文档格式和文档风格,使用Markdown格式为 `next_obs` 变量编写一段相应的文档说明呢?
2. Run
pip uninstall transformers -y && \
pip install git+https://github.com/huggingface/transformers.git@v4.40.0
pip uninstall sentence-transformers -y
2. load_pretrained_model
llava.model.builder.load_pretrained_model(model_path,model_base=None,model_name='llava',load_8bit=False,load_4bit=False,device_map='auto',torch_dtype='float16',attn_implementation='flash_attention_2',customized_config=None,overwrite_config=None,**kwargs
)
Loads a pretrained LLaVA model, language model, or LoRA adapter from a local path or Hugging Face Hub.
This function automatically detects the model type based on model_name and loads the appropriate model architecture. It supports various base language models (Llama, Qwen, Mixtral, Mistral, Gemma) and can handle full models, LoRA adapters, and multimodal projector-only checkpoints.
Parameters
model_path (str)
Path to the model directory or Hugging Face Hub model identifier. This can be:
- A local directory path containing model files
- A Hugging Face Hub model ID (e.g.,
"lmms-lab/llava-onevision-qwen2-0.5b-si")
model_base (str, optional)
Path to the base language model. Required when:
- Loading LoRA adapters (the base model must be specified)
- Loading multimodal projector-only checkpoints (the base language model must be provided)
If None, the function assumes the model at model_path is a complete model.
model_name (str, optional, defaults to "llava")
Model identifier used to determine the model architecture. The function automatically detects the model type based on keywords in this string:
"llava"or contains"llava": Loads a LLaVA multimodal model"qwen"or"quyen": Loads Qwen-based model (supports MoE variants)"llama": Loads Llama-based model"mixtral": Loads Mixtral-based model"mistral"or"zephyr": Loads Mistral-based model"gemma": Loads Gemma-based model"lora": Indicates a LoRA adapter model"mpt": Loads MPT model (language model only)
load_8bit (bool, optional, defaults to False)
If True, loads the model in 8-bit quantization using BitsAndBytes. This reduces memory usage at the cost of some performance. Mutually exclusive with load_4bit.
load_4bit (bool, optional, defaults to False)
If True, loads the model in 4-bit quantization using BitsAndBytes with NF4 quantization type. This significantly reduces memory usage. Mutually exclusive with load_8bit.
Note: When using quantization, the model uses BitsAndBytesConfig with:
bnb_4bit_compute_dtype=torch.float16bnb_4bit_use_double_quant=Truebnb_4bit_quant_type="nf4"
device_map (str or dict, optional, defaults to "auto")
Controls how the model is distributed across devices. Can be:
"auto": Automatically determines the optimal device mapping"cpu": Loads the model on CPU"cuda": Loads the model on the default CUDA device- A dictionary: Manual device mapping for specific layers
torch_dtype (str or torch.dtype, optional, defaults to "float16")
Data type for model weights. Supported values:
"float16"ortorch.float16: Half precision (default)"bfloat16"ortorch.bfloat16: Brain floating point (requires compatible hardware)
Note: This parameter is ignored when load_8bit=True or load_4bit=True.
attn_implementation (str, optional, defaults to "flash_attention_2")
Attention implementation to use. Options:
"flash_attention_2": Uses Flash Attention 2 (requiresflash-attnpackage)None: Uses the default PyTorch attention implementation (SDPA)
Warning: If "flash_attention_2" is specified but the flash-attn package is not installed, the function will raise an ImportError. Set attn_implementation=None to use standard attention if Flash Attention is unavailable.
customized_config (PretrainedConfig, optional, defaults to None)
Custom model configuration object. If provided, this configuration will be used instead of loading from model_path. Useful for custom model configurations or debugging.
overwrite_config (dict, optional, defaults to None)
Dictionary of configuration attributes to overwrite after loading the base configuration. Keys should be configuration attribute names, and values are the new values to set.
Example:
overwrite_config = {"mm_projector_type": "mlp2x_gelu"}
**kwargs
Additional keyword arguments passed to the underlying from_pretrained() method. Common options include:
trust_remote_code: Whether to trust remote code in the model repositorylow_cpu_mem_usage: Whether to use low CPU memory usage mode (defaults toTrue)torch_dtype: Overrides thetorch_dtypeparameter if specified- Any other arguments supported by
transformers.PreTrainedModel.from_pretrained()
Returns
Returns a tuple of four elements:
tokenizer (PreTrainedTokenizer)
The tokenizer associated with the loaded model. For LLaVA models, special image tokens are automatically added to the tokenizer vocabulary.
model (PreTrainedModel)
The loaded model instance. The exact class depends on the model architecture:
LlavaLlamaForCausalLMfor Llama-based modelsLlavaQwenForCausalLMfor Qwen-based modelsLlavaMixtralForCausalLMfor Mixtral-based modelsLlavaMistralForCausalLMfor Mistral-based modelsLlavaGemmaForCausalLMfor Gemma-based modelsAutoModelForCausalLMfor pure language models
image_processor (ImageProcessor or None)
The image processor for LLaVA models. Returns None for pure language models. The processor is obtained from the vision tower’s image processor.
context_len (int)
The maximum context length (sequence length) supported by the model. Determined by:
model.config.max_sequence_length(if available)model.config.max_position_embeddings(if available)model.config.tokenizer_model_max_length(if available)- Defaults to
2048if none of the above are found
Examples
Basic Usage: Loading a LLaVA Model
from llava.model.builder import load_pretrained_model# Load a LLaVA model from Hugging Face Hub
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",device_map="auto"
)print(f"Model loaded with context length: {context_len}")
Loading with Quantization (4-bit)
from llava.model.builder import load_pretrained_model# Load a 7B model with 4-bit quantization to reduce memory usage
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-next-7b",model_base=None,model_name="llava_qwen",load_4bit=True,device_map="auto",attn_implementation=None # Disable Flash Attention if not available
)
Loading without Flash Attention
from llava.model.builder import load_pretrained_model# Load model without Flash Attention (useful if flash-attn is not installed)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",device_map="auto",attn_implementation=None # Use standard attention
)
Loading a LoRA Adapter
from llava.model.builder import load_pretrained_model# Load a LoRA adapter with its base model
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="/path/to/lora/adapter",model_base="/path/to/base/model",model_name="llava_llama_lora",device_map="auto"
)
Loading with Custom Configuration
from llava.model.builder import load_pretrained_model
from llava.model.language_model.llava_qwen import LlavaQwenConfig# Create a custom configuration
custom_config = LlavaQwenConfig.from_pretrained("lmms-lab/llava-onevision-qwen2-0.5b-si")
custom_config.mm_projector_type = "mlp2x_gelu"# Load with custom configuration
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",customized_config=custom_config,device_map="auto"
)
Loading a Pure Language Model
from llava.model.builder import load_pretrained_model# Load a pure language model (not multimodal)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="meta-llama/Llama-2-7b-hf",model_base=None,model_name="llama", # Must not contain "llava"device_map="auto"
)
# image_processor will be None for language models
Notes
-
Memory Usage: For large models, consider using
load_4bit=Trueorload_8bit=Trueto reduce memory requirements. 4-bit quantization can reduce memory usage by approximately 75%. -
Flash Attention: Flash Attention 2 requires compatible hardware (typically Ampere architecture or newer GPUs). For older GPUs (e.g., RTX 2080 Ti), set
attn_implementation=None. -
Model Detection: The function automatically detects the model architecture based on
model_name. Ensure the model name contains appropriate keywords (e.g.,"qwen","llama","llava") for correct loading. -
Special Tokens: For LLaVA models, the function automatically adds special image tokens (
<image>,<im_start>,<im_end>) to the tokenizer and resizes token embeddings accordingly. -
Vision Tower Loading: For LLaVA models, the vision tower is loaded lazily. It will be loaded when first accessed or can be explicitly loaded by calling
model.get_vision_tower().load_model(). -
Device Mapping: When using
device_map="auto", the function uses Hugging Face’sacceleratelibrary to automatically distribute the model across available devices. For single GPU setups, the model will be placed on that GPU.
