当前位置: 首页 > news >正文

生成任务,大模型

一个生成项目

输入:文字描述(但是给的数据集是一串数字,id,ct描述,医生描述)
输出:诊断报告

一、数据处理

import pandas as pd  #处理表格数据

pre_train_file= "data/train.csv"

train_df = pd.read_csv(pre_train_file,header=None,names=["id","input","tgt"]) #读入数据


print(train_df.head())




train_data = train_df.sample(frac=0.9, random_state=0, axis=0)   #采样0.9的比例

val_data = train_df[~train_df.index.isin(train_data.index)]       #干啥的,  过来用

train_data.to_csv("data/pro_train_data.csv", index=False,header=False)

val_data.to_csv("data/pro_val_data.csv", index=False,header=False)

主要是用于从一个CSV文件中读取数据,并将其划分为训练集和验证集,然后将这两个数据集分别保存到新的CSV文件中。

代码逐行解释

导入必要的库
import pandas as pd  # 处理表格数据
  • pandas:一个强大的数据分析和处理库,特别适合处理表格数据(如CSV文件)。
定义文件路径并读取数据
pre_train_file = "data/train.csv"

train_df = pd.read_csv(pre_train_file, header=None, names=["id", "input", "tgt"])  # 读入数据

print(train_df.head())
  • pre_train_file:指定要读取的CSV文件路径。
  • pd.read_csv
    • header=None:表示CSV文件没有表头(第一行不是列名)。
    • names=["id", "input", "tgt"]:为每一列指定名称。
  • print(train_df.head()):打印前五行数据,以便检查读取是否正确。
数据划分
train_data = train_df.sample(frac=0.9, random_state=0, axis=0)  # 采样0.9的比例

val_data = train_df[~train_df.index.isin(train_data.index)]  # 干啥的, 过来用
  • train_data

    • 使用 sample 方法随机采样90%的数据作为训练集。
    • frac=0.9:表示采样的比例为90%。
    • random_state=0:设置随机种子以确保结果可重复。
    • axis=0:表示沿行方向进行采样(默认行为)。
  • val_data

    • 使用 ~train_df.index.isin(train_data.index) 来获取不在训练集中的数据作为验证集。
    • isin(train_data.index) 返回一个布尔数组,指示哪些索引在训练集中。
    • ~ 取反操作符,返回不在训练集中的索引。
保存数据
train_data.to_csv("data/pro_train_data.csv", index=False, header=False)

val_data.to_csv("data/pro_val_data.csv", index=False, header=False)
  • to_csv 方法
    • 将DataFrame保存为CSV文件。
    • index=False:不保存行索引。
    • header=False:不保存列名。

二、处理词表

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args

args = parse_args()         #设置 ,字典, 属性类  config  {}

def load_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    datas = []
    for line in lines:
        line = line.strip().split(",")
        if len(line) == 3:
            # 训练集
            text, target = line[1].split(" "), line[2].split(" ")
            datas.append(text + target)
        else:
            text = line[1].split(" ")
            datas.append(text)
    return datas


train_data = load_data('./data/train.csv')

token2count = Counter()     #计数工具 哈希表

for i in train_data:
    token2count.update(i)       #不需要知道原理


tail = []
ct = 0
for k, v in token2count.items():
    if v >= ct:
        tail.append(k)
tail.sort()
vocab = tail

vocab.insert(0,"[PAD]")
vocab.insert(100,"[UNK]")
vocab.insert(101,"[CLS]")
vocab.insert(102,"[SEP]")
vocab.insert(103,"[MASK]")
vocab.insert(104,"[EOS]")
# tokenizer = BertTokenizer.from_pretrained(args.pre_model_path)
# vocabs = tokenizer.get_vocab()   #获取模型词表

# new_vocabs = list(vocabs.keys())
# print(len(vocabs))
# count = 0
# for v in vocab:         #mn复杂度
#     if v not in vocabs:
#         count += 1
#         new_vocabs.append(v)
# print(len(new_vocabs))
new_vocabs = vocab
with open(args.pre_model_path+'/vocab.txt', 'w', encoding='utf-8') as f:
    for v in new_vocabs:
        f.write(f"{v}\n")    #保存


model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)      #模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path+'/pytorch_model.bin')
bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)

1. 导入必要的库

