当前位置: 首页 > news >正文

元学习之孪生网络Siamese Network

        简介:元学习是一种思想,一般以神经网络作为特征嵌入的工具,实现对数据特征的提取,然后通过构造某种指标以引导优化器对模型参数进行优化。而最小化距离是最常见的学习目标,这就是熟知的度量学习,度量学习里面经典的训练范式就是孪生网络。

1、小样本学习

        小样本学习是指用于训练的数据很少,以分类任务为例,minist数据一共有10个类别,每个类别差不多有几百张图片,传统的训练方式是一股脑的把所有训练集数据给端到端模型进行训练,得到一个模型,然后在测试集上测试。

        在小样本学习当中,每个类别仅能够使用很少的图片,比如10个类别每个类别使用5张图片,则称为10ways-5shots,10个类别每个类别使用2张图片,则称为10ways-2shots。在这么少的数据情况下,一般的端到端模型肯定学不到东西,导致效果变差。

        那么换个思路,让神经网络生成表征即可,但是得按照我的思路进行生成,思路就是你神经网络生成的样本表征需要满足下面的条件:相同的图片表征距离尽量靠近、不相同的图片表征距离尽量原理,然后构造一个自定义损失函数,进行训练即可。

        可以看到,度量学习本质上就是在神经网络后面添加一个额外的网络层,这个网络层对神经网络的输出表征进行处理,输出一个度量值,也就是自定义了一个损失函数网络层。在torch当中,原理层面就说构造了一个新的计算图,使得优化器的优化目标进行了改变,而这种改变也会使得神经网络的权重变成我们想要的情况,也就是这个自定义的度量损失函数指导了神经网络权重的学习,这就是元学习的体现。

        换一种说法就是,有一个初始的神经网络,我们需要改变他的权重,但不能直接让这个神经网络去参与训练。我们需要对神经网络的输出进行加工,得到另一种令人接受结果,然后使用万能的优化器优化这个结果,当这个结果确实令人接受了,那么神经网络的权重自然而然也就令人接受了。

 2、孪生网络数据集

        下面是孪生网络的数据集格式。  


from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random
import torch.functional as F
from tqdm import tqdm

class SiameseDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):

        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),  # 调整图片大小
            transforms.ToTensor(),  # 转换为张量
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  
        ])
        self.label_to_indices = self._create_label_to_indices()

    def _create_label_to_indices(self):
        """
        创建一个字典,将每个标签映射到具有该标签的所有图像的索引列表
        """
        label_to_indices = {}
        for idx, label in enumerate(self.labels):
            if label not in label_to_indices:
                label_to_indices[label] = []
            label_to_indices[label].append(idx)
        return label_to_indices

    def __len__(self):
        """返回数据集的大小"""
        return len(self.image_paths)

    def __getitem__(self, index):
        """
        返回一对图像和一个标签,指示这对图像是否属于同一类别
        """
        # 随机选择是否返回同一类别的图像对
        label = self.labels[index]
        if random.random() < 0.5:
            # 选择同一类别的图像
            siamese_index = random.choice(self.label_to_indices[label])
            target = 1  # 1 表示同一类别
        else:
            # 选择不同类别的图像
            other_labels = [l for l in self.label_to_indices.keys() if l != label]
            other_label = random.choice(other_labels)
            siamese_index = random.choice(self.label_to_indices[other_label])
            target = 0  # 0 表示不同类别

        # 加载图像
        image1 = Image.open(self.image_paths[index]).convert('RGB')
        image2 = Image.open(self.image_paths[siamese_index]).convert('RGB')

        # 应用变换
        if self.transform:
            image1 = self.transform(image1)
            image2 = self.transform(image2)

        return image1, image2, target

 3、损失函数


class ContrastiveLoss(nn.Module):
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

相关文章:

  • Whisper+T5-translate实现python实时语音翻译
  • 【MySQL】高频 SQL 50 题(基础版)
  • 每日一题——矩阵最长递增路径
  • 算法-哈希表03-快乐数
  • Django ORM:外键字段的命名与查询机制解析
  • Linux进程调度
  • DeepSeek 开放平台无法充值使用 改用其他中转平台API调用DeepSeek-chat模型方法
  • 变电站激光驱鸟器:绿色技术助力电网安全,减少鸟类威胁
  • C# 异步编程Async/Await 原理及使用详解
  • 【2023 K8s CKA】云原生K8s管理员认证课-零基础 考题更新免费学-全新PSI考试系统
  • Git子模块实战:大型后台管理系统模块拆分实践
  • elementUI rules 判断 el-cascader控件修改值未生效
  • Qt中QApplication 类和uic、moc程序
  • Node.js调用DeepSeek Api 实现本地智能聊天的简单应用
  • DeepSeek R1生成图片总结(虽然本身是不能直接生成图片,但是可以想办法利用别的工具一起实现)
  • Linux入侵检查流程
  • 使用 Visual Studio Code (VS Code) 开发 Python 图形界面程序
  • 你认为如何理解“约定大于配置”?
  • CentOS 系统上安装 Anaconda3-2022.05-Linux-x86_64.sh linux安装python3.9
  • 缓存三大问题及其解决方案
  • 怀化找什么人做网站/品牌seo如何优化
  • 如何做网站访问日志/企业网站管理系统
  • wordpress rss修改/化工网站关键词优化
  • 网站做鸭/百度网址大全在哪里找
  • 自建企业邮箱/网站seo的优化怎么做
  • 寻乌网站建设/培训机构招生方案模板