diff --git a/Dataloader.py b/Dataloader.py index fab833b..7d9ec3e 100644 --- a/Dataloader.py +++ b/Dataloader.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt import numpy as np -def create_dataloaders(data_root='..', +def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', batch_size=32, image_size=256, val_split=0.2, @@ -111,9 +111,8 @@ def create_dataloaders(data_root='..', ) # 4. 获取类别名称 - class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3'] + class_names = train_dataset.classes print(f"类别: {class_names}") - print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}") return train_loader, val_loader, class_names @@ -158,7 +157,7 @@ def visualize_batch(dataloader, class_names, num_images=8): if __name__ == '__main__': train_loader, val_loader, class_names = create_dataloaders( - data_root='..', # 与trash-division同级文件夹 + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 batch_size=32, # 根据你的显存调整 image_size=256, # 与你模型输入一致 num_workers=4, # Windows 可能需设为 0 diff --git a/merge_classes.py b/Merge_classes.py similarity index 90% rename from merge_classes.py rename to Merge_classes.py index a605aa2..145c345 100644 --- a/merge_classes.py +++ b/Merge_classes.py @@ -1,5 +1,5 @@ """将原数据集合并为我们需要的四个大类 - 运行时先配置路径 + 已修改成相对路径 具体配置方法详见README.md author: weikaiwen @@ -18,9 +18,9 @@ import shutil # ================= 1. 配置你的路径 ================= # 注意:请确保相对路径正确,以下为示例 -ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 -NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 -CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 +ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录 +NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录 +CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置 # =================================================== diff --git a/Train.py b/Train.py index 6c87c69..e5d462c 100644 --- a/Train.py +++ b/Train.py @@ -12,7 +12,7 @@ import torch.optim as optim from tqdm import tqdm # 进度条,可选 import matplotlib.pyplot as plt from Model import Net - +from Dataloader import create_dataloaders def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): """训练一个epoch""" model.train() # 设置为训练模式 @@ -171,11 +171,16 @@ def plot_training_history(history): # ========== 使用示例 ========== if __name__ == '__main__': # 假设你的 dataloader 已经写好了 - # train_loader = ... - # val_loader = ... + train_loader, val_loader, class_names = create_dataloaders( + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 + batch_size=32, # 根据你的显存调整 + image_size=256, # 与你模型输入一致 + num_workers=4, # Windows 可能需设为 0 + augment=True # 训练时使用数据增强 + ) # 1. 创建模型 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu') model = Net().get_network() # 根据你的 Net 类调整 model = model.to(device)