import sys
import torch
from collections import Counter
from transformers import BertTokenizer
from transformers import BartConfig
from transformers import BartForConditionalGeneration
from model_utils.config import parse_args
  • sys:用于系统相关的操作(如命令行参数)。
  • torch:PyTorch的核心库,用于深度学习模型。
  • Counter:来自 collections 模块,用于统计元素出现的次数。
  • BertTokenizer, BartConfig, BartForConditionalGeneration:来自 transformers 库,分别用于分词、配置和加载预训练模型。
  • parse_args:自定义函数,用于解析命令行参数或配置文件,返回一个包含配置参数的对象。

2. 解析参数

args = parse_args()  # 设置,字典,属性类 config {}
  • parse_args:调用自定义函数解析配置参数,并将其存储在 args 对象中。假设 args 包含诸如 pre_model_path 等路径信息。

3. 定义数据加载函数

def load_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    datas = []
    for line in lines:
        line = line.strip().split(",")
        if len(line) == 3:
            # 训练集
            text, target = line[1].split(" "), line[2].split(" ")
            datas.append(text + target)
        else:
            text = line[1].split(" ")
            datas.append(text)
    return datas
  • load_data 函数
    • 打开指定路径的文件并读取每一行。
    • 使用 strip() 去除每行的前后空白字符,并使用 split(",") 将其按逗号分割为列表。
    • 如果列表长度为3(假设是训练集),则将第二列和第三列的数据拆分为单词列表,并合并后添加到 datas 列表中。
    • 如果列表长度不为3,则仅处理第二列的数据,并将其拆分为单词列表后添加到 datas 列表中。
    • 返回 datas 列表。

4. 加载数据

train_data = load_data('./data/train.csv')
  • 调用 load_data 函数加载训练数据,并将结果存储在 train_data 变量中。

5. 统计词频

token2count = Counter()  # 计数工具 哈希表

for i in train_data:
    token2count.update(i)  # 不需要知道原理
  • token2count:使用 Counter 类创建一个哈希表来统计每个单词出现的次数。
  • 遍历 train_data 中的每一行数据,并使用 update 方法更新 token2count,记录每个单词出现的次数。

6. 创建词汇表

tail = []
ct = 0
for k, v in token2count.items():
    if v >= ct:
        tail.append(k)
tail.sort()
vocab = tail

vocab.insert(0, "[PAD]")
vocab.insert(100, "[UNK]")
vocab.insert(101, "[CLS]")
vocab.insert(102, "[SEP]")
vocab.insert(103, "[MASK]")
vocab.insert(104, "[EOS]")
  • tail:筛选出频率大于等于 ct 的单词,并按字母顺序排序。注意这里 ct 设为0,因此所有单词都会被包含进来。
  • vocab:将 tail 赋值给 vocab
  • 插入特殊标记:在 vocab 中插入一些特殊的标记符号(如 [PAD], [UNK], [CLS], [SEP], [MASK], [EOS]),这些标记在自然语言处理任务中具有特定含义。

7. 保存词汇表

new_vocabs = vocab
with open(args.pre_model_path + '/vocab.txt', 'w', encoding='utf-8') as f:
    for v in new_vocabs:
        f.write(f"{v}\n")  # 保存
  • new_vocabs:直接赋值为 vocab
  • 保存词汇表:将词汇表中的每个单词写入 vocab.txt 文件中,文件路径由 args.pre_model_path 指定。

8. 加载预训练模型并调整词汇表大小

model = BartForConditionalGeneration.from_pretrained(args.pre_model_path)  # 模型
model.resize_token_embeddings(len(new_vocabs))
state_dict = model.state_dict()
torch.save(state_dict, args.pre_model_path + '/pytorch_model.bin')

bartconfig = BartConfig.from_pretrained(args.pre_model_path)
bartconfig.vocab_size = len(new_vocabs)
bartconfig.save_pretrained(args.pre_model_path)
  • 加载预训练模型:使用 BartForConditionalGeneration.from_pretrained 加载预训练模型。
  • 调整词汇表大小:使用 resize_token_embeddings 方法调整模型的嵌入层大小以适应新的词汇表。
  • 保存模型状态:将模型的状态字典保存到 pytorch_model.bin 文件中,文件路径由 args.pre_model_path 指定。
  • 更新配置:更新 BartConfig 中的 vocab_size 属性,并保存配置。

三、自监督预训练

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        #日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
# os.environ['CUDA_VISIBLE_DEVICES']='0'

