模型规范设计
This commit is contained in:
parent
793852eedd
commit
b2f7a9c172
3 changed files with 64 additions and 86 deletions
|
|
@ -158,9 +158,9 @@ def visualize_batch(dataloader, class_names, num_images=8):
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_loader, val_loader, class_names = create_dataloaders(
|
train_loader, val_loader, class_names = create_dataloaders(
|
||||||
data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
|
data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
|
||||||
batch_size=32, # 根据你的显存调整
|
batch_size=16, # 根据你的显存调整
|
||||||
image_size=256, # 与你模型输入一致
|
image_size=256, # 与你模型输入一致
|
||||||
num_workers=4, # Windows 可能需设为 0
|
num_workers=16, # Windows 可能需设为 0
|
||||||
augment=True # 训练时使用数据增强
|
augment=True # 训练时使用数据增强
|
||||||
)
|
)
|
||||||
visualize_batch(train_loader, class_names, num_images=8)
|
visualize_batch(train_loader, class_names, num_images=8)
|
||||||
|
|
|
||||||
69
Model.py
69
Model.py
|
|
@ -6,16 +6,16 @@ author : yukun-hh
|
||||||
date : 2026-4-10
|
date : 2026-4-10
|
||||||
|
|
||||||
"""
|
"""
|
||||||
#神经网络模型库
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torchsummary import summary
|
from torchsummary import summary
|
||||||
|
|
||||||
|
|
||||||
# 残差块
|
# 残差块
|
||||||
class Resblock(nn.Module):
|
class Resblock(nn.Module):
|
||||||
def __init__(self, input_channels, output_channels, use_1x1conv=False, strides=1):
|
def __init__(self, input_channels, output_channels, use_1x1conv=False, strides=1):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
:param input_channels: 进入残差块时的原通道
|
:param input_channels: 进入残差块时的原通道
|
||||||
:param output_channels: 输出时的通道数
|
:param output_channels: 输出时的通道数
|
||||||
:param use_1x1conv: 如果输入和输出通道不相等时,需要用一个1x1的卷积层对原来的输入进行一个通道提升
|
:param use_1x1conv: 如果输入和输出通道不相等时,需要用一个1x1的卷积层对原来的输入进行一个通道提升
|
||||||
|
|
@ -30,6 +30,7 @@ class Resblock(nn.Module):
|
||||||
self.conv3 = None
|
self.conv3 = None
|
||||||
self.bn1 = nn.BatchNorm2d(output_channels)
|
self.bn1 = nn.BatchNorm2d(output_channels)
|
||||||
self.bn2 = nn.BatchNorm2d(output_channels)
|
self.bn2 = nn.BatchNorm2d(output_channels)
|
||||||
|
|
||||||
def forward(self, X):
|
def forward(self, X):
|
||||||
Y = F.relu(self.bn1(self.conv1(X)))
|
Y = F.relu(self.bn1(self.conv1(X)))
|
||||||
Y = self.bn2(self.conv2(Y))
|
Y = self.bn2(self.conv2(Y))
|
||||||
|
|
@ -38,15 +39,19 @@ class Resblock(nn.Module):
|
||||||
Y += X
|
Y += X
|
||||||
return F.relu(Y)
|
return F.relu(Y)
|
||||||
|
|
||||||
class Net():
|
|
||||||
|
class Net(nn.Module):
|
||||||
"""
|
"""
|
||||||
模型的主要结构就在这里了,到时也好该和调用
|
模型的主要结构就在这里了,到时也好该和调用
|
||||||
现在必须实现的方法:
|
现在必须实现的方法:
|
||||||
目前还是以图片缩放到256*256构建残差块
|
目前还是以图片缩放到256*256构建残差块
|
||||||
"""
|
"""
|
||||||
net = nn.Sequential()
|
|
||||||
def resnet_block(self,input_channels, num_channels, num_residuals,
|
def __init__(self):
|
||||||
first_block=False):
|
super().__init__()
|
||||||
|
|
||||||
|
# 定义残差块的辅助方法
|
||||||
|
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
|
||||||
"""
|
"""
|
||||||
:param input_channels: 输入维度
|
:param input_channels: 输入维度
|
||||||
:param num_channels: 输出维度
|
:param num_channels: 输出维度
|
||||||
|
|
@ -55,17 +60,18 @@ class Net():
|
||||||
:return: list[nn.Module]
|
:return: list[nn.Module]
|
||||||
"""
|
"""
|
||||||
blk = []
|
blk = []
|
||||||
|
|
||||||
for i in range(num_residuals):
|
for i in range(num_residuals):
|
||||||
if i == 0 and not first_block:
|
if i == 0 and not first_block:
|
||||||
blk.append(Resblock(input_channels, num_channels,
|
blk.append(Resblock(input_channels, num_channels, use_1x1conv=True, strides=2))
|
||||||
use_1x1conv=True, strides=2))
|
|
||||||
else:
|
else:
|
||||||
blk.append(Resblock(num_channels, num_channels))
|
blk.append(Resblock(num_channels, num_channels))
|
||||||
return blk
|
return blk
|
||||||
def __init__(self):
|
|
||||||
b1 = nn.Sequential( nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
|
# 构建网络各层
|
||||||
nn.BatchNorm2d(64), nn.ReLU(),
|
self.b1 = nn.Sequential(
|
||||||
|
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
|
||||||
|
nn.BatchNorm2d(64),
|
||||||
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
@ -75,25 +81,28 @@ class Net():
|
||||||
最大池化
|
最大池化
|
||||||
(64×128×128)->(64×64×64)
|
(64×128×128)->(64×64×64)
|
||||||
"""
|
"""
|
||||||
b2 = nn.Sequential(*self.resnet_block(64, 64, num_residuals=3, first_block=True))
|
self.b2 = nn.Sequential(*resnet_block(64, 64, num_residuals=3, first_block=True))
|
||||||
b3 = nn.Sequential(*self.resnet_block(64, 128, num_residuals=4))
|
self.b3 = nn.Sequential(*resnet_block(64, 128, num_residuals=4))
|
||||||
b4 = nn.Sequential(*self.resnet_block(128, 256, num_residuals=6))
|
self.b4 = nn.Sequential(*resnet_block(128, 256, num_residuals=6))
|
||||||
b5 = nn.Sequential(*self.resnet_block(256, 512, num_residuals=3))
|
self.b5 = nn.Sequential(*resnet_block(256, 512, num_residuals=3))
|
||||||
self.net = nn.Sequential(b1, b2, b3, b4, b5,nn.AdaptiveAvgPool2d((1,1)),nn.Flatten(), nn.Linear(512, 4))
|
|
||||||
def get_network(self):
|
|
||||||
return self.net
|
|
||||||
|
|
||||||
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
||||||
|
self.flatten = nn.Flatten()
|
||||||
|
self.fc = nn.Linear(512, 4)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.b1(x)
|
||||||
|
x = self.b2(x)
|
||||||
|
x = self.b3(x)
|
||||||
|
x = self.b4(x)
|
||||||
|
x = self.b5(x)
|
||||||
|
x = self.avgpool(x)
|
||||||
|
x = self.flatten(x)
|
||||||
|
x = self.fc(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
Net_new = Net()
|
model = Net()
|
||||||
X = torch.rand(size=(1, 3, 256, 256))
|
# 使用 torchsummary 查看模型结构
|
||||||
summary(Net_new.get_network(), input_size=(3, 256, 256))
|
summary(model, input_size=(3, 256, 256))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
37
Train.py
37
Train.py
|
|
@ -81,14 +81,11 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
|
||||||
criterion = nn.CrossEntropyLoss() # 多分类用交叉熵
|
criterion = nn.CrossEntropyLoss() # 多分类用交叉熵
|
||||||
|
|
||||||
# 优化器选择(推荐 Adam 或 SGD)
|
# 优化器选择(推荐 Adam 或 SGD)
|
||||||
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
|
|
||||||
# 或者使用 SGD + 动量
|
# 或者使用 SGD + 动量
|
||||||
# optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
|
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-4)
|
||||||
|
|
||||||
# 学习率调度器(可选,帮助收敛)
|
# 学习率调度器(可选,帮助收敛)
|
||||||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1)
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||||
# 或者用余弦退火
|
|
||||||
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
|
||||||
|
|
||||||
# 2. 记录训练历史
|
# 2. 记录训练历史
|
||||||
history = {
|
history = {
|
||||||
|
|
@ -133,7 +130,6 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
|
||||||
print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)')
|
print(f'✓ 保存最佳模型 (Acc: {val_acc:.2f}%)')
|
||||||
|
|
||||||
# 4. 绘制训练曲线
|
# 4. 绘制训练曲线
|
||||||
plot_training_history(history)
|
|
||||||
|
|
||||||
print(f'\n{"=" * 50}')
|
print(f'\n{"=" * 50}')
|
||||||
print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%')
|
print(f'训练完成!最佳验证准确率: {best_val_acc:.2f}%')
|
||||||
|
|
@ -141,33 +137,6 @@ def train(model, train_loader, val_loader, epochs=50, lr=0.001, device='cuda'):
|
||||||
return model, history
|
return model, history
|
||||||
|
|
||||||
|
|
||||||
def plot_training_history(history):
|
|
||||||
"""绘制训练曲线"""
|
|
||||||
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
|
||||||
|
|
||||||
# 损失曲线
|
|
||||||
ax1.plot(history['train_loss'], label='Train Loss')
|
|
||||||
ax1.plot(history['val_loss'], label='Val Loss')
|
|
||||||
ax1.set_xlabel('Epoch')
|
|
||||||
ax1.set_ylabel('Loss')
|
|
||||||
ax1.set_title('Training and Validation Loss')
|
|
||||||
ax1.legend()
|
|
||||||
ax1.grid(True)
|
|
||||||
|
|
||||||
# 准确率曲线
|
|
||||||
ax2.plot(history['train_acc'], label='Train Acc')
|
|
||||||
ax2.plot(history['val_acc'], label='Val Acc')
|
|
||||||
ax2.set_xlabel('Epoch')
|
|
||||||
ax2.set_ylabel('Accuracy (%)')
|
|
||||||
ax2.set_title('Training and Validation Accuracy')
|
|
||||||
ax2.legend()
|
|
||||||
ax2.grid(True)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
plt.savefig('training_history.png', dpi=150)
|
|
||||||
plt.show()
|
|
||||||
|
|
||||||
|
|
||||||
# ========== 使用示例 ==========
|
# ========== 使用示例 ==========
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 假设你的 dataloader 已经写好了
|
# 假设你的 dataloader 已经写好了
|
||||||
|
|
@ -181,7 +150,7 @@ if __name__ == '__main__':
|
||||||
|
|
||||||
# 1. 创建模型
|
# 1. 创建模型
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
||||||
model = Net().get_network() # 根据你的 Net 类调整
|
model = Net() # 根据你的 Net 类调整
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
# 打印模型信息
|
# 打印模型信息
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue