当前位置: 首页 > news >正文

有趣的机器学习-利用神经网络来模拟“古龙”写作风格的输出器

前言

在探索大模型落地应用的旅程中,我们常常聚焦于其解决严肃商业问题的解决方案、策略,如:优化客服、生成报告、辅助决策……这些固然重要,但技术的魅力远不止于此。有时,跳出“实用主义”的框架,用一点“玩心”去触碰技术的边界,反而能更深刻地理解其内核。

今天,我们就来做这样一次有趣的“侧向探索”。我们将暂时放下繁重的企业用例,一同踏入一个充满武侠与诗意的创作领域——让机器学习模型学习古龙先生的独特文风。

古龙的笔触,以短句、留白、悬念和极具画面感的意境著称,这与现代神经网络处理序列数据的模式形成了奇妙的映照。本文将为你完整呈现,如何一步步构建一个“古龙风”文本生成器。从数据集的采集与清洗,到模型结构的设计与调优,再到最终那带着几分“江湖气”的文字输出,我们将亲历整个过程。

这不仅是一个妙趣横生的技术玩具,更是一个绝佳的微型案例。它能让你直观地理解神经网络如何“学习”一种抽象的风格,其中涉及的数据质量、模型局限性和调参技巧,正是所有大模型落地项目中都会遇到的、最真实的挑战与经验。希望这次轻松的实践,能为你接下来的严肃项目,带来新的灵感和启发。

1. 万能模仿大师:神经网络如何习得“风格”之魂

在我们开启构建“古龙风”写作模型的奇妙之旅前,一个根本性问题需要被清晰地解答:神经网络,这个看似神秘的计算模型,究竟凭借何种能力,可以捕捉并复现像文学风格这般抽象且精妙的事物?

简而言之,您可以将神经网络视为一位不知疲倦、感知敏锐的模仿者。它无需理解“侠义”的哲学内涵或“孤独”的情感深度,它的核心能力在于,通过对海量数据进行统计分析,识别出其中深藏的、构成特定模式的复杂规律。我们的任务,便是将这位“模仿者”培养成专精于古龙文体的特型演员。

1.1 解构智能:从微观决策到宏观涌现

神经网络的能力并非凭空产生,它源于其独特的、受生物学启发的基础架构。

1.1.1 基石:作为基本决策单元的神经元

神经网络的核心组件是“神经元”,它本质上是一个进行加权决策的简单处理器

  • 信息输入:它接收来自其他神经元的多个输入信号(例如,信号A代表“上文是否出现短句?”,信号B代表“是否提及自然意象?”,信号C代表“语境基调是否凝练?”)。
  • 权重评估:神经元并非平等看待所有输入,它会为每一个信号分配一个权重,这代表了该信号在本次决策中的重要程度。
  • 综合判断与激活:神经元将所有输入信号与对应的权重相乘后求和,得到一个总刺激值。随后,它会根据一个预设的激活函数来判断是否被“激活”。如果总刺激值超过某个阈值,神经元便会输出一个信号至下一层;反之,则保持静默。

由此可见,单个神经元的功能极为基础和局限,仅能完成一次简单的非线性变换。

1.1.2 架构之力:深度层级结构与抽象理解

单一神经元能力有限,但当它们以特定方式大规模互联时,便能涌现出强大的智能。

  • 输入层:负责接收并分发原始数据,如一个文本序列。
  • 隐藏层:这是网络的核心计算区域,通常由多层神经元构成。每一层都对前一层的输出进行加工,从中提取出更加复杂和抽象的特征。浅层网络可能识别词汇与基础语法,而更深层的网络则可能逐步捕捉到句式结构、修辞手法乃至整体的行文节奏与意境。
  • 输出层:汇聚所有中间结果,生成最终的输出,例如预测下一个字符的概率分布。

这种“深度”结构,使得网络能够进行特征的分层提取与抽象,从具体的数据中逐步构建出对整体风格的宏观把握。

1.2 学习的本质:基于概率的持续优化

神经网络的学习过程,可以精辟地概括为一个在高维空间中进行持续自我优化的概率游戏。

让我们以学习古龙先生的短句风格为例:

  • 目标:训练网络在接收到输入“秋风。落叶。”后,能够生成一个风格契合的下文。

  • 学习机制的闭环流程

    1. 前向传播:将输入数据送入网络。网络基于其当前内部参数(权重),经过层层计算,最终输出一个预测结果。在训练初期,它可能会给出一个平淡无奇且不符合语境的词,如“很多”。

    2. 损失计算:我们将网络的预测“很多”与训练数据中的真实下文(例如,“人独立。”)进行比对。通过一个称为损失函数的数学工具,精确计算出当前预测的“错误程度”。

    3. 反向传播:这是学习过程的关键。损失值被从输出层向输入层反向传递。在此过程中,网络会利用梯度下降算法,精确计算出每一个权重参数对最终误差应负的“责任”大小。

    4. 参数更新:网络根据反向传播计算出的梯度,对其所有权重进行微调。那些导致错误预测的权重被降低,而那些有助于正确预测的权重则被增强。

  • 一个简明的实例: 假设网络在学习过程中遇到短语“一剑刺出,快如”。负责关联“快如”与“流星”的连接权重初始较低,而关联“快如”与“奔马”的权重较高。网络可能首次输出“奔马”。但训练数据表明,在古龙先生的语料库中,“流星”是更高频且更具风格特色的搭配。通过反向传播,网络会系统性地提升“流星”路径的权重,同时降低“奔马”路径的权重。经过大量此类调整,网络便能内化这种语言偏好。

通过在海量数据上无数次重复这一“预测-比较-调整”的循环,神经网络最终将其内部的数百万乃至数十亿个参数,调整至一个能够极其精确地反映训练数据统计特性的状态。它并非“懂得”了风格,而是将风格数字化为一系列复杂的、高维的概率映射关系

至此,我们已经为这位“模仿大师”奠定了坚实的理论基础。接下来,我们将进入激动人心的实战环节,探讨如何将古龙先生笔下那个充满意境的武侠世界,转化为可供其学习和吸收的数字化养料

2. 铸造灵魂之书:将古龙风格转化为机器可读的数据

我们已经知道,神经网络是一位通过数据来学习的“模仿大师”。那么,要想让它学会古龙的文风,我们首先必须为它准备一份精炼的“教材”——一份能完美体现古龙三大核心特色的高质量数据集。这个过程,就如同一位铸剑师在锻造前精心挑选和提炼铁矿,材料的纯度,直接决定了未来宝剑的锋芒。

我们的核心目标,是将古龙笔下那些充满意境与张力的文字,转化为神经网络能够理解的数字序列。但在此之前,我们需要对“教材”的内容进行严格的界定与梳理。

2.1 解构文风:定义机器学习的三大支柱

在投喂数据之前,我们必须让机器明确知道要学什么。因此,我们先将古龙文风的精髓,拆解为三个可被量化和捕捉的维度:

2.1.1 节奏之魂:短句与断句

古龙的文字节奏感极强,他惜墨如金,用最少的字词营造出最强的画面感和悬念。我们的数据集将极力捕捉这种“碎片化”的叙事节奏。

  • 核心特征:句子平均长度短,大量使用句号分割意象,形成类似电影分镜头的效果。
  • 实例样本
    • “风在吹。灯在动。”
    • “他站着。不动。”
    • “刀光一闪。血已溅出。” 这种结构强迫神经网络去学习如何在极简的语境中,建立词与词、句与句之间的强烈关联。
2.1.2 对话之骨:电报体与冷峻哲思

古龙笔下的对话极具辨识度,省略了所有不必要的寒暄与修饰,直抵核心,常常在简洁的问答中蕴含机锋与悬念。

  • 核心特征:对话轮次简洁,问答直接,信息密度高,常带有哲学意味或未言明的潜台词。
  • 实例样本
    • “你怕了?”“不怕。”
    • “谁是敌人?”“看不见的,才是敌人。”
    • “为何出手?”“因为他该杀。”
2.1.3 意境之韵:散文诗化与韵律

这是古龙文风中最为精妙的部分,他将散文的意境与诗歌的韵律融入小说,大量使用排比、设问和对仗,使语言本身成为一种审美对象。

  • 核心特征:语言富有音乐性,结构工整,通过重复和提问来渲染气氛和深化主题。
  • 实例样本
    • “月圆之夜,紫禁之巅,一剑西来,天外飞仙。”
    • “歌者的歌,舞者的舞,剑客的剑,只要不死,就不能停。”
    • “路本是同样的路,只看你怎样去走。”

2.2 数据炼金术:从原始文本到训练样本

拥有了清晰的标准后,我们便可以开始“数据炼金”的过程。这绝非简单的复制粘贴,而是一个精细的提纯与重构流程。

  1. 原料采集:我们从古龙先生的多部代表作中,系统地摘录符合上述三大特征的句子、对话片段和段落。这个过程需要人工筛选,确保送入“熔炉”的都是最精华的部分。

  2. 清洗与标准化:去除文本中的版权信息、章节标题等非内容性文字,确保数据集的纯粹性。

  3. 构造训练对:这是让神经网络学会“创作”的关键一步。我们将完整的句子或段落进行切割,构造出“输入-输出”对。

    • 例如,我们可以将“月圆之夜,紫禁之巅”作为输入,而将“一剑西来,天外飞仙”作为期望的输出。更常见的做法是使用“滑动窗口”,例如,将“黄昏。街上。”作为输入,将“一个人。”作为目标。通过无数个这样的“上下文-下一个词”配对,网络便能逐渐领悟到古龙的用词习惯和行文逻辑。

通过这一系列严谨的数据准备工作,我们不再是向神经网络杂乱无章地投喂文本,而是为它提供了一份结构化的、目标明确的风格词典

至此,一份蕴含着古龙文风灵魂的“教材”已经编纂完毕。这位“模仿大师”即将开始它最重要的课程。在下一部分中,我们将见证它是如何“消化”这本词典,并最终从它的“笔”下,流淌出带有那份独特意蕴的文字。

3. 构建样本数据

为了演示一个完整的神经网络我们会使用LSTM,对于以下三种风格进行核心训练:

  1. 短句为主(多用单句、断句,避免冗长修饰)

  2. 电报体对话(简洁、冷峻、带哲思或悬念)

  3. 散文诗化(韵律感、排比、设问、意境营造)

3.1 为什么要使用 LSTM

  • 优点:能精确捕捉“短句断开”的节奏(如“。”后换意象)、标点使用习惯、重复字词(如“秋天来晚了”多次出现)。

  • 缺点:训练慢、需要更多数据——但你的样本虽少,风格高度一致,字符级反而更利于模仿韵律。

  • 若用词级,需自定义分词(如将“秋天来晚了”视为一个词),否则“秋天/来/晚/了”会丢失整体节奏。