def train_and_validate(args):
    # 1. load data  model
    model = preModel(args)     #加载预训练模型
    optimizer, scheduler = build_optimizer(args, model)
    # model = model.to(args.device)
    use_pre = False

    if use_pre:
        checkpoint = torch.load(args.pre_file, map_location='cpu')
        new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    if args.device == 'cuda':
        if args.paral == True:
            model = torch.nn.parallel.DataParallel(model.to(args.device))
        else:
            model = model.to(args.device)
        # model = BalancedDataParallel(16, model, dim=0).to(args.device)
    # model = model.to(args.device)
    #-------ema here-----------------

    all_data = loadData(args.data_path)
    train_MLM_data = MLM_Data(all_data, args)

    train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True,collate_fn=train_MLM_data.collate)
    step = 0
    start_time = time.time()

    num_total_steps = len(train_dataloader) * args.max_epochs

    for epoch in range(args.max_epochs):    #开始训练了
        for batch in train_dataloader:
            model.train()
            loss= model(batch)
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            step += 1
            if step % args.print_steps == 0:
                time_per_step = (time.time() - start_time) / max(1, step)
                remaining_time = time_per_step * (num_total_steps - step)
                remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
                logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")

        logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")
        if epoch % 5 == 0:
            torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},
                       f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')

def main():
    args = parse_args()           #设置   字典
    setup_logging()
    setup_device(args)
    setup_seed(args)
    os.makedirs(args.savedmodel_path, exist_ok=True)
    logging.info("Training/evaluation parameters: %s", args)         #LINUX
    train_and_validate(args)





if __name__ == '__main__':
    main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、日志记录以及模型保存等功能

1. 导入必要的库

from model_utils.pre_data import PreTrainDataset, loadData, MLM_Data
from torch.utils.data import DataLoader, Dataset
from model_utils.models import preModel
import logging        # 日志
import os
from model_utils.config import parse_args
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer
import torch
import time
  • PreTrainDataset, loadData, MLM_Data:自定义模块,用于数据处理。
  • DataLoader, Dataset:PyTorch提供的类,用于数据加载和管理。
  • preModel:自定义模型类。
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • setup_device, setup_seed, setup_logging, build_optimizer:自定义工具函数,分别用于设置设备、随机种子、日志记录和优化器构建。
  • torch:PyTorch核心库。
  • time:用于时间相关操作。

2. 定义训练和验证函数

def train_and_validate(args):
    # 1. 加载数据和模型
    model = preModel(args)     # 加载预训练模型
    optimizer, scheduler = build_optimizer(args, model)
    
    use_pre = False

    if use_pre:
        checkpoint = torch.load(args.pre_file, map_location='cpu')
        new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if args.device == 'cuda':
        if args.paral == True:
            model = torch.nn.parallel.DataParallel(model.to(args.device))
        else:
            model = model.to(args.device)

    all_data = loadData(args.data_path)
    train_MLM_data = MLM_Data(all_data, args)

    train_dataloader = DataLoader(train_MLM_data, batch_size=args.batch_size, shuffle=True, collate_fn=train_MLM_data.collate)
    step = 0
    start_time = time.time()

    num_total_steps = len(train_dataloader) * args.max_epochs

    for epoch in range(args.max_epochs):    # 开始训练了
        for batch in train_dataloader:
            model.train()
            loss = model(batch)
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            step += 1
            if step % args.print_steps == 0:
                time_per_step = (time.time() - start_time) / max(1, step)
                remaining_time = time_per_step * (num_total_steps - step)
                remaining_time = time.strftime('%H:%M:%S', time.gmtime(remaining_time))
                logging.info(f"Epoch {epoch} step {step} eta {remaining_time}: loss {loss:.3f}")

        logging.info(f"VAL_Epoch {epoch} step {step}: loss {loss:.3f}")
        if epoch % 5 == 0:
            torch.save({'epoch': epoch, 'model_state_dict': model.module.state_dict()},
                       f'{args.savedmodel_path}/lr{args.learning_rate}epoch{epoch}loss{loss:.3f}pre_model.bin')
解释
  • 加载数据和模型

    • 使用 preModel 类加载预训练模型。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 根据 args.deviceargs.paral 参数决定是否使用多GPU并行训练。
  • 数据加载

    • 使用 loadData 函数加载所有数据。
    • 使用 MLM_Data 类将数据转换为适合训练的数据集格式。
    • 使用 DataLoader 创建数据加载器,支持批量加载和数据打乱。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 记录训练进度和剩余时间,并在特定步数时打印日志。
    • 每隔5个epoch保存一次模型。

3. 主函数

