博客
关于我
残差网络(RESNET)
阅读量:361 次
发布时间:2019-03-04

本文共 3000 字,大约阅读时间需要 10 分钟。

残差网络(ResNet)

引言

残差网络(ResNet)是由何恺明等人于2015年提出的深度神经网络架构,广泛应用于图像分类任务中。其核心思想是通过引入残差块(residual blocks),解决深度网络训练中梯度消失问题,从而实现更深的网络结构。

残差块设计

残差块通过跳跃连接(skip connection),将早期层的输出直接与当前层的输出相加,防止梯度消失问题。其设计基于以下关键观察:

  • 增加层会扩展模型的可学习空间,但若新层仅作为恒等映射使用,模型性能不会有显著提升。
  • 实践中,随着网络深度增加,训练误差反而上升,这被称为“深度陷阱”。
残差块实现

残差块由两个卷积层、批量归一化层和激活函数组成,具体步骤如下:

  • 第一个卷积层(3x3)将输入通道数变换为目标输出通道数。
  • 第二个卷积层(3x3)再次应用同样的通道数变换。
  • 通过跳跃连接,将第一个卷积层的输出直接与第二个卷积层的输出相加。
  • 为了实现跳跃连接,若需要改变通道数,需在输入通道数与目标通道数之间插入一个1x1卷积层。
  • 示例代码
    import torchimport torch.nn as nnimport torch.nn.functional as Fclass Residual(nn.Module):    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):        super(Residual, self).__init__()        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride)        self.bn1 = nn.BatchNorm2d(out_channels)        self.relu = nn.ReLU()        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)        self.bn2 = nn.BatchNorm2d(out_channels)        if use_1x1conv:            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)        else:            self.conv3 = None    def forward(self, X):        Y = self.relu(self.bn1(self.conv1(X)))        Y = self.bn2(self.conv2(Y))        if self.conv3:            X = self.conv3(X)        return self.relu(Y + X)
    输入输出形状验证
    blk = Residual(3, 3)  # 没有使用1x1卷积层X = torch.randn((4, 3, 6, 6))  # 输入形状为(批量,通道,高度,宽度)print(blk(X).shape)  # 输出:(4, 3, 6, 6)

    ResNet模型架构

    ResNet的核心结构包括:

  • 初始卷积层:7x7卷积层(64通道),后接最大池化层。
  • 4个残差块模块,每个模块包含2个残差块。
  • 全局平均池化层后接全连接层输出。
  • 模块设计
    • 第一个模块:通道数与输入一致。
    • 后续模块:通道数每层翻倍,输出尺寸减半。
    代码实现
    import torchimport torch.nn as nnimport torch.nn.functional as Fdef resnet_block(in_channels, out_channels, num_residuals, first_block=False):    blk = []    if first_block:        assert in_channels == out_channels        blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2))    else:        blk.append(Residual(in_channels, out_channels))        blk.append(Residual(out_channels, out_channels))    return blk# ResNet-18架构net = nn.Sequential(    nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),    nn.BatchNorm2d(64),    nn.ReLU(),    nn.MaxPool2d(kernel_size=3, stride=2, padding=1))for i in range(4):    net.add_module(f"resnet_block{i+1}", resnet_block(64, 64 if i == 0 else 128, 2, first_block=(i == 0)))net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1, 1)))net.add_module("fc", nn.Sequential(    nn.Flatten(),    nn.Linear(512, 10)))

    训练与结果

    在fashion-mnist数据集上训练ResNet-18模型:

    batch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr = 0.001optimizer = torch.optim.Adam(net.parameters(), lr=lr)d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs=5)

    训练结果显示,ResNet-18在5个epoch内取得了较好的性能:

    epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 secepoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec...

    小结

    ResNet通过引入残差块解决了深度网络中的梯度消失问题,实现了更深的网络结构。其简单的设计和高效的训练效果使其在图像分类任务中广泛应用,成为深度学习领域的重要进展之一。

    转载地址:http://lpfr.baihongyu.com/

    你可能感兴趣的文章
    Objective-C实现FTP文件上传(附完整源码)
    查看>>
    Objective-C实现FTP文件下载(附完整源码)
    查看>>
    Objective-C实现fuzzy operations模糊运算算法(附完整源码)
    查看>>
    Objective-C实现Gale-Shapley盖尔-沙普利算法(附完整源码)
    查看>>
    Objective-C实现gamma recursive伽玛递归算法(附完整源码)
    查看>>
    Objective-C实现gamma 伽玛功能算法(附完整源码)
    查看>>
    Objective-C实现gauss easte高斯复活节日期算法(附完整源码)
    查看>>
    Objective-C实现gaussian filter高斯滤波器算法(附完整源码)
    查看>>
    Objective-C实现gaussian naive bayes高斯贝叶斯算法(附完整源码)
    查看>>
    Objective-C实现gaussian高斯算法(附完整源码)
    查看>>
    Objective-C实现geometric series几何系列算法(附完整源码)
    查看>>
    Objective-C实现getline函数功能(附完整源码)
    查看>>
    Objective-C实现gnome sortt侏儒排序算法(附完整源码)
    查看>>
    Objective-C实现graph list图列算法(附完整源码)
    查看>>
    Objective-C实现GraphEdge图边算法(附完整源码)
    查看>>
    Objective-C实现GraphVertex图顶点算法(附完整源码)
    查看>>
    Objective-C实现greatest common divisor最大公约数算法(附完整源码)
    查看>>
    Objective-C实现greedy coin change贪心硬币找零算法(附完整源码)
    查看>>
    Objective-C实现greedy knapsack贪婪的背包算法(附完整源码)
    查看>>
    Objective-C实现GridGet算法(附完整源码)
    查看>>