3.2 样本数据

我们根据古龙的原著以及风格摘抄甚至生成了一些标准的符合古龙风格的样本数据。

# 古龙风格的样本数据
samples = ["秋,如一杯冷酒。人,如一场旧梦。","剑气,寒如霜。剑心,冷如冰。剑魂,静如水。","江湖,一场未完的梦。人生,一杯未尽的酒。","风,带走了尘埃。雨,带不走记忆。","夜深了。人静了。唯有那杯酒,还在说话。","刀,不在鞘中,在心中。路,不在脚下,在眼前。","雨,不曾停。泪,不曾干。恨,不曾消。","刀在鞘中。人在风里。血在手上。恩怨在心里。","秋天来晚了,落叶飘零处,剑影如霜寒。","他,一人一剑一壶酒,走过万水千山。","云散了。风停了。人走了。只有那把刀,还在等待。","明月,几时有?江湖,几时休?","酒,解千愁。剑,断万仇。却解不开,那心中的结。","天,很高。地,很远。人,很小。心,很大。","秋,带走了夏的热情。留下的,只有冷清。","风,无形。剑,无情。人,无常。心,无悔。","雪,纷纷扬扬。人,来来往往。唯有那杯酒,始终如一。","夜,很长。路,很远。灯,很暗。心,很亮。","花,开过。人,走过。留下的,只有那把剑和那壶酒。","秋天来晚了,人已不在,情已成空。","山,不语。水,不言。人,不动。心,不静。","他来了,如风。他走了,如梦。留下的,只有那把刀和那滴血。","雨,下个不停。泪,流个不休。恨,却永远不会说出口。","雨,洗不尽血。","孤灯下,剑光寒。","谁在笑?笑里藏刀。","酒,越喝越清醒。人,越走越迷茫。","寒鸦啼,夜未央。","月,照大江。风,吹故人。一杯酒,一把剑,一生情。","秋,不语。人,不言。唯有那把剑,在风中轻轻叹息。","他,一人一剑一壶酒,走过了千山万水。却走不出,自己的心。","血,比水冷。情,比火热。","梦,碎在黎明前。","酒入喉,泪入心。","人未语,刀先鸣。"," 空,胜万语。","刀未冷,人已远。","影,随形而生。心,随情而动。","青衫湿,不是雨。","虚,胜实。","长夜尽,酒未干。","人,一生几何?恨,一世几多?","秋天来晚了,江湖已老,故人何在?","痛,说不出口。","笑,不露齿。杀,不见血。","你走后,风也停。","断桥边,雪落无声。","雨打灯,灯不灭。","光,照不进心底。","心若死,剑亦锈。","恨,比酒烈。","月照影,影随心。","路,走不完。情,断不了。","黄沙起,故人归?","剑出时,天地寂。","命,不由人。","心,静如水。剑,快如风。","旧巷深,新愁生。","醉,不在酒。痴,不在情。","命,如纸薄。心,似铁硬。","恨,无声。爱,无形。","死,不可怕。怕的是白活。","风卷残,云不留。","生,不带来。死,不带去。","雾,吞尽前路。心,照见归途。","残阳血,染旧袍。","谁在哭?哭声藏刃。","雪落刃,不化。恨入骨,不言。","孤舟横,不系岸。","情难收,剑先收。","夜无边,酒有底。","风吹旗,旗不动。","你问仇?仇在酒中。","灯将熄,影愈长。","刀无眼,人有泪。","落叶响,非风过。","心若焚,面如冰。","马停处,非故乡。","旧剑鸣,为新恨。","空杯响,胜千言。","雨打石,石不语。","生死间,一念隔。","花未开,已成冢。","风不止,誓不休。","血,干在袖中。","寒星坠,照孤城。","你拔剑,我闭眼。","命如烟,握不住。","影成三,心独行。","火熄后,灰更烫。","路无碑,人自铭。","酒未尽,人先散。","云遮日,日犹在。","生如刃,死如鞘。","雾,漫过无名冢。","残阳下,孤影拖长街。","谁拔剑,斩断旧日誓?","血,滴在未寄信上。","寒鸦叫,惊破三更梦。","你转身,带走半世光。","心若死,何惧刀锋冷?","断桥雪,埋尽相逢路。","酒未冷,人已隔天涯。","风不止,吹散故人名。","空山里,回声胜言语。","旧刀锈,仍能饮仇血。","命如草,却向天而生。","夜雨急,打湿归人衣。","笑一声,藏尽平生痛。","黄沙起,不见来时路。","灯将灭,照见少年骨。","花落时,无人问归期。","恨入髓,面上却含春。","马蹄远,踏碎月如霜。","死,不过一场长眠。","孤舟横,不系红尘岸。","你问我,可曾怕过死?","影成双,心各在一方。","火燃尽,余温烫旧忆。","路尽头,唯有风相送。","泪无声,浸透十年衣。","云遮月,月照杀人刀。","生如寄,死亦非归处。","剑未鸣,敌已跪尘埃。",
]

有了样本数据我们就要开始训练了。

4. 训练神经网络学习到古龙行文的风格

4.1 训练方法论

4.1.1 数据增强

对原始样本进行增强,通过随机删除词和交换相邻词生成更多训练数据

对原始样本进行数据增强(通过随机删除词和交换相邻词)的主要目的有以下几点:

  1. 增加训练数据量:神经网络模型通常需要大量数据才能学习到有效的模式。我们的古龙风格样本数量有限(只有几十个句子),通过数据增强可以显著增加训练样本数量,而不需要手动创建更多样本。

  2. 提高模型泛化能力:通过对原始样本进行轻微变化(删除某个词或交换词序),可以帮助模型学习到更加鲁棒的语言表示,减少对特定词序的过度依赖,提高在新句子上的泛化能力

  3. 防止过拟合:小数据集容易导致模型过拟合,即模型只是"记住"了训练样本而不是真正学习到语言模式。数据增强通过引入变体,可以降低过拟合风险。

  4. 增强语义理解:通过词序变化和词删除,模型需要学习更深层次的语义关系,而不仅仅是表面的词序模式。

  5. 模拟语言变化:古龙风格的语言本身就有多变性,有时词序调整或省略某些词并不会改变整体风格和含义,这种增强可以模拟这种语言灵活性。

在代码中,数据增强是有选择性的:只对较长的句子(长度>6)进行增强,并且只对更长的句子(长度>8)进行删除操作。这是因为较短的句子删除或交换词可能会显著改变句子含义,而较长的句子有更多冗余,更能承受这种变化。这种数据增强技术是自然语言处理中常用的方法,特别适用于数据有限的情况,如我们这个古龙风格文本生成任务。

4.1.2 使用2层LSTM网络

这两层LSTM是垂直堆叠的,具体工作方式如下:

  1. 第一层LSTM:接收词嵌入(embedding)作为输入,处理序列数据,并输出隐藏状态序列。

  2. 第二层LSTM:接收第一层LSTM的输出作为输入,进一步处理并提取更高级的特征。

增加LSTM层数的主要原因是:

1. 增强模型复杂度:单层LSTM可能无法捕捉古龙风格文本中的复杂语言模式。增加到两层可以让模型学习更复杂的语言结构和长期依赖关系。

2. 层次化特征提取:

  • 第一层LSTM通常学习基本的语法和词序模式

  • 第二层LSTM能够学习更高级的语义和风格特征

3. 改善梯度流动:在深度学习中,适当增加网络深度可以改善信息和梯度的流动,使模型更容易学习复杂模式。

4. 提高生成文本质量:两层LSTM通常能生成更连贯、更符合特定风格的文本,因为它能捕捉更多层次的语言特征。

在实践中,增加到两层LSTM通常是一个很好的平衡点 - 一层可能不够复杂,而三层或更多层可能会导致训练困难或过拟合,特别是对于我们这种相对较小的数据集。

4.1.3 温度(temperature)设定(如 0.1~1.5)
  • 低温使模型更“接近输入的原意”,倾向于复现训练数据中的高频短句结构(如“X在Y中。Z在W里。”),避免胡编长句。

  • 高温会破坏“电报体”的冷峻感,产生啰嗦或不合逻辑的修饰。

4.1.4 模型保存和加载
  • 保存训练好的模型和词汇表

  • 加载已有模型继续使用

4.1.5 防止梯度爆炸和过度拟合
  • 添加Dropout防止过拟合

  • 梯度裁剪防止梯度爆炸

防止过拟合:注意我们在两层之间添加了dropout(dropout=0.2),这有助于防止模型过拟合,特别是在我们的训练数据有限的情况下。

5. 全代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import jieba
import os
import json
import time
import argparse
import platform
from datetime import datetime# 检测并设置设备 - 支持Mac的MPS加速
if torch.backends.mps.is_available() and platform.system() == 'Darwin':device = torch.device("mps")print("使用 Apple MPS (Metal Performance Shaders) 加速")
elif torch.cuda.is_available():device = torch.device("cuda")print("使用 CUDA GPU 加速")
else:device = torch.device("cpu")print("使用 CPU 运行")# 古龙风格的样本数据
samples = ["秋,如一杯冷酒。人,如一场旧梦。","剑气,寒如霜。剑心,冷如冰。剑魂,静如水。","江湖,一场未完的梦。人生,一杯未尽的酒。","风,带走了尘埃。雨,带不走记忆。","夜深了。人静了。唯有那杯酒,还在说话。","刀,不在鞘中,在心中。路,不在脚下,在眼前。","雨,不曾停。泪,不曾干。恨,不曾消。","刀在鞘中。人在风里。血在手上。恩怨在心里。","秋天来晚了,落叶飘零处,剑影如霜寒。","他,一人一剑一壶酒,走过万水千山。","云散了。风停了。人走了。只有那把刀,还在等待。","明月,几时有?江湖,几时休?","酒,解千愁。剑,断万仇。却解不开,那心中的结。","天,很高。地,很远。人,很小。心,很大。","秋,带走了夏的热情。留下的,只有冷清。","风,无形。剑,无情。人,无常。心,无悔。","雪,纷纷扬扬。人,来来往往。唯有那杯酒,始终如一。","夜,很长。路,很远。灯,很暗。心,很亮。","花,开过。人,走过。留下的,只有那把剑和那壶酒。","秋天来晚了,人已不在,情已成空。","山,不语。水,不言。人,不动。心,不静。","他来了,如风。他走了,如梦。留下的,只有那把刀和那滴血。","雨,下个不停。泪,流个不休。恨,却永远不会说出口。","雨,洗不尽血。","孤灯下,剑光寒。","谁在笑?笑里藏刀。","酒,越喝越清醒。人,越走越迷茫。","寒鸦啼,夜未央。","月,照大江。风,吹故人。一杯酒,一把剑,一生情。","秋,不语。人,不言。唯有那把剑,在风中轻轻叹息。","他,一人一剑一壶酒,走过了千山万水。却走不出,自己的心。","血,比水冷。情,比火热。","梦,碎在黎明前。","酒入喉,泪入心。","人未语,刀先鸣。"," 空,胜万语。","刀未冷,人已远。","影,随形而生。心,随情而动。","青衫湿,不是雨。","虚,胜实。","长夜尽,酒未干。","人,一生几何?恨,一世几多?","秋天来晚了,江湖已老,故人何在?","痛,说不出口。","笑,不露齿。杀,不见血。","你走后,风也停。","断桥边,雪落无声。","雨打灯,灯不灭。","光,照不进心底。","心若死,剑亦锈。","恨,比酒烈。","月照影,影随心。","路,走不完。情,断不了。","黄沙起,故人归?","剑出时,天地寂。","命,不由人。","心,静如水。剑,快如风。","旧巷深,新愁生。","醉,不在酒。痴,不在情。","命,如纸薄。心,似铁硬。","恨,无声。爱,无形。","死,不可怕。怕的是白活。","风卷残,云不留。","生,不带来。死,不带去。","雾,吞尽前路。心,照见归途。","残阳血,染旧袍。","谁在哭?哭声藏刃。","雪落刃,不化。恨入骨,不言。","孤舟横,不系岸。","情难收,剑先收。","夜无边,酒有底。","风吹旗,旗不动。","你问仇?仇在酒中。","灯将熄,影愈长。","刀无眼,人有泪。","落叶响,非风过。","心若焚,面如冰。","马停处,非故乡。","旧剑鸣,为新恨。","空杯响,胜千言。","雨打石,石不语。","生死间,一念隔。","花未开,已成冢。","风不止,誓不休。","血,干在袖中。","寒星坠,照孤城。","你拔剑,我闭眼。","命如烟,握不住。","影成三,心独行。","火熄后,灰更烫。","路无碑,人自铭。","酒未尽,人先散。","云遮日,日犹在。","生如刃,死如鞘。","雾,漫过无名冢。","残阳下,孤影拖长街。","谁拔剑,斩断旧日誓?","血,滴在未寄信上。","寒鸦叫,惊破三更梦。","你转身,带走半世光。","心若死,何惧刀锋冷?","断桥雪,埋尽相逢路。","酒未冷,人已隔天涯。","风不止,吹散故人名。","空山里,回声胜言语。","旧刀锈,仍能饮仇血。","命如草,却向天而生。","夜雨急,打湿归人衣。","笑一声,藏尽平生痛。","黄沙起,不见来时路。","灯将灭,照见少年骨。","花落时,无人问归期。","恨入髓,面上却含春。","马蹄远,踏碎月如霜。","死,不过一场长眠。","孤舟横,不系红尘岸。","你问我,可曾怕过死?","影成双,心各在一方。","火燃尽,余温烫旧忆。","路尽头,唯有风相送。","泪无声,浸透十年衣。","云遮月,月照杀人刀。","生如寄,死亦非归处。","剑未鸣,敌已跪尘埃。",
]# 数据预处理
def preprocess_data(samples, seq_length=5):# 分词tokenized_samples = []for sample in samples:words = list(jieba.cut(sample))tokenized_samples.append(words)# 数据增强:创建更多样本augmented_samples = []for words in tokenized_samples:if len(words) > 6:  # 只对较长的句子进行增强# 1. 随机删除一些词if len(words) > 8:delete_idx = random.randint(0, len(words)-1)new_words = words.copy()new_words.pop(delete_idx)augmented_samples.append(new_words)# 2. 随机交换相邻词if len(words) > 4:swap_idx = random.randint(0, len(words)-2)new_words = words.copy()new_words[swap_idx], new_words[swap_idx+1] = new_words[swap_idx+1], new_words[swap_idx]augmented_samples.append(new_words)# 合并原始和增强样本all_samples = tokenized_samples + augmented_samples# 构建词汇表vocab = set()for words in all_samples:vocab.update(words)# 词到索引的映射word_to_idx = {word: i for i, word in enumerate(vocab)}idx_to_word = {i: word for i, word in enumerate(vocab)}vocab_size = len(vocab)print(f"词汇表大小: {vocab_size}")print(f"训练样本数量: 原始={len(tokenized_samples)}, 增强后={len(all_samples)}")# 准备训练数据X = []y = []for words in all_samples:indices = [word_to_idx[word] for word in words]for i in range(len(indices) - seq_length):X.append(indices[i:i+seq_length])y.append(indices[i+seq_length])print(f"训练序列数量: {len(X)}")X = torch.tensor(X, dtype=torch.long)y = torch.tensor(y, dtype=torch.long)return X, y, word_to_idx, idx_to_word, vocab_size# 定义模型
class LSTMModel(nn.Module):def __init__(self, vocab_size, embedding_dim, hidden_dim):super(LSTMModel, self).__init__()self.embedding = nn.Embedding(vocab_size, embedding_dim)self.dropout_emb = nn.Dropout(0.3)  # 添加Dropout防止过拟合self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, batch_first=True, dropout=0.2)  # 增加LSTM层数self.dropout_out = nn.Dropout(0.3)  # 输出层前的Dropoutself.fc = nn.Linear(hidden_dim, vocab_size)def forward(self, x):embeds = self.embedding(x)embeds = self.dropout_emb(embeds)lstm_out, _ = self.lstm(embeds)out = self.dropout_out(lstm_out[:, -1, :])out = self.fc(out)return out# 保存模型和词汇表
def save_model(model, word_to_idx, idx_to_word, save_dir='models'):# 创建保存目录if not os.path.exists(save_dir):os.makedirs(save_dir)# 生成时间戳作为模型标识timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")model_path = os.path.join(save_dir, f'gulongstyle_model_{timestamp}.pt')vocab_path = os.path.join(save_dir, f'gulongstyle_vocab_{timestamp}.json')# 保存模型torch.save(model.state_dict(), model_path)# 保存词汇表vocab_data = {'word_to_idx': word_to_idx,'idx_to_word': {int(k): v for k, v in idx_to_word.items()}  # 确保键是整数}with open(vocab_path, 'w', encoding='utf-8') as f:json.dump(vocab_data, f, ensure_ascii=False, indent=2)print(f"模型已保存到: {model_path}")print(f"词汇表已保存到: {vocab_path}")return model_path, vocab_path# 加载模型和词汇表
def load_model(model_path, vocab_path, embedding_dim=256, hidden_dim=512):# 加载词汇表with open(vocab_path, 'r', encoding='utf-8') as f:vocab_data = json.load(f)word_to_idx = vocab_data['word_to_idx']idx_to_word = {int(k): v for k, v in vocab_data['idx_to_word'].items()}vocab_size = len(word_to_idx)# 创建模型model = LSTMModel(vocab_size, embedding_dim, hidden_dim)# 加载模型参数 - 支持不同设备间加载if torch.backends.mps.is_available() and platform.system() == 'Darwin':# 从CPU加载到MPSmodel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))else:model.load_state_dict(torch.load(model_path))# 将模型移至适当的设备model = model.to(device)model.eval()  # 设置为评估模式print(f"模型已从 {model_path} 加载")print(f"词汇表已从 {vocab_path} 加载")return model, word_to_idx, idx_to_word# 训练模型
def train_model(X, y, vocab_size, embedding_dim=256, hidden_dim=512, epochs=500, save_dir='models'):# 将模型移至适当的设备(MPS/GPU/CPU)model = LSTMModel(vocab_size, embedding_dim, hidden_dim).to(device)criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)  # 降低学习率,提高稳定性scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, factor=0.5)  # 学习率调度器# 创建小批量数据batch_size = 32  # 增大批量大小以提高训练速度# 将数据移至适当的设备X_device = X.to(device)y_device = y.to(device)dataset = torch.utils.data.TensorDataset(X_device, y_device)dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)best_loss = float('inf')patience_counter = 0patience_limit = 50  # 提前停止的耐心值best_model = None# 记录训练开始时间start_time = time.time()for epoch in range(epochs):model.train()  # 设置为训练模式total_loss = 0for batch_x, batch_y in dataloader:optimizer.zero_grad()output = model(batch_x)loss = criterion(output, batch_y)loss.backward()# 梯度裁剪,防止梯度爆炸torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)optimizer.step()total_loss += loss.item()avg_loss = total_loss / len(dataloader)scheduler.step(avg_loss)  # 更新学习率# 提前停止检查if avg_loss < best_loss:best_loss = avg_losspatience_counter = 0best_model = model.state_dict().copy()  # 保存最佳模型else:patience_counter += 1if patience_counter >= patience_limit:print(f'提前停止训练,Epoch: {epoch+1}/{epochs}')breakif (epoch + 1) % 10 == 0:# 计算已用时间和估计剩余时间elapsed_time = time.time() - start_timeavg_time_per_epoch = elapsed_time / (epoch + 1)remaining_epochs = epochs - (epoch + 1)est_remaining_time = avg_time_per_epoch * remaining_epochsprint(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, LR: {optimizer.param_groups[0]["lr"]:.6f}')print(f'已用时间: {elapsed_time:.1f}秒, 估计剩余时间: {est_remaining_time:.1f}秒')# 加载最佳模型if best_model is not None:model.load_state_dict(best_model)# 计算总训练时间total_time = time.time() - start_timeprint(f'训练完成,总时间: {total_time:.1f}秒')return model# 生成文本
def generate_text(model, seed_text, word_to_idx, idx_to_word, seq_length, num_words=10, temperature=0.8):model.eval()words = list(jieba.cut(seed_text))# 如果种子文本长度不足,使用填充策略if len(words) < seq_length:# 使用种子文本中的词进行填充padding = words * (seq_length // len(words) + 1)words = padding[:seq_length]else:words = words[-seq_length:]result = list(words)for _ in range(num_words):# 将当前序列转换为索引try:x = torch.tensor([[word_to_idx[word] for word in words]], dtype=torch.long).to(device)except KeyError:# 处理词汇表外的词unknown_words = [word for word in words if word not in word_to_idx]print(f"种子文本包含未知词汇: {', '.join(unknown_words)}")# 尝试替换未知词为词表中的词for i, word in enumerate(words):if word not in word_to_idx:# 随机选择词表中的一个词替换replacement = random.choice(list(word_to_idx.keys()))words[i] = replacementprint(f"将未知词 '{word}' 替换为 '{replacement}'")# 重新尝试try:x = torch.tensor([[word_to_idx[word] for word in words]], dtype=torch.long).to(device)except KeyError:print("无法处理未知词汇,停止生成")break# 预测下一个词with torch.no_grad():output = model(x)# 使用温度参数调整分布scaled_output = output / temperatureprobs = torch.softmax(scaled_output, dim=1)# Top-k采样,只保留概率最高的前5个词top_k = 5top_indices = torch.topk(probs, top_k).indices[0]top_probs = probs[0, top_indices]top_probs = top_probs / top_probs.sum()  # 重新归一化# 从top-k中采样next_idx = top_indices[torch.multinomial(top_probs, 1).item()]# 添加预测的词到结果中next_word = idx_to_word[next_idx.item()]result.append(next_word)# 更新序列words = words[1:] + [next_word]return ''.join(result)# 命令行参数解析
def parse_args():parser = argparse.ArgumentParser(description='古龙风格文本生成器')parser.add_argument('--train', action='store_true', help='训练新模型')parser.add_argument('--generate', action='store_true', help='生成文本')parser.add_argument('--model', type=str, help='模型文件路径')parser.add_argument('--vocab', type=str, help='词汇表文件路径')parser.add_argument('--seed', type=str, default='一把剑', help='种子文本')parser.add_argument('--length', type=int, default=10, help='生成文本长度')parser.add_argument('--temperature', type=float, default=0.8, help='采样温度')parser.add_argument('--seq_length', type=int, default=2, help='序列长度')parser.add_argument('--epochs', type=int, default=500, help='训练轮数')return parser.parse_args()# 主函数
def main():# 解析命令行参数args = parse_args()# 如果没有指定操作,默认进入交互模式if not args.train and not args.generate and not args.model:interactive_mode()return# 训练模式if args.train:train_mode(args)# 生成模式elif args.generate:if not args.model or not args.vocab:print("错误:生成模式需要指定模型和词汇表文件路径")returngenerate_mode(args)# 交互模式
def interactive_mode():seed_text = "一把剑"seq_length = 2print("=== 古龙风格文本生成器 - 交互模式 ===")print("1. 训练新模型")print("2. 加载已有模型")choice = input("请选择操作 (1/2): ")if choice == '1':print("\n正在处理数据...")X, y, word_to_idx, idx_to_word, vocab_size = preprocess_data(samples, seq_length)print("\n正在训练模型...")model = train_model(X, y, vocab_size)# 保存模型save_choice = input("\n是否保存模型? (y/n): ")if save_choice.lower() == 'y':model_path, vocab_path = save_model(model, word_to_idx, idx_to_word)# 使用不同的温度参数生成多个结果print("\n生成文本示例:")temperatures = [0.5, 0.8, 1.0]for temp in temperatures:generated_text = generate_text(model, seed_text, word_to_idx, idx_to_word, seq_length, temperature=temp)print(f"温度={temp},'{seed_text}' -> '{generated_text}'")elif choice == '2':model_path = input("\n请输入模型文件路径: ")vocab_path = input("请输入词汇表文件路径: ")try:model, word_to_idx, idx_to_word = load_model(model_path, vocab_path)seq_length = 2  # 默认序列长度except Exception as e:print(f"加载模型失败: {e}")returnelse:print("无效的选择")return# 进入交互式生成模式print("\n=== 进入交互式生成模式 ===")print("可以输入任意文本作为种子,系统将生成古龙风格的后续文本")print("输入'退出'结束程序")while True:user_input = input("\n请输入种子文本: ")if user_input.lower() == '退出':breaktemp = float(input("请输入温度参数 (0.1-1.5,推荐0.8): ") or "0.8")length = int(input("请输入生成长度: ") or "10")generated_text = generate_text(model, user_input, word_to_idx, idx_to_word, seq_length, num_words=length, temperature=temp)print(f"\n生成结果: '{user_input}' -> '{generated_text}'")# 训练模式
def train_mode(args):print("=== 古龙风格文本生成器 - 训练模式 ===")seq_length = args.seq_lengthprint("\n正在处理数据...")X, y, word_to_idx, idx_to_word, vocab_size = preprocess_data(samples, seq_length)print("\n正在训练模型...")model = train_model(X, y, vocab_size, epochs=args.epochs)# 保存模型model_path, vocab_path = save_model(model, word_to_idx, idx_to_word)# 生成示例print("\n生成文本示例:")seed_text = args.seedgenerated_text = generate_text(model, seed_text, word_to_idx, idx_to_word, seq_length, num_words=args.length, temperature=args.temperature)print(f"'{seed_text}' -> '{generated_text}'")# 生成模式
def generate_mode(args):print("=== 古龙风格文本生成器 - 生成模式 ===")try:model, word_to_idx, idx_to_word = load_model(args.model, args.vocab)seq_length = args.seq_lengthexcept Exception as e:print(f"加载模型失败: {e}")returnseed_text = args.seedgenerated_text = generate_text(model, seed_text, word_to_idx, idx_to_word, seq_length, num_words=args.length, temperature=args.temperature)print(f"'{seed_text}' -> '{generated_text}'")# 是否进入交互模式interactive = input("\n是否进入交互模式? (y/n): ")if interactive.lower() == 'y':while True:user_input = input("\n请输入种子文本(输入'退出'结束): ")if user_input == '退出':breaktemp = float(input("请输入温度参数 (0.1-1.5,推荐0.8): ") or "0.8")length = int(input("请输入生成长度: ") or "10")generated_text = generate_text(model, user_input, word_to_idx, idx_to_word, seq_length, num_words=length, temperature=temp)print(f"\n生成结果: '{user_input}' -> '{generated_text}'")if __name__ == "__main__":main()

