2025.03.03测试ok
from safetensors.torch import load_file
import yaml
with open("configs/maggie_image.yaml", 'r', encoding='utf8') as file: # utf8可识别中文
data = yaml.safe_load(file)
class Config:
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
# 将字典转换为类的实例
config = Config(**data)
# Create image model
image_model = MaGGIe(cfg=config.model)
# image_model = MaGGIe.from_pretrained("chuonghm/maggie-image-him50k-cvpr24")
tensors = load_file("model_image/model.safetensors")
# 将加载的张量赋值给模型的状态字典
image_model.load_state_dict(tensors)
image_model = image_model.eval()
image_model = image_model.cuda()
CONFIG.merge_from_file("configs/maggie_image.yaml")