当前位置: 首页 > news >正文

【机器学习】用 TensorFlow 实现词向量训练全流程

【机器学习】用 TensorFlow 实现词向量训练全流程

简介

在自然语言处理(NLP)的世界里,词向量是连接离散文本与连续数值计算的桥梁。Word2Vec 作为经典的词向量生成模型,以其简洁高效的特点至今仍被广泛应用。本文将基于 TensorFlow,从数据预处理到模型训练、评估,从零实现一个完整的 Word2Vec 模型。

Word2Vec核心原理

Word2Vec 的核心思想是 “物以类聚,人以群分”—— 在文本中经常共同出现的词语具有相似的语义。它主要通过两种架构实现:

1.Skip-gram:以中心词为输入,预测其上下文词(本文采用此架构)
2.CBOW:以上下文词为输入,预测中心词

为解决大规模词汇表的计算效率问题,模型引入了负采样(Negative Sampling)技术:将多分类问题转化为二分类问题,每次训练仅采样少量负例词与正例词进行对比学习,大幅降低计算成本。

环境准备与依赖导入

首先导入所需要的库,涵盖数据处理、文件操作、数值计算和深度学习框架等工具:

import collections
import os
import random
import urllib
import zipfile
import numpy as np
import tensorflow as tf

参数配置

将参数分为训练参数、测试样例和模型核心参数三类:

# 训练参数
learning_rate = 0.1 # 学习率
batch_size = 128 # 每次训练迭代使用的样本数量
num_steps = 3000000 # 总的训练步数
display_step = 10000 # 每训练10000步显示一次训练信息
eval_step = 200000 # 每训练200000步进行一次模型评估
# 测试样例
eval_words = ['nine','of','going','hardware','american','britain']
# Word2vec参数
embedding_size = 200 # 词向量维度
max_vocabulary_size = 50000 # 语料库词语数
min_occurrence = 10 # 最小词频,低于此频率的词会被过滤掉
skip_window = 3 # 左右窗口大小
num_skips = 2 # 每次选取2个上下文词与当前词构成训练样本对
num_sampled = 64 # 负采样,每次训练时随机采样64个负例词来优化模型

参数调优小技巧:
1.词向量维度通常设置50-300,维度越高语义表达越丰富但计算成本越高
2.学习率可根据损失曲线调整,若损失波动大则减小学习率
3.负采样数量一般取5-20,64是针对大规模语料的优化值

数据预处理

数据预处理是模型训练的基础,主要包括数据集加载、词频统计、词汇表构建和训练样本四个步骤。

1.加载与清洗数据集

本研究中使用的是经典的text8数据集,包括预处理的维基百科文本,共约1700万词:

# 加载训练数据
data_path = '../text8.zip'
with zipfile.ZipFile(data_path) as f:text_word = f.read(f.namelist()[0]).lower().split()
print(len(text_word))

运行结果:17005207

2.词频统计与词汇表构建

过滤低频词并构建词汇表,同时将低频词统一归为“UNK”(未知词):

import collections
# 创建一个计数器,计算每个词出现了多少次
count = [('UNK',-1)]
# 基于词频返回max_vocabulary_size个常用词
count.extend(collections.Counter(text_word).most_common(max_vocabulary_size-1))
# 剔除掉出现次数少于'min_occurrence'的词
for i in range(len(count)-1,-1,-1):if count[i][1]<min_occurrence: # 从start到end每次step多少count.pop(i)else:# 判断时,从小到大排序的,所以跳出时候剩下的都是满足条件的break
# 计算语料库大小
vocabulary_size = len(count)
# 每个词都分配一个ID
word2id = dict()
for i,(word,_) in enumerate(count):word2id[word] = i
print(count[0:10])
print(word2id)

运行结果:
在这里插入图片描述
在这里插入图片描述

3.文本转ID序列

将原始文本词列表转换为模型可处理的ID序列,并统计UNK的实际数量:

data = list() # 用于存储转换后的单词ID序列
unk_count = 0 # 用于统计未知词(UNK)的出现次数
for word in text_word:# 全部转换成idindex = word2id.get(word,0)if index == 0:unk_count += 1data.append(index)
count[0] = ('UNK',unk_count) # 更新未知词的计数
id2word = dict(zip(word2id.values(),word2id.keys())) # 构建一个与word2id反向的字典,用于通过ID查找对应的单词
print('Words count:', len(text_word)) # 总单词数量
print('Unique words:',len(set(text_word))) # 去重后的单词数量
print('Vocabulary size:',vocabulary_size) # 词汇表大小
print('Most common words:',count[:10]) # 出现频率最高的前10个词

运行结果:
在这里插入图片描述

预处理效果验证标准:
1.UNK 占比不宜过高(本文约 2.6%),否则说明词汇表过滤过度
2.高频词应符合语言规律(如英语中的 “the”、“of” 等)

4.批次生成训练样本

实现next_batch函数,通过滑动窗口从ID序列中提取(中心词,上下文词)样本对:

data_index = 0
def next_batch(batch_size,num_skips,skip_window):global data_index # 全局变量,记录当前数据读取位置assert batch_size % num_skips == 0 # 确保batch_size能被num_skips整除assert num_skips <= 2*skip_window # 确保采样数不超过上下文总词数batch = np.ndarray(shape=(batch_size), dtype=np.int32) # 存储中心词ID的数组labels = np.ndarray(shape=(batch_size,1), dtype=np.int32) # 存储上文词ID的数组span = 2*skip_window + 1 # 7为窗口,左3右3中间1 buffer = collections.deque(maxlen=span) # 创建一个固定长度为7的队列if data_index + span > len(data): # 如果剩余数据不足一个窗口,从头开始data_index = 0buffer.extend(data[data_index:data_index+span]) # 填充初始窗口数据data_index += span # 更新数据读取位置for i in range(batch_size//num_skips): # 每个中心词生成2个样本# 上下文词的索引(排除中心词的位置)context_word = [w for w in range(span) if w!=skip_window]# 从上下文词中随机选择2个作为标签word_to_use = random.sample(context_word,num_skips)for j,context_word in enumerate(word_to_use): # 遍历每一个候选词,用其当作输出也就是标签# 批量存储中心词batch[i*num_skips+j] = buffer[skip_window]# 标签中存储对应的上下文词labels[i*num_skips+j,0]=buffer[context_word]if data_index == len(data): # 如果数据读完,从头补充窗口buffer.extend(data[0:span])data_index = spanelse:buffer.append(data[data_index]) # 窗口右移一个词data_index += 1# 调整数据索引(循环读取)data_index = (data_index+len(data)-span)%len(data)return batch,labels

样本生成逻辑示例:
1.窗口:[A, B, C, D, E, F, G](中心词为 D,skip_window=3)
2.上下文词:[A, B, C, E, F, G]
3.随机选择 2 个上下文词(如 B 和 E),生成样本(D,B)和(D,E)

模型构建

模型构建分为词嵌入层、损失函数、优化器和评估函数四个核心模块,全部指定在 CPU 上运行(词嵌入计算在 CPU 上更高效)。

1.词嵌入层:词语到向量的映射

词嵌入层是 Word2Vec 的核心,通过可训练的矩阵将离散 ID 转换为连续向量:

with tf.device('/cpu:0'):# 词嵌入矩阵:[词汇表大小, 词向量维度],正态分布初始化embedding = tf.Variable(tf.random.normal([vocabulary_size, embedding_size]))# NCE损失权重矩阵:与词嵌入矩阵形状一致nce_weights = tf.Variable(tf.random.normal([vocabulary_size, embedding_size]))# NCE损失偏置项:初始化为0nce_biases = tf.Variable(tf.zeros([vocabulary_size]))

2.词嵌入查询

实现get_embeddings函数,根据词 ID 查找对应的词向量:

def get_embeddings(x):with tf.device('/cpu:0'):# 从嵌入矩阵中查找x对应的词向量x_embed = tf.nn.embedding_lookup(embedding, x)return x_embed

3.NCE损失

使用负采样的 NCE 损失函数,解决大规模词汇表的计算问题:

NCE损失工作原理:
1.对每个正例(中心词 - 上下文词对)采样多个负例(随机词)
2.训练模型区分 “正例对” 和 “负例对”
3.无需计算所有词汇的概率,大幅提升效率