5.1 如何运行

直接运行:

python WordsPredict.py

训练模式:

python3 WordsPredict.py --train --epochs 200

生成模式:

python WordsPredict.py --generate --model path/to/model.pt --vocab path/to/vocab.json --seed "江湖"

--model参数的解释

  • 这是指向训练好的模型文件的路径

  • 模型文件会在训练后自动保存在 models 目录下

  • 文件命名格式为:gulongstyle_model_时间戳.pt

  • 例如:models/gulongstyle_model_20251014_153045.pt

--vocab参数的解释

  • 词汇表文件包含了模型训练时使用的所有词语及其索引映射

  • 它也会保存在 models 目录下

  • 文件命名格式为:gulongstyle_vocab_时间戳.json

  • 例如:models/gulongstyle_vocab_20251014_153045.json

5.2 运行起来看效果

安装必要的依赖库(基于python10)

这是我们的requirements.txt文件内容

flask==2.2.3
requests==2.28.2numpy==1.24.3
torch==2.2.2
torchvision==0.17.2
torchaudio==2.2.2
jieba==0.42.1

我们可以使用以下命令来安装这些必要的库

pip install -r requirements.txt
训练

我们先用以下语句来运行训练

python WordsPredict.py --generate --model path/to/model.pt --vocab path/to/vocab.json --seed "拨刀相见"

