博客
关于我
残差网络(RESNET)
阅读量:361 次
发布时间:2019-03-04

本文共 3053 字,大约阅读时间需要 10 分钟。

残差网络(ResNet)

引言

残差网络(ResNet)是由何恺明等人于2015年提出的深度神经网络架构,广泛应用于图像分类任务中。其核心思想是通过引入残差块(residual blocks),解决深度网络训练中梯度消失问题,从而实现更深的网络结构。

残差块设计

残差块通过跳跃连接(skip connection),将早期层的输出直接与当前层的输出相加,防止梯度消失问题。其设计基于以下关键观察:

  • 增加层会扩展模型的可学习空间,但若新层仅作为恒等映射使用,模型性能不会有显著提升。
  • 实践中,随着网络深度增加,训练误差反而上升,这被称为“深度陷阱”。
残差块实现

残差块由两个卷积层、批量归一化层和激活函数组成,具体步骤如下:

  • 第一个卷积层(3x3)将输入通道数变换为目标输出通道数。
  • 第二个卷积层(3x3)再次应用同样的通道数变换。
  • 通过跳跃连接,将第一个卷积层的输出直接与第二个卷积层的输出相加。
  • 为了实现跳跃连接,若需要改变通道数,需在输入通道数与目标通道数之间插入一个1x1卷积层。
  • 示例代码
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
    super(Residual, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    if use_1x1conv:
    self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
    else:
    self.conv3 = None
    def forward(self, X):
    Y = self.relu(self.bn1(self.conv1(X)))
    Y = self.bn2(self.conv2(Y))
    if self.conv3:
    X = self.conv3(X)
    return self.relu(Y + X)
    输入输出形状验证
    blk = Residual(3, 3)  # 没有使用1x1卷积层
    X = torch.randn((4, 3, 6, 6)) # 输入形状为(批量,通道,高度,宽度)
    print(blk(X).shape) # 输出:(4, 3, 6, 6)

    ResNet模型架构

    ResNet的核心结构包括:

  • 初始卷积层:7x7卷积层(64通道),后接最大池化层。
  • 4个残差块模块,每个模块包含2个残差块。
  • 全局平均池化层后接全连接层输出。
  • 模块设计
    • 第一个模块:通道数与输入一致。
    • 后续模块:通道数每层翻倍,输出尺寸减半。
    代码实现
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    def resnet_block(in_channels, out_channels, num_residuals, first_block=False):
    blk = []
    if first_block:
    assert in_channels == out_channels
    blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))
    else:
    blk.append(Residual(in_channels, out_channels))
    blk.append(Residual(out_channels, out_channels))
    return blk
    # ResNet-18架构
    net = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    )
    for i in range(4):
    net.add_module(f"resnet_block{i+1}", resnet_block(64, 64 if i == 0 else 128, 2, first_block=(i == 0)))
    net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1, 1)))
    net.add_module("fc", nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, 10)
    ))

    训练与结果

    在fashion-mnist数据集上训练ResNet-18模型:

    batch_size = 256
    train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)
    lr = 0.001
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)
    d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs=5)

    训练结果显示,ResNet-18在5个epoch内取得了较好的性能:

    epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 sec
    epoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec
    ...

    小结

    ResNet通过引入残差块解决了深度网络中的梯度消失问题,实现了更深的网络结构。其简单的设计和高效的训练效果使其在图像分类任务中广泛应用,成为深度学习领域的重要进展之一。

    转载地址:http://lpfr.baihongyu.com/

    你可能感兴趣的文章
    NFS共享文件系统搭建
    查看>>
    nfs复习
    查看>>
    NFS安装配置
    查看>>
    NFS的安装以及windows/linux挂载linux网络文件系统NFS
    查看>>
    NFS的常用挂载参数
    查看>>
    NFS网络文件系统
    查看>>
    nft文件传输_利用remoting实现文件传输-.NET教程,远程及网络应用
    查看>>
    NFV商用可行新华三vBRAS方案实践验证
    查看>>
    ng build --aot --prod生成文件报错
    查看>>
    ng 指令的自定义、使用
    查看>>
    nghttp3使用指南
    查看>>
    Nginx
    查看>>
    nginx + etcd 动态负载均衡实践(三)—— 基于nginx-upsync-module实现
    查看>>
    nginx + etcd 动态负载均衡实践(二)—— 组件安装
    查看>>
    nginx + etcd 动态负载均衡实践(四)—— 基于confd实现
    查看>>
    Nginx + Spring Boot 实现负载均衡
    查看>>
    Nginx + uWSGI + Flask + Vhost
    查看>>
    Nginx - Header详解
    查看>>
    nginx 1.24.0 安装nginx最新稳定版
    查看>>
    nginx css,js合并插件,淘宝nginx合并js,css插件
    查看>>