本文共 3053 字,大约阅读时间需要 10 分钟。
残差网络(ResNet)是由何恺明等人于2015年提出的深度神经网络架构,广泛应用于图像分类任务中。其核心思想是通过引入残差块(residual blocks),解决深度网络训练中梯度消失问题,从而实现更深的网络结构。
残差块通过跳跃连接(skip connection),将早期层的输出直接与当前层的输出相加,防止梯度消失问题。其设计基于以下关键观察:
残差块由两个卷积层、批量归一化层和激活函数组成,具体步骤如下:
import torchimport torch.nn as nnimport torch.nn.functional as Fclass Residual(nn.Module): def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1): super(Residual, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride) self.bn1 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(out_channels) if use_1x1conv: self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) else: self.conv3 = None def forward(self, X): Y = self.relu(self.bn1(self.conv1(X))) Y = self.bn2(self.conv2(Y)) if self.conv3: X = self.conv3(X) return self.relu(Y + X)
blk = Residual(3, 3) # 没有使用1x1卷积层X = torch.randn((4, 3, 6, 6)) # 输入形状为(批量,通道,高度,宽度)print(blk(X).shape) # 输出:(4, 3, 6, 6)
ResNet的核心结构包括:
import torchimport torch.nn as nnimport torch.nn.functional as Fdef resnet_block(in_channels, out_channels, num_residuals, first_block=False): blk = [] if first_block: assert in_channels == out_channels blk.append(Residual(in_channels, out_channels, use_1x1conv=True, stride=2)) else: blk.append(Residual(in_channels, out_channels)) blk.append(Residual(out_channels, out_channels)) return blk# ResNet-18架构net = nn.Sequential( nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))for i in range(4): net.add_module(f"resnet_block{i+1}", resnet_block(64, 64 if i == 0 else 128, 2, first_block=(i == 0)))net.add_module("global_avg_pool", nn.AdaptiveAvgPool2d((1, 1)))net.add_module("fc", nn.Sequential( nn.Flatten(), nn.Linear(512, 10))) 在fashion-mnist数据集上训练ResNet-18模型:
batch_size = 256train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)lr = 0.001optimizer = torch.optim.Adam(net.parameters(), lr=lr)d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs=5)
训练结果显示,ResNet-18在5个epoch内取得了较好的性能:
epoch 1, loss 0.0015, train acc 0.853, test acc 0.885, time 31.0 secepoch 2, loss 0.0010, train acc 0.910, test acc 0.899, time 31.8 sec...
ResNet通过引入残差块解决了深度网络中的梯度消失问题,实现了更深的网络结构。其简单的设计和高效的训练效果使其在图像分类任务中广泛应用,成为深度学习领域的重要进展之一。
转载地址:http://lpfr.baihongyu.com/