我们的代码是支持mac/windows/linux不同操作系统的,它具备这样的运行模式:

  1. 自动检测是否可以使用Apple MPS加速
  2. 如果不可用,会尝试使用CUDA(如果有NVIDIA GPU)
  3. 否则回退到CPU

将模型移至GPU设备:

  1. 模型创建后自动移至适当的设备(MPS/GPU/CPU)
  2. 确保所有计算都在同一设备上进行

数据加速优化:

  1. 将训练数据移至GPU设备
  2. 增大批量大小从16到32,提高训练吞吐量

兼顾模型加载兼容性:

  1. 添加跨设备模型加载支持
  2. 确保从CPU保存的模型可以正确加载到MPS设备

文本生成过程时也使用了GPU加速。所以整体训练过程相当的快。

当看到以下输出,代表训练正常结束了,如果是在MPS或者是在GPU上,这个过程只要4~5分钟。

训练成功后在同python运行文件目录下会生成一个名为:models的目录

训练完毕的模型和vocab文件都位于此,这样我们如果每次只是单纯生成古龙风格的短语时就可以每次运行时选择相应和合适的训练的模型而不需要重复训练这个漫长的过程了。

运行
python WordsPredict.py --generate --model path/to/model.pt --vocab path/to/vocab.json --seed "拨刀相见"

