当前位置: 首页 > news >正文

LLaVA-NeXT 学习笔记

1. Prompt

是否可以请您参考PyTorch的文档格式和文档风格,使用Markdown格式为 `next_obs` 变量编写一段相应的文档说明呢?

2. Run

pip uninstall transformers -y && \
pip install git+https://github.com/huggingface/transformers.git@v4.40.0
pip uninstall sentence-transformers -y

2. load_pretrained_model

llava.model.builder.load_pretrained_model(model_path,model_base=None,model_name='llava',load_8bit=False,load_4bit=False,device_map='auto',torch_dtype='float16',attn_implementation='flash_attention_2',customized_config=None,overwrite_config=None,**kwargs
)

Loads a pretrained LLaVA model, language model, or LoRA adapter from a local path or Hugging Face Hub.

This function automatically detects the model type based on model_name and loads the appropriate model architecture. It supports various base language models (Llama, Qwen, Mixtral, Mistral, Gemma) and can handle full models, LoRA adapters, and multimodal projector-only checkpoints.

Parameters

model_path (str)

Path to the model directory or Hugging Face Hub model identifier. This can be:

  • A local directory path containing model files
  • A Hugging Face Hub model ID (e.g., "lmms-lab/llava-onevision-qwen2-0.5b-si")

model_base (str, optional)

Path to the base language model. Required when:

  • Loading LoRA adapters (the base model must be specified)
  • Loading multimodal projector-only checkpoints (the base language model must be provided)

If None, the function assumes the model at model_path is a complete model.

model_name (str, optional, defaults to "llava")

Model identifier used to determine the model architecture. The function automatically detects the model type based on keywords in this string:

  • "llava" or contains "llava": Loads a LLaVA multimodal model
  • "qwen" or "quyen": Loads Qwen-based model (supports MoE variants)
  • "llama": Loads Llama-based model
  • "mixtral": Loads Mixtral-based model
  • "mistral" or "zephyr": Loads Mistral-based model
  • "gemma": Loads Gemma-based model
  • "lora": Indicates a LoRA adapter model
  • "mpt": Loads MPT model (language model only)

load_8bit (bool, optional, defaults to False)

If True, loads the model in 8-bit quantization using BitsAndBytes. This reduces memory usage at the cost of some performance. Mutually exclusive with load_4bit.

load_4bit (bool, optional, defaults to False)

If True, loads the model in 4-bit quantization using BitsAndBytes with NF4 quantization type. This significantly reduces memory usage. Mutually exclusive with load_8bit.

Note: When using quantization, the model uses BitsAndBytesConfig with:

  • bnb_4bit_compute_dtype=torch.float16
  • bnb_4bit_use_double_quant=True
  • bnb_4bit_quant_type="nf4"

device_map (str or dict, optional, defaults to "auto")

Controls how the model is distributed across devices. Can be:

  • "auto": Automatically determines the optimal device mapping
  • "cpu": Loads the model on CPU
  • "cuda": Loads the model on the default CUDA device
  • A dictionary: Manual device mapping for specific layers

torch_dtype (str or torch.dtype, optional, defaults to "float16")

Data type for model weights. Supported values:

  • "float16" or torch.float16: Half precision (default)
  • "bfloat16" or torch.bfloat16: Brain floating point (requires compatible hardware)

Note: This parameter is ignored when load_8bit=True or load_4bit=True.

attn_implementation (str, optional, defaults to "flash_attention_2")

Attention implementation to use. Options:

  • "flash_attention_2": Uses Flash Attention 2 (requires flash-attn package)
  • None: Uses the default PyTorch attention implementation (SDPA)

Warning: If "flash_attention_2" is specified but the flash-attn package is not installed, the function will raise an ImportError. Set attn_implementation=None to use standard attention if Flash Attention is unavailable.

customized_config (PretrainedConfig, optional, defaults to None)

Custom model configuration object. If provided, this configuration will be used instead of loading from model_path. Useful for custom model configurations or debugging.

overwrite_config (dict, optional, defaults to None)

Dictionary of configuration attributes to overwrite after loading the base configuration. Keys should be configuration attribute names, and values are the new values to set.

Example:

overwrite_config = {"mm_projector_type": "mlp2x_gelu"}

**kwargs

Additional keyword arguments passed to the underlying from_pretrained() method. Common options include:

  • trust_remote_code: Whether to trust remote code in the model repository
  • low_cpu_mem_usage: Whether to use low CPU memory usage mode (defaults to True)
  • torch_dtype: Overrides the torch_dtype parameter if specified
  • Any other arguments supported by transformers.PreTrainedModel.from_pretrained()

Returns

Returns a tuple of four elements:

tokenizer (PreTrainedTokenizer)

The tokenizer associated with the loaded model. For LLaVA models, special image tokens are automatically added to the tokenizer vocabulary.

model (PreTrainedModel)

The loaded model instance. The exact class depends on the model architecture:

  • LlavaLlamaForCausalLM for Llama-based models
  • LlavaQwenForCausalLM for Qwen-based models
  • LlavaMixtralForCausalLM for Mixtral-based models
  • LlavaMistralForCausalLM for Mistral-based models
  • LlavaGemmaForCausalLM for Gemma-based models
  • AutoModelForCausalLM for pure language models

image_processor (ImageProcessor or None)