def main():
    args = parse_args()           # 设置   字典
    setup_logging()
    setup_device(args)
    setup_seed(args)
    os.makedirs(args.savedmodel_path, exist_ok=True)
    logging.info("Training/evaluation parameters: %s", args)         # LINUX
    train_and_validate(args)

if __name__ == '__main__':
    main()
  • main 函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

四、微调

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer,array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdm

os.environ['CUDA_VISIBLE_DEVICES']='0'


# 不需要完全理解,  知道每一块在做什么就行   知道之后,  以后再用到, 搬过去就行


def validate(model, loader, args, output_file=None, beam=1, n=-1):
    res, gts = [], {}
    tot = 0
    for (source, targets) in tqdm(loader):
        if n>0 and tot>n:
            break
        source = source.cuda()
        pred = model(source[:, :args. input_l])
        pred = pred.cpu().detach().numpy()
        #print(pred.shape)
        for i in range(pred.shape[0]):
            # res.append({'image_id':tot, 'caption': [array2str(pred[i][2:], args)]})
            # gts[tot] = [array2str(targets[i][1:], args)]
            res.append({'image_id':tot, 'caption': [array2str(pred[i], args)]})
            gts[tot] = [array2str(targets[i][1:], args)]
            tot += 1
    CiderD_scorer = CiderD(df='corpus', sigma=15)
    cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)
    return cider_score



def train_and_validate(args):
    # 1. load data
    train_dataloader, val_dataloader = create_dataloaders(args)
    model = myModel(args)
    use_pre = True
    if use_pre:
        print('use_pre')
        checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')

        new_KEY = model.load_state_dict(checkpoint['model_state_dict'],strict=True)
    optimizer, scheduler = build_optimizer(args, model)
    model = model.to(args.device)
    #-------ema here-----------------


    model.train()
    #-------------------------------
    # loss, results = validate(model, val_dataloader)
    # 3. training
    step = 0
    best_score = args.best_score     #评估指标  准确率

    for epoch in range(args.max_epochs):
        for (source, targets) in tqdm(train_dataloader):
            source = source.cuda()
            targets = targets.cuda()
            model.train()
            pred = model(source[:, :args. input_l], targets[:, :args.output_l])
            loss  = CE(pred[:, :-1], targets[:, 1:])
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            scheduler.step()
            step += 1

        if epoch % 1 == 0:
            cider_score = validate(model, val_dataloader, args)
            logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")
            if cider_score >= best_score:
                best_score = cider_score
                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},
                        f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')



def main():
    args = parse_args()
    setup_logging()
    setup_device(args)
    setup_seed(args)
    os.makedirs(args.savedmodel_path, exist_ok=True)
    logging.info("Training/evaluation parameters: %s", args)
    train_and_validate(args)


if __name__ == '__main__':
    main()

实现了一个完整的训练和验证流程,包括数据加载、模型初始化、训练循环、验证评估以及模型保存等功能。

1. 导入必要的库

import logging
import os
import time
import torch
from transformers import PretrainedBartModel
from model_utils.config import parse_args
from model_utils.data import create_dataloaders
from model_utils.models import myModel
from model_utils.score import CiderD, CE
from model_utils.utils import setup_device, setup_seed, setup_logging, build_optimizer, array2str
from torch.cuda.amp import autocast as ac
from tqdm import tqdm as tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
  • logging:用于记录日志信息。
  • os:用于操作系统相关的操作(如文件路径处理)。
  • time:用于时间相关操作。
  • torch:PyTorch核心库。
  • PretrainedBartModel:来自 transformers 库的预训练模型基类。
  • parse_args:自定义函数,解析命令行参数或配置文件。
  • create_dataloaders:自定义函数,创建数据加载器。
  • myModel:自定义模型类。
  • CiderD, CE:自定义评分函数,分别用于计算CIDEr-D分数和交叉熵损失。
  • setup_device, setup_seed, setup_logging, build_optimizer, array2str:自定义工具函数,分别用于设置设备、随机种子、日志记录、构建优化器和数组转字符串。
  • autocast:用于混合精度训练。
  • tqdm:用于显示进度条。

2. 定义验证函数