python WordsPredict.py --generate --model path/to/model.pt --vocab path/to/vocab.json --seed "人己走"

怎么样?这个效果还是相当不错的。

不过我们为了演示,样本数据太少,如果有2,000左右样本,输出的结果还是很惊艳的。

好了,结束今天的分享,关键还是在于自己多动动手。

http://www.dtcms.com/a/481641.html

相关文章:

  • AI破解数学界遗忘谜题:GPT-5重新发现尘封二十年的埃尔德什问题解法
  • ui网站推荐如何建网站不花钱
  • Java版自助共享空间系统,打造高效无人值守智慧实体门店
  • 《超越单链表的局限:双链表“哨兵位”设计模式,如何让边界处理代码既优雅又健壮?》
  • HENGSHI SENSE 6.0技术白皮书:基于HQL语义层的Agentic BI动态计算引擎架构解析
  • C#实现MySQL→Clickhouse建表语句转换工具
  • 禁止下载app网站东莞网
  • MySQL数据库精研之旅第十九期:存储过程,数据处理的全能工具箱(二)
  • Ubuntu Linux 服务器快速安装 Docker 指南
  • Linux 信号捕捉与软硬中断
  • Linux NTP配置全攻略:从客户端到服务端
  • 二分查找专题总结:从数组越界到掌握“两段性“
  • aws ec2防ssh爆破, aws服务器加固, 亚马逊服务器ssh安全,防止ip扫描ssh。 aws安装fail2ban, ec2配置fail2ban
  • F024 CNN+vue+flask电影推荐系统vue+python+mysql+CNN实现
  • 谷歌生成在线网站地图买外链网站
  • Redis Key的设计
  • Redis 的原子性操作
  • 竹子建站免费版七牛云cdn加速wordpress
  • python进阶_Day8
  • 在React中如何应用函数式编程?
  • selenium的css定位方式有哪些
  • RabbitMq快速入门程序
  • Qt模型控件:QTreeView应用
  • selenium常用的等待有哪些?
  • 基于51单片机水位监测控制自动抽水—LCD1602
  • 电脑系统做的好的几个网站wordpress主题很卡
  • 数据结构和算法篇-环形缓冲区
  • iOS 26 性能分析深度指南 包含帧率、渲染、资源瓶颈与 KeyMob 协助策略
  • vs网站建设弹出窗口代码c网页视频下载神器哪种最好
  • Chrome性能优化秘籍