trash-division/Model.py

108 lines
3.8 KiB
Python
Raw Normal View History

2026-04-10 08:11:49 +00:00
"""
这个文件是模型的定义文件请不要擅自修改如有疑问微信群里反馈
单独运行本文件将会输出模型结构
2026-04-10 13:04:09 +00:00
目前的话是一个36层的模型模型总量应该是在80M左右 如果到时候还是欠拟合的话再考虑去做更深的结构
2026-04-10 08:11:49 +00:00
author : yukun-hh
date : 2026-4-10
"""
import torch
from torch import nn
from torch.nn import functional as F
2026-04-10 13:04:09 +00:00
from torchsummary import summary
2026-04-16 05:55:02 +00:00
# 残差块
2026-04-10 08:11:49 +00:00
class Resblock(nn.Module):
2026-04-16 05:55:02 +00:00
def __init__(self, input_channels, output_channels, use_1x1conv=False, strides=1):
2026-04-10 08:11:49 +00:00
"""
:param input_channels: 进入残差块时的原通道
:param output_channels: 输出时的通道数
:param use_1x1conv: 如果输入和输出通道不相等时需要用一个1x1的卷积层对原来的输入进行一个通道提升
:param strides: 默认1如果大于1起到缩小张量的作用
"""
super().__init__()
2026-04-16 05:55:02 +00:00
self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1, stride=strides)
self.conv2 = nn.Conv2d(output_channels, output_channels, kernel_size=3, padding=1, stride=1)
2026-04-10 08:11:49 +00:00
if use_1x1conv:
2026-04-16 05:55:02 +00:00
self.conv3 = nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=strides)
2026-04-10 08:11:49 +00:00
else:
self.conv3 = None
self.bn1 = nn.BatchNorm2d(output_channels)
self.bn2 = nn.BatchNorm2d(output_channels)
2026-04-16 05:55:02 +00:00
def forward(self, X):
2026-04-10 08:11:49 +00:00
Y = F.relu(self.bn1(self.conv1(X)))
Y = self.bn2(self.conv2(Y))
if self.conv3 is not None:
X = self.conv3(X)
Y += X
return F.relu(Y)
2026-04-16 05:55:02 +00:00
class Net(nn.Module):
"""
模型的主要结构就在这里了到时也好该和调用
现在必须实现的方法
目前还是以图片缩放到256256构建残差块
"""
def __init__(self):
2026-04-16 05:55:02 +00:00
super().__init__()
# 定义残差块的辅助方法
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
"""
:param input_channels: 输入维度
:param num_channels: 输出维度
:param num_residuals: 单个残差层的残差块数
:param first_block: 第一块不用下采样 特殊控制
:return: list[nn.Module]
"""
blk = []
for i in range(num_residuals):
if i == 0 and not first_block:
blk.append(Resblock(input_channels, num_channels, use_1x1conv=True, strides=2))
else:
blk.append(Resblock(num_channels, num_channels))
return blk
# 构建网络各层
self.b1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
)
"""
7×7 卷积层输出通道 64步长 2填充 3
(3×256×256)->(64×128×128)
批归一化 relu层
最大池化
(64×128×128)->(64×64×64)
"""
2026-04-16 05:55:02 +00:00
self.b2 = nn.Sequential(*resnet_block(64, 64, num_residuals=3, first_block=True))
self.b3 = nn.Sequential(*resnet_block(64, 128, num_residuals=4))
self.b4 = nn.Sequential(*resnet_block(128, 256, num_residuals=6))
self.b5 = nn.Sequential(*resnet_block(256, 512, num_residuals=3))
2026-04-16 05:55:02 +00:00
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.flatten = nn.Flatten()
self.fc = nn.Linear(512, 4)
2026-04-16 05:55:02 +00:00
def forward(self, x):
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
x = self.b4(x)
x = self.b5(x)
x = self.avgpool(x)
x = self.flatten(x)
x = self.fc(x)
return x
2026-04-10 08:11:49 +00:00
2026-04-16 05:55:02 +00:00
if __name__ == '__main__':
model = Net()
# 使用 torchsummary 查看模型结构
summary(model, input_size=(3, 256, 256))