def validate(model, loader, args, output_file=None, beam=1, n=-1):
    res, gts = [], {}
    tot = 0
    for (source, targets) in tqdm(loader):
        if n > 0 and tot > n:
            break
        source = source.cuda()
        pred = model(source[:, :args.input_l])
        pred = pred.cpu().detach().numpy()
        for i in range(pred.shape[0]):
            res.append({'image_id': tot, 'caption': [array2str(pred[i], args)]})
            gts[tot] = [array2str(targets[i][1:], args)]
            tot += 1
    CiderD_scorer = CiderD(df='corpus', sigma=15)
    cider_score, cider_scores = CiderD_scorer.compute_score(gts, res)
    return cider_score
解释
  • 输入参数

    • model: 需要验证的模型。
    • loader: 数据加载器。
    • args: 命令行参数或配置对象。
    • output_file: 输出文件路径(可选)。
    • beam: 束搜索宽度(可选,默认为1)。
    • n: 验证样本数限制(可选,默认为-1,表示不限制)。
  • 逻辑

    • 初始化结果列表 res 和真实标签字典 gts
    • 使用 tqdm 显示进度条遍历数据加载器中的每个批次 (source, targets)
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果和真实标签转换为字符串格式并添加到 resgts 中。
    • 使用 CiderD 计算预测结果与真实标签之间的 CIDEr-D 分数。
    • 返回 CIDEr-D 分数。

3. 定义训练和验证函数

def train_and_validate(args):
    # 1. load data
    train_dataloader, val_dataloader = create_dataloaders(args)
    model = myModel(args)
    use_pre = True
    if use_pre:
        print('use_pre')
        checkpoint = torch.load(args.my_pre_model_path, map_location='cpu')
        new_KEY = model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    optimizer, scheduler = build_optimizer(args, model)
    model = model.to(args.device)

    model.train()
    step = 0
    best_score = args.best_score  # 评估指标 准确率

    for epoch in range(args.max_epochs):
        for (source, targets) in tqdm(train_dataloader):
            source = source.cuda()
            targets = targets.cuda()
            model.train()
            pred = model(source[:, :args.input_l], targets[:, :args.output_l])
            loss = CE(pred[:, :-1], targets[:, 1:])
            loss = loss.mean()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            scheduler.step()
            step += 1

        if epoch % 1 == 0:
            cider_score = validate(model, val_dataloader, args)
            logging.info(f"Epoch {epoch} step {step}: loss {loss:.3f}, cider_score {cider_score}")
            if cider_score >= best_score:
                best_score = cider_score
                torch.save({'epoch': epoch, 'model_state_dict': model.state_dict()},
                           f'{args.savedmodel_path}/model_epoch_{epoch}_cider_score_{cider_score}.bin')
解释
  • 加载数据

    • 使用 create_dataloaders 函数加载训练和验证数据加载器。
  • 初始化模型和优化器

    • 使用 myModel 类加载模型。
    • 如果 use_pre 为真,则从指定路径加载预训练模型的权重。
    • 使用 build_optimizer 函数构建优化器和学习率调度器。
    • 将模型移动到指定设备(CPU或GPU)。
  • 训练循环

    • 对每个epoch进行遍历。
    • 对每个batch进行前向传播计算损失,反向传播更新权重。
    • 每个epoch结束后调用 validate 函数计算验证集上的 CIDEr-D 分数。
    • 如果当前 CIDEr-D 分数优于历史最佳分数,则保存模型。

4. 主函数

def main():
    args = parse_args()  # 设置   字典
    setup_logging()
    setup_device(args)
    setup_seed(args)
    os.makedirs(args.savedmodel_path, exist_ok=True)
    logging.info("Training/evaluation parameters: %s", args)  # LINUX
    train_and_validate(args)


if __name__ == '__main__':
    main()
解释
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 setup_logging 配置日志记录。
    • 调用 setup_devicesetup_seed 分别设置设备和随机种子。
    • 创建保存模型的目录(如果不存在)。
    • 打印训练和评估参数。
    • 调用 train_and_validate 函数开始训练和验证过程。

五、inference

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args







def inference(args):
    test_loader = create_dataloaders(args,test=True)
    model = myModel(args)
    print(args.ckpt_file)

    checkpoint = torch.load(args.ckpt_file, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'],strict=False)
    model.to('cuda:0')
    model.eval()

    fp = open(args.test_output_csv, 'w', newline='')
    writer = csv.writer(fp)
    tot = 0
    for source in tqdm(test_loader):
        source = to_device(source, 'cuda:0')
        pred = model(source)
        pred = pred.cpu().numpy()
        for i in range(pred.shape[0]):
            writer.writerow([tot, array2str(pred[i][2:], args)])
            tot += 1
    fp.close()

if __name__ == '__main__':
    args = parse_args()
    inference(args)

实现了一个推理(inference)流程,包括数据加载、模型加载、前向传播以及结果保存等功能。

1. 导入必要的库

from tqdm import tqdm
import csv
from model_utils.utils import to_device, array2str
from model_utils.models import myModel
from model_utils.data import create_dataloaders
import torch
from model_utils.config import parse_args
  • tqdm:用于显示进度条。
  • csv:用于处理CSV文件的读写操作。
  • to_device:自定义函数,将数据移动到指定设备(CPU或GPU)。
  • array2str:自定义函数,将数组转换为字符串。
  • myModel:自定义模型类。
  • create_dataloaders:自定义函数,创建数据加载器。
  • torch:PyTorch核心库。
  • parse_args:自定义函数,解析命令行参数或配置文件。

2. 定义推理函数

def inference(args):
    test_loader = create_dataloaders(args, test=True)
    model = myModel(args)
    print(args.ckpt_file)

    checkpoint = torch.load(args.ckpt_file, map_location='cpu')
    model.load_state_dict(checkpoint['model_state_dict'], strict=False)
    model.to('cuda:0')
    model.eval()

    fp = open(args.test_output_csv, 'w', newline='')
    writer = csv.writer(fp)
    tot = 0
    for source in tqdm(test_loader):
        source = to_device(source, 'cuda:0')
        pred = model(source)
        pred = pred.cpu().numpy()
        for i in range(pred.shape[0]):
            writer.writerow([tot, array2str(pred[i][2:], args)])
            tot += 1
    fp.close()
解释
  • 加载测试数据

    • 使用 create_dataloaders 函数加载测试数据加载器,设置 test=True 表示加载测试集。
  • 初始化模型并加载权重

    • 使用 myModel 类加载模型。
    • 打印预训练模型路径 args.ckpt_file
    • 使用 torch.load 加载预训练模型的权重,并使用 load_state_dict 方法加载到模型中。
    • 将模型移动到 GPU(cuda:0),并设置为评估模式(model.eval())。
  • 推理过程

    • 打开输出 CSV 文件,并创建 CSV 写入器。
    • 使用 tqdm 显示进度条遍历测试数据加载器中的每个批次 source
    • source 移动到 GPU 并进行前向传播得到预测结果 pred
    • 将预测结果转换为 NumPy 数组,并逐个样本写入 CSV 文件。

3. 主函数

if __name__ == '__main__':
    args = parse_args()
    inference(args)
  • 主函数
    • 调用 parse_args 解析命令行参数。
    • 调用 inference 函数开始推理过程。

相关文章:

  • GHCTF2025--Web
  • Nginx完全指南:从入门到精通(基于Ubuntu系统)
  • MySQL入门手册
  • Vite 打包后Nginx部署配置
  • 二叉树计算
  • _二级继电器程控放大倍数自动设置
  • WWW 2025 | 时间序列(Time Series)论文总结
  • 【计算机网络】深入解析 HTTP 中的 GET 方法、POST 方法和 GET 和 POST 的区别
  • SpringCloud——LoadBalancer负载均衡服务调用
  • Docker入门篇1:搜索镜像、拉取镜像、查看本地镜像列表、删除本地镜像
  • 第13章 安全加固OSI的第8层(网络安全防御实战--蓝军武器库)
  • k倍区间 | 哈希 分巧克力 | 二分 青蛙跳杯子 | BFS
  • Lab18_ SQL injection with filter bypass via XML encoding
  • Codeforces Round 566 (Div. 2) E. Product Oriented Recurrence 矩阵加速、欧拉降幂
  • 通过Nacos API实现微服务不间断部署
  • 从传统到智能:Node-red工控机助力农业大棚高效监控
  • 【Python】Django 中的算法应用与实现
  • Android Configuration相关问题如何定位分析(中英文切换、黑夜白天模式等)
  • 【GPU】什么是 NVLink?
  • 4G铁路工控机在高铁信号控制中的关键作用
  • 中方是否担忧美国主权信用评级下调?外交部:美国应采取负责任的政策措施
  • 广东茂名高州市山体滑坡已致3死1失联,搜救仍在继续
  • 南宁海关辟谣网传“查获600公斤稀土材料”:实为焊锡膏
  • 解读|战国子弹库帛书漂泊海外79年今归国,追索仍将继续
  • 第十一届世界雷达展开幕,尖端装备、“大国重器”集中亮相
  • 高新波任西安电子科技大学校长