数据清理程序改成相对路径 完成dataloader

This commit is contained in:
yukun-hh 2026-04-13 22:20:28 +08:00
parent 543e833fd0
commit 1350cdd319
3 changed files with 16 additions and 12 deletions

View file

@ -15,7 +15,7 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
def create_dataloaders(data_root='..', def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
batch_size=32, batch_size=32,
image_size=256, image_size=256,
val_split=0.2, val_split=0.2,
@ -111,9 +111,8 @@ def create_dataloaders(data_root='..',
) )
# 4. 获取类别名称 # 4. 获取类别名称
class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3'] class_names = train_dataset.classes
print(f"类别: {class_names}") print(f"类别: {class_names}")
print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}")
return train_loader, val_loader, class_names return train_loader, val_loader, class_names
@ -158,7 +157,7 @@ def visualize_batch(dataloader, class_names, num_images=8):
if __name__ == '__main__': if __name__ == '__main__':
train_loader, val_loader, class_names = create_dataloaders( train_loader, val_loader, class_names = create_dataloaders(
data_root='..', # 与trash-division同级文件夹 data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
batch_size=32, # 根据你的显存调整 batch_size=32, # 根据你的显存调整
image_size=256, # 与你模型输入一致 image_size=256, # 与你模型输入一致
num_workers=4, # Windows 可能需设为 0 num_workers=4, # Windows 可能需设为 0

View file

@ -1,5 +1,5 @@
"""将原数据集合并为我们需要的四个大类 """将原数据集合并为我们需要的四个大类
运行时先配置路径 已修改成相对路径 具体配置方法详见README.md
author author
weikaiwen weikaiwen
@ -18,9 +18,9 @@ import shutil
# ================= 1. 配置你的路径 ================= # ================= 1. 配置你的路径 =================
# 注意:请确保相对路径正确,以下为示例 # 注意:请确保相对路径正确,以下为示例
ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录
NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录
CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置
# =================================================== # ===================================================

View file

@ -12,7 +12,7 @@ import torch.optim as optim
from tqdm import tqdm # 进度条,可选 from tqdm import tqdm # 进度条,可选
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from Model import Net from Model import Net
from Dataloader import create_dataloaders
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
"""训练一个epoch""" """训练一个epoch"""
model.train() # 设置为训练模式 model.train() # 设置为训练模式
@ -171,11 +171,16 @@ def plot_training_history(history):
# ========== 使用示例 ========== # ========== 使用示例 ==========
if __name__ == '__main__': if __name__ == '__main__':
# 假设你的 dataloader 已经写好了 # 假设你的 dataloader 已经写好了
# train_loader = ... train_loader, val_loader, class_names = create_dataloaders(
# val_loader = ... data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
batch_size=32, # 根据你的显存调整
image_size=256, # 与你模型输入一致
num_workers=4, # Windows 可能需设为 0
augment=True # 训练时使用数据增强
)
# 1. 创建模型 # 1. 创建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
model = Net().get_network() # 根据你的 Net 类调整 model = Net().get_network() # 根据你的 Net 类调整
model = model.to(device) model = model.to(device)