The image processor for LLaVA models. Returns None for pure language models. The processor is obtained from the vision tower’s image processor.

context_len (int)

The maximum context length (sequence length) supported by the model. Determined by:

  1. model.config.max_sequence_length (if available)
  2. model.config.max_position_embeddings (if available)
  3. model.config.tokenizer_model_max_length (if available)
  4. Defaults to 2048 if none of the above are found

Examples

Basic Usage: Loading a LLaVA Model

from llava.model.builder import load_pretrained_model# Load a LLaVA model from Hugging Face Hub
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",device_map="auto"
)print(f"Model loaded with context length: {context_len}")

Loading with Quantization (4-bit)

from llava.model.builder import load_pretrained_model# Load a 7B model with 4-bit quantization to reduce memory usage
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-next-7b",model_base=None,model_name="llava_qwen",load_4bit=True,device_map="auto",attn_implementation=None  # Disable Flash Attention if not available
)

Loading without Flash Attention

from llava.model.builder import load_pretrained_model# Load model without Flash Attention (useful if flash-attn is not installed)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",device_map="auto",attn_implementation=None  # Use standard attention
)

Loading a LoRA Adapter

from llava.model.builder import load_pretrained_model# Load a LoRA adapter with its base model
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="/path/to/lora/adapter",model_base="/path/to/base/model",model_name="llava_llama_lora",device_map="auto"
)

Loading with Custom Configuration

from llava.model.builder import load_pretrained_model
from llava.model.language_model.llava_qwen import LlavaQwenConfig# Create a custom configuration
custom_config = LlavaQwenConfig.from_pretrained("lmms-lab/llava-onevision-qwen2-0.5b-si")
custom_config.mm_projector_type = "mlp2x_gelu"# Load with custom configuration
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="lmms-lab/llava-onevision-qwen2-0.5b-si",model_base=None,model_name="llava_qwen",customized_config=custom_config,device_map="auto"
)

Loading a Pure Language Model

from llava.model.builder import load_pretrained_model# Load a pure language model (not multimodal)
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path="meta-llama/Llama-2-7b-hf",model_base=None,model_name="llama",  # Must not contain "llava"device_map="auto"
)
# image_processor will be None for language models

Notes

  • Memory Usage: For large models, consider using load_4bit=True or load_8bit=True to reduce memory requirements. 4-bit quantization can reduce memory usage by approximately 75%.

  • Flash Attention: Flash Attention 2 requires compatible hardware (typically Ampere architecture or newer GPUs). For older GPUs (e.g., RTX 2080 Ti), set attn_implementation=None.

  • Model Detection: The function automatically detects the model architecture based on model_name. Ensure the model name contains appropriate keywords (e.g., "qwen", "llama", "llava") for correct loading.

  • Special Tokens: For LLaVA models, the function automatically adds special image tokens (<image>, <im_start>, <im_end>) to the tokenizer and resizes token embeddings accordingly.

  • Vision Tower Loading: For LLaVA models, the vision tower is loaded lazily. It will be loaded when first accessed or can be explicitly loaded by calling model.get_vision_tower().load_model().

  • Device Mapping: When using device_map="auto", the function uses Hugging Face’s accelerate library to automatically distribute the model across available devices. For single GPU setups, the model will be placed on that GPU.

http://www.dtcms.com/a/576865.html

相关文章:

  • 投资融资理财网站模板网站搭建福州公司
  • OpenStack创建实例一直处于创建且未分配IP问题解决
  • C++的诗行:一文掌握内存管理中 new/delete 接口正确调用与常见场景适配
  • 谷歌网站 百度做网站对服务器什么要求高
  • Smartproxy 企业级解决方案
  • 图像分类深度学习
  • 自监督骨干(DINOv2)用于内镜分割与跟踪的全面实现分析
  • 6.基础--SQL--DDL表操作-创建查询
  • 《算法闯关指南:优选算法--位运算》--34.判断字符是否唯一,35.丢失的数字
  • 四川建设网网站首页网站开发 周期
  • linux怎么检查磁盘是否有坏道
  • 微信小程序开发——第三章:WXML 与 WXSS —— 小程序页面结构与样式设计
  • Pytorch 内存布局优化:Contiguous Memory
  • pytorch-张量
  • MYSQL CDC 同步到 PAIMON
  • MATLAB实现高光谱分类算法
  • Linux:WSL常用指令总结
  • Git 最近提交中不小心包含了多余的文件怎么办
  • T100打破远程孤岛-轻松实现异地组网-P2P打洞+UDP NAT 穿透
  • 建设网站人员名单企业网站建设报价单
  • 联通研究院:基于‘多模态SCA+全周期协同’的中间件开源风险治理实践
  • 五子棋项目Alpha-Beta剪枝与MCTS+神经网络实现人机对弈算法对比报告
  • 测试题-5
  • 商洛免费做网站公司网站设计策划案
  • Java 项目 HTTP+WebSocket 统一权限控制实战
  • Tomcat日志配置与优化指南
  • 技术演进中的开发沉思-174 java-EJB:分布式通信
  • HarmonyOS实战项目:AI健康助手(影像识别与健康分析)
  • 利用 AWS Lambda 与 EventBridge 优化低频 Java 作业的云计算成本
  • 工业和信息化部网站备案管理系统公司网站维护怎么维护