当前位置: 首页 > news >正文

【笔记】训练步骤代码解析

目录

config参数配置

setup_dirs创建训练文件夹

 load_data加载数据

build_model创建模型

train训练


记录一下训练代码中不理解的地方

config参数配置

config = {'data_root': r"D:\project\megnetometer\datasets\WISDM_ar_latest\organized_dataset",'train_dir': 'train','test_dir': 'test','seq_length': 300,  # 序列长度'batch_size': 32,  # 可能需减小batch_size'epochs': 60,'initial_lr': 3e-4,  # 初始学习率'max_lr': 5e-4,'patience': 20}

配置好需要用到的参数,比如数据集地址,训练轮数,批次大小,学习率等

setup_dirs创建训练文件夹

    def setup_dirs(self):self.run_dir = os.path.join(self.config['data_root'], 'run')  os.makedirs(self.run_dir, exist_ok=True)print('创建运行目录run_dir  = ', self.run_dir)# 创建带时间戳的实验目录timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")print('时间戳 = ', timestamp)self.exp_dir = os.path.join(self.run_dir, f"exp_{timestamp}")os.makedirs(self.exp_dir, exist_ok=True)# 保存当前配置with open(os.path.join(self.exp_dir, 'config.json'), 'w') as f:json.dump(self.config, f, indent=2)  # 两个字符缩进,没有则压缩成一行,把config内容存在config.json里

os.path.join(self.config['data_root'], 'run')  

用于拼接文件路径data_root的路径加上run,中间的连接符会根据系统自动调整

os.makedirs(self.exp_dir, exist_ok=True)

创建文件,exist_ok=True当文件夹存在的时候不报错

创建的文件夹用于存放后续训练生成的模型以及保存训练参数等文件

 load_data加载数据

    def load_data(self):"""从按行为分类的目录加载数据(带多级进度条)"""def load_activity_data(subset_dir):"""加载train或test子目录下的数据"""data = []subset_path = os.path.join(self.config['data_root'], subset_dir)  #在数据集路径内读取,由subset_dir决定读取的是训练集还是测试集# 获取所有活动类别目录activities = [d for d in os.listdir(subset_path)if os.path.isdir(os.path.join(subset_path, d))]#print('activities=',activities)#activities= ['Downstairs', 'Jogging', 'Sitting', 'Standing', 'Upstairs', 'Walking']# 第一层进度条:活动类别pbar_activities = tqdm(activities, desc=f"扫描{subset_dir}目录", position=0)for activity in pbar_activities:activity_lower = activity.lower()if activity_lower not in self.label_map:continueactivity_dir = os.path.join(subset_path, activity)#当前活动的目录# 获取所有用户文件user_files = [f for f in os.listdir(activity_dir)if f.endswith('.txt')]#获取所有txt结尾的文件# 第二层进度条:用户文件#pbar_users = tqdm(user_files, desc="读取用户文件", leave=False, position=1)#后面要close,但是已经把所有的进度注释掉了只留下来一个总的第一层进度#print('pbar_users=',pbar_users)for user_file in user_files:file_path = os.path.join(activity_dir, user_file)# 获取文件行数用于进度条with open(file_path, 'r') as f:num_lines = sum(1 for _ in f)# 第三层进度条:读取文件内容with open(file_path, 'r') as f:for line in f:line = line.strip()if not line:continuetry:x, y, z = map(float, line.split(','))data.append({'x': x,'y': y,'z': z,'activity': activity_lower})except ValueError:continuepbar_activities.close()return data# 调用示例print("\n" + "=" * 50)print("开始加载数据集...")train_data = load_activity_data(self.config['train_dir'])#print(train_data)#{'x': 5.33, 'y': 8.73, 'z': -0.42, 'activity': 'walking'},test_data = load_activity_data(self.config['test_dir'])
pbar_activities = tqdm(activities, desc=f"扫描{subset_dir}目录", position=0)

tqdm创建进度条,desc是进度条前面的描述,position用于多级进度条之间的嵌套,以免位置混乱,在运行完之后要关闭进度条

pbar_activities.close()
with open('data.txt', 'r') as f:打开文件夹,r为只读模式

# 转换为模型输入格式(带优化进度条)def create_sequences(data, desc="生成序列"):seq_length = self.config['seq_length']features, labels = [], []total_windows = len(data) - seq_lengthpbar = tqdm(range(total_windows),desc=desc,position=0,bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [速度:{rate_fmt}]")for i in pbar:window = data[i:i + seq_length]# 检查窗口内活动是否一致if len(set(d['activity'] for d in window)) != 1:continuefeatures.append([[d['x'], d['y'], d['z']] for d in window])labels.append(self.label_map[window[0]['activity']])# 每1000次更新一次进度信息if i % 1000 == 0:pbar.set_postfix({"有效窗口": len(features),"跳过窗口": i - len(features) + 1}, refresh=True)return np.array(features), np.array(labels)print("\n正在预处理训练集...")X_train, y_train = create_sequences(train_data, "训练集序列化")#返回的x是数据,y是标签print("\n正在预处理测试集...")X_test, y_test = create_sequences(test_data, "测试集序列化")# 标准化(显示进度)print("\n正在计算标准化参数...")self.mean = np.mean(X_train, axis=(0, 1))self.std = np.std(X_train, axis=(0, 1))print("应用标准化...")X_train = (X_train - self.mean) / (self.std + 1e-8)X_test = (X_test - self.mean) / (self.std + 1e-8)# One-hot编码# 将 NumPy 数组转为 PyTorch 张量,并指定类型为 int64(等价于 .long())y_train = torch.from_numpy(y_train).long()  # 或 .to(torch.int64)y_train = torch.nn.functional.one_hot(y_train.long(), num_classes=len(self.label_map))y_test = torch.from_numpy(y_test).long()  # 或 .to(torch.int64)y_test = torch.nn.functional.one_hot(y_test.long(), num_classes=len(self.label_map))print("\n" + "=" * 50)print("数据预处理完成!")print(f"训练集形状: X_train{X_train.shape}, y_train{y_train.shape}")print(f"测试集形状: X_test{X_test.shape}, y_test{y_test.shape}")print("=" * 50 + "\n")return (X_train, y_train), (X_test, y_test)

滑动窗口开销大,改用向量化滑动窗口(NumPy)

参数标准化全部使用训练集数据

1e-8的作用:防止除零的小常数,特别适用于某些标准差接近0的特征

axis=(0,1):假设您的数据是3D张量(样本×时间步/空间×特征),这样计算每个特征通道的统计量

消除量纲影响:当特征的单位/量纲不同时(如年龄0-100 vs 工资0-100000),标准化使所有特征具有可比性

只使用训练集统计量:测试集必须使用训练集的mean/std,这是为了避免数据泄露(data leakage)

数据泄露:是机器学习中一个常见但严重的问题,指在模型训练过程中意外地使用了测试集或未来数据的信息,导致模型评估结果被高估,无法反映真实性能。这种现象会使模型在实际应用中表现远差于预期。

将分类标签(整数形式)转换为 One-hot 编码,这是机器学习中处理分类任务的常见方法。

build_model创建模型

    def build_model(self):"""构建改进的BiLSTM分类模型"""model = tf.keras.Sequential([tf.keras.layers.InputLayer(input_shape=(self.config['seq_length'], 3)),# 双向LSTM层tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),tf.keras.layers.BatchNormalization(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),tf.keras.layers.BatchNormalization(),# 全连接层tf.keras.layers.Dense(32, activation='relu'),tf.keras.layers.Dropout(0.3),tf.keras.layers.Dense(len(self.label_map), activation='softmax')])model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='categorical_crossentropy',metrics=['accuracy'])return model

两个模型框架TensorFlow更早,但PyTorch的初始设计更现代,以上是TensorFlow的模型。

计算图(Computational Graph) 是描述数学运算和数据处理流程的抽象结构,而 静态图动态图 是两种不同的计算图构建和执行方式。

计算图 是一个有向无环图(DAG),用于表示计算过程:

  • 节点(Node):代表运算(如加法、矩阵乘法)或数据(如张量、变量)。

  • 边(Edge):描述数据流动方向(如张量从一层传递到下一层)。

改用PyTorch模型需要注意

PyTorch更推荐类式构建,而且保存时仅保存模型的参数(权重和偏置),不包含模型结构。如果需要测试,加载时必须先实例化一个结构完全相同的模型,再加载参数。

先创建一个模型类,再去调用 

class BiLSTMModel(nn.Module):def __init__(self, input_size, hidden_size, num_layers, num_classes, bidirectional=True):super(BiLSTMModel, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layersself.num_directions = 2 if bidirectional else 1# 双向LSTMself.lstm = nn.LSTM(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=True,bidirectional=bidirectional)# 全连接层(双向时hidden_size需*2)self.fc = nn.Linear(hidden_size * self.num_directions, num_classes)def forward(self, x):# 初始化隐藏状态(可选,PyTorch默认全零)h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)c0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(x.device)# LSTM前向传播out, _ = self.lstm(x, (h0, c0))  # out形状: (batch, seq_len, hidden_size * num_directions)# 取最后一个时间步的输出out = out[:, -1, :]  # 形状: (batch, hidden_size * num_directions)# 分类层out = self.fc(out)return out

此处构建的就是双向LSTM模型,然后再构建函数调用

    def build_model(self):# 使用示例model = LSTMModel(input_size=3,  # 对应x/y/z特征hidden_size=32,num_layers=2,num_classes=6,  # 类别数bidirectional=True)return model

train训练

    def train(self):"""PyTorch版本训练流程"""# 1. 数据加载与预处理(X_train, y_train), (X_test, y_test) = self.load_data()# 转换为PyTorch张量并移至设备device = torch.device("cuda" if torch.cuda.is_available() else "cpu")X_train = torch.FloatTensor(X_train).to(device)y_train = torch.LongTensor(y_train.argmax(axis=1)).to(device)  # 如果y是one-hotX_test = torch.FloatTensor(X_test).to(device)y_test = torch.LongTensor(y_test.argmax(axis=1)).to(device)# 创建DataLoadertrain_dataset = TensorDataset(X_train, y_train)# 类似zip(features, labels)train_loader = DataLoader(train_dataset,batch_size=self.config['batch_size'],shuffle=True)# 2. 模型初始化self.model = self.build_model().to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(self.model.parameters(),lr=self.config.get('lr', 0.001))# 3. 回调函数设置"""# 早停early_stopping = EarlyStopping(patience=self.config['patience'],verbose=True,path=os.path.join(self.exp_dir, 'best_model.pth'))"""# 学习率调度scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.1,patience=5,verbose=True)# TensorBoard日志writer = SummaryWriter(log_dir=os.path.join(self.exp_dir, 'logs'))print("\n开始训练...")print(f"实验目录: {self.exp_dir}")print(f"使用设备: {device}")# 4. 训练循环for epoch in range(self.config['epochs']):self.model.train()train_loss = 0.0# 训练批次for inputs, labels in train_loader:optimizer.zero_grad()outputs = self.model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()train_loss += loss.item()# 验证阶段self.model.eval()with torch.no_grad():test_outputs = self.model(X_test)test_loss = criterion(test_outputs, y_test)_, predicted = torch.max(test_outputs, 1)accuracy = (predicted == y_test).float().mean()# 记录日志writer.add_scalar('Loss/train', train_loss / len(train_loader), epoch)writer.add_scalar('Loss/test', test_loss.item(), epoch)writer.add_scalar('Accuracy/test', accuracy.item(), epoch)# 打印进度print(f"Epoch {epoch + 1}/{self.config['epochs']} | "f"Train Loss: {train_loss / len(train_loader):.4f} | "f"Test Loss: {test_loss.item():.4f} | "f"Accuracy: {accuracy.item():.4f}")# 学习率调整scheduler.step(test_loss)"""# 早停检查early_stopping(test_loss, self.model)if early_stopping.early_stop:print("Early stopping triggered")break"""# 5. 保存最终结果writer.close()self.save_results(X_test, y_test)  # 需要适配PyTorch的保存方法

TensorDatasetDataLoader 都是 PyTorch 官方库中的核心组件,专门用于高效的数据加载和批处理。

torch.utils.data.TensorDataset将多个张量(如特征张量和标签张量)打包成一个数据集对象

dataset = TensorDataset(features, labels)  # 类似zip(features, labels)

torch.utils.data.DataLoader将数据集按批次加载,支持自动批处理、打乱数据、多进程加载等

shuffle=True代表打乱数据,此处是时序信号,但是由于从长序列中通过滑动窗口提取样本每个窗口本身就是一个独立样本,此时打乱窗口顺序是安全的

损失函数 criterion = nn.CrossEntropyLoss()

http://www.dtcms.com/a/275030.html

相关文章:

  • docker安装Consul笔记
  • Java(7.11 设计模式学习)
  • PLC框架-1.3- 汇川PN伺服(3号报文)
  • 多种人脸处理方案——人脸裁剪
  • Webview 中可用的 VS Code 方法
  • G1 垃圾回收算法详解
  • 【TCP/IP】16. 简单网络管理协议
  • 天晟科技携手万表平台,共同推动RWA项目发展
  • 从「小公司人事」到「HRBP」:选对工具,比转岗更能解决成长焦虑
  • Java大厂面试故事:谢飞机的互联网音视频场景技术面试全纪录(Spring Boot、MyBatis、Kafka、Redis、AI等)
  • kubernetes单机部署踩坑笔记
  • DIDCTF-蓝帽杯
  • 谷歌云代理商:谷歌云TPU/GPU如何加速您的AI模型训练和推理
  • 【数据结构与算法】206.反转链表(LeetCode)
  • C++:非类型模板参数,模板特化以及模板的分离编译
  • 实现将文本数据(input_text)转换为input_embeddings的操作
  • 《从依赖纠缠到接口协作:ASP.NET Core注入式开发指南》
  • Vue 表单开发优化实践:如何优雅地合并 `data()` 与 `resetForm()` 中的重复对象
  • Sigma-Aldrich 细胞培养实验方案 | 通过Hoechst DNA染色检测细胞的支原体污染
  • 拔高原理篇
  • 奇哥面试记:SpringBoot整合RabbitMQ与高级特性,一不小心吊打面试官
  • java底层的native和沙箱安全机制
  • Lecture #19 : Multi-Version Concurrency Control
  • 深入理解JVM的垃圾收集(GC)机制
  • Next知识框架、SSR、SSG和ISR知识框架梳理
  • c++——运算符的重载
  • 鸿蒙开发之ArkTS常量与变量的命名规则
  • 面向对象编程
  • [面试] 手写题-选择排序
  • 持有对象-泛型和类型安全的容器