def nce_loss(x_embed,y):with tf.device('/cpu:0'):y = tf.cast(y, tf.int64)loss = tf.reduce_mean(tf.nn.nce_loss(weights=nce_weights, # NCE损失的权重矩阵biases=nce_biases, # NCE损失的偏置项inputs=x_embed, # 输入的词嵌入向量labels=y, # 真实标签num_sampled=num_sampled, # 采样的负例数量num_classes=vocabulary_size # 总类别数))return loss

4.优化器

使用随机梯度下降(SGD)优化器更新模型参数:

# 定义随机梯度下降(SGD)优化器,学习率为learning_rate
optimizer = tf.optimizers.SGD(learning_rate) 

5.模型评估

通过计算余弦相似度评估词向量质量,余弦相似度越高说明词语语义越相近:

def evaluate(x_embed):with tf.device('/cpu:0'):# 将输入词向量转换为float32类型,确保数值类型一致x_embed = tf.cast(x_embed, tf.float32)# 对输入词向量进行L2归一化# 归一化向量=原始向量/向量的模长(L2范数)x_embed_norm = x_embed / tf.sqrt(tf.reduce_sum(tf.square(x_embed)))# 对整个嵌入矩阵进行L2归一化embedding_norm = embedding/tf.sqrt(tf.reduce_sum(tf.square(embedding),1,keepdims=True),tf.float32)# 计算余弦相似度:归一化后的输入向量×归一化后的嵌入矩阵(转置)cosine_sim_op = tf.matmul(x_embed_norm, embedding_norm, transpose_b=True)return cosine_sim_op

模型训练与评估

训练过程包含参数优化、损失监控和语义评估三个核心环节,总迭代 300 万步。

1.优化函数

实现run_optimization函数,完成 “前向传播→反向传播→参数更新” 的闭环:

def run_optimization(x, y):with tf.device('/cpu:0'):# 梯度磁带:记录变量操作用于自动求导with tf.GradientTape() as g:# 前向传播:获取词嵌入并计算损失emb = get_embeddings(x)loss = nce_loss(emb, y)# 反向传播:计算损失对参数的梯度gradients = g.gradient(loss, [embedding, nce_weights, nce_biases])# 参数更新:应用梯度调整模型参数optimizer.apply_gradients(zip(gradients, [embedding, nce_weights, nce_biases]))

2.核心训练循环

# 待测试的几个词
x_test = np.array([word2id[w.encode('utf-8')] for w in eval_words])# 训练
for step in range(1, num_steps + 1):batch_x, batch_y = next_batch(batch_size, num_skips, skip_window)run_optimization(batch_x, batch_y)if step % display_step == 0 or step == 1:loss = nce_loss(get_embeddings(batch_x), batch_y)print("step: %i, loss: %f" % (step, loss))# Evaluation.if step % eval_step == 0 or step == 1:print("Evaluation...")sim = evaluate(get_embeddings(x_test)).numpy()for i in range(len(eval_words)):top_k = 8  # 返回前8个最相似的nearest = (-sim[i, :]).argsort()[1:top_k + 1]log_str = '"%s" nearest neighbors:' % eval_words[i]for k in range(top_k):log_str = '%s %s,' % (log_str, id2word[nearest[k]])print(log_str)

运行结果:
从最后的运行来看,模型呈现出 “稳定中微幅波动” 的典型收敛特征;
1.在295 万步达到最低损失 4.69,300 万步最终损失定格在 5.62,说明模型参数已基本达到最优状态,进一步增加训练步数对损失降低的增益有限;
2.从 270 万步到 300 万步的 30 万次迭代中,损失值未出现显著下降(最低值 4.69 与平均值 5.5 差距不足 1),说明模型已触及当前参数配置下的收敛边界。此时继续训练可能陷入 “无效迭代”,既无法降低损失,也难以提升语义表达能力,因此 300 万步是合理的终止节点;

1.数词类:精准捕捉数值逻辑关联:“nine” :b’eight’, b’seven’, b’six’, b’four’, b’five’, b’three’, b’one’, b’two’;
2.功能词类:捕捉高频共现特征:“of” 作为英语中最常用的介词之一,其相似词集中在高频功能词:b’and’, b’the’, b’including’, b’in’, b’modern’;
3.动词类:语义关联待进一步强化:“going” :b’put’, b’little’, b’so’, b’out’, b’long’;
4.技术名词类:行业属性聚类初现:“hardware”(硬件)的相似词在 300 万步有明显优化:b’program’, b’systems’, b’using’, b’computer’, b’source’;
5地域属性词:国家 / 地区语义链成型
“american”(美国的)和 “britain”(英国)的相似词呈现出清晰的地域语义关联:
“american” 的相似词从 “actor”“canadian” 优化为 “canadian”“english”“british”,均为国家 / 地区相关的形容词或名词,符合地域属性聚类逻辑;
“britain” 的相似词稳定在 “france”“germany”“europe” 等欧洲国家或地域名词,准确捕捉了 “英国” 的地理与文化关联。

在这里插入图片描述


文章转载自:

http://w9SZQncJ.yfmLj.cn
http://CIOCkANc.yfmLj.cn
http://xxMA5Qzt.yfmLj.cn
http://raTT5rXw.yfmLj.cn
http://siITwstv.yfmLj.cn
http://tqH6kMaF.yfmLj.cn
http://aesVMn19.yfmLj.cn
http://QX0mgzaf.yfmLj.cn
http://Ff8AENIG.yfmLj.cn
http://IPyLZxak.yfmLj.cn
http://U7UtAQ4C.yfmLj.cn
http://NbNNyCnS.yfmLj.cn
http://mMrNqGa2.yfmLj.cn
http://3F6TRSOS.yfmLj.cn
http://XcuWqFbG.yfmLj.cn
http://WQ85DCDU.yfmLj.cn
http://usRiSnAx.yfmLj.cn
http://tmQCFGLU.yfmLj.cn
http://vkSWqutt.yfmLj.cn
http://Li5zhgMo.yfmLj.cn
http://LGLz30p8.yfmLj.cn
http://flGn0ypO.yfmLj.cn
http://jZBn4F3U.yfmLj.cn
http://V0Wnm7BW.yfmLj.cn
http://COUZ77XZ.yfmLj.cn
http://ySYue04m.yfmLj.cn
http://j47f7akY.yfmLj.cn
http://cVargnhM.yfmLj.cn
http://awU7q9IX.yfmLj.cn
http://KagTNmb6.yfmLj.cn
http://www.dtcms.com/a/384821.html

相关文章:

  • C# --- 使用定时任务实现日志的定时聚合
  • Origin如何将格点色阶条进化为渐变色阶条
  • 非关系数据库(NoSQL):所需软件与环境配置全指南
  • 计算机网络1
  • 字幕编辑工具推荐,Subtitle Edit v4.0.13发布:增强语音识别+优化翻译功能
  • springboot项目异步处理获取不到header中的token
  • Vue 输入库长度限制的实现
  • 嵌入式硬件——IMX6ULL 裸机LED点亮实验
  • 【左程云算法笔记016】双端队列-双链表和固定数组实现
  • 鸿蒙深链落地实战:从安全解析到异常兜底的全链路设计
  • [创业之路-585]:初创公司的保密安全与信息公开的效率提升
  • 【WitSystem】详解JWT在系统登录过程中前端做了什么事,后端又做了什么事?
  • 力扣(LeetCode) ——217. 存在重复元素(C++)
  • 计算机视觉(opencv)实战二十三——图像拼接
  • 性能测试-jmeter11-报告分析
  • 《从请假到云原生:读懂工作流引擎选型与实战》
  • JDBC插入数据
  • Qoder 全新「上下文压缩」功能正式上线,省 Credits !
  • FPGA时序约束(五)--衍生时钟约束
  • 【C语言】第八课 输入输出与文件操作​​
  • 滤波器模块选型指南:关键参数与实用建议
  • 现有的双边拍卖机制——VCG和McAfee
  • Linux 系统、内核及 systemd 服务等相关知识
  • 企业级 Docker 应用:部署、仓库与安全加固
  • 倍福TwinCAT HMI如何关联PLC变量
  • 2025.9.25大模型学习
  • Java开发工具选择指南:Eclipse、NetBeans与IntelliJ IDEA对比
  • C++多线程编程:从基础到高级实践
  • JavaWeb 从入门到面试:Tomcat、Servlet、JSP、过滤器、监听器、分页与Ajax全面解析
  • Java 设计模式——分类及功能:从理论分类到实战场景映射