From 4a00bd43b49f8068f0fa691607150461515be8d5 Mon Sep 17 00:00:00 2001 From: weikaiwen348-code Date: Sun, 12 Apr 2026 13:55:41 +0800 Subject: [PATCH 1/4] Create merge_classes.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 自动生成类别映射并合并四大类数据集的脚本(ai生成还未校验) --- data/merge_classes.py | 83 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 data/merge_classes.py diff --git a/data/merge_classes.py b/data/merge_classes.py new file mode 100644 index 0000000..c14dde6 --- /dev/null +++ b/data/merge_classes.py @@ -0,0 +1,83 @@ +import os +import shutil + +# ================= 1. 配置你的路径 ================= +# 注意:请确保相对路径正确,以下为示例 +ORIGINAL_DATA_DIR = './garbage265' # 原始数据集的目录 +NEW_DATA_DIR = './garbage_4_classes' # 合并后的新目录 +CLASSNAME_FILE = './garbage265/classname.txt' # txt 文件的位置 +# =================================================== + +def build_mapping(): + """让 Python 自动读取 txt 文件并建立映射字典""" + mapping = {} + + # 打开并读取文件 + with open(CLASSNAME_FILE, 'r', encoding='utf-8') as f: + lines = f.read().splitlines() + + for idx, line in enumerate(lines): + # 过滤掉空行 + if '-' not in line: + continue + + # 用 '-' 把字符串一分为二:前面的做大类,后面的做小类 + big_class, small_class = line.split('-', 1) + + # 核心:修复原数据集中的错别字 Bug + if big_class == '其它垃圾': + big_class = '其他垃圾' + + # 为了绝对安全,我们把三种可能出现的文件夹名字全存进字典里: + mapping[str(idx)] = big_class # 应对文件夹名为数字 ID (如 '0') 的情况 + mapping[line] = big_class # 应对文件夹名为完整名称 (如 '厨余垃圾-八宝粥') 的情况 + mapping[small_class] = big_class # 应对文件夹名为小类名称 (如 '八宝粥') 的情况 + + return mapping + +def merge_dataset(): + print("正在解析类别映射文件...") + class_mapping = build_mapping() + + # 同时处理训练集和验证集 + for split in ['train', 'val']: + original_split_dir = os.path.join(ORIGINAL_DATA_DIR, split) + new_split_dir = os.path.join(NEW_DATA_DIR, split) + + if not os.path.exists(original_split_dir): + print(f"⚠️ 找不到文件夹: {original_split_dir},跳过处理。") + continue + + print(f"\n🚀 开始合并 [{split}] 集合...") + + # 遍历原始的 265 个文件夹 + for sub_class in os.listdir(original_split_dir): + sub_class_path = os.path.join(original_split_dir, sub_class) + + # 忽略隐藏文件或说明文件,确保只处理文件夹 + if not os.path.isdir(sub_class_path): + continue + + # 核心:通过字典查询这个小类属于哪个大类 + target_big_class = class_mapping.get(sub_class, "未知分类") + + target_dir = os.path.join(new_split_dir, target_big_class) + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + # 获取该小类文件夹下的所有图片并开始搬运 + images = os.listdir(sub_class_path) + for img in images: + src_img_path = os.path.join(sub_class_path, img) + + # 给新图片加个前缀,防止不同小类有同名图片(比如分别叫 001.jpg 导致互相覆盖) + new_img_name = f"{sub_class}_{img}" + dst_img_path = os.path.join(target_dir, new_img_name) + + # 执行复制操作 + shutil.copy(src_img_path, dst_img_path) + + print(f"✅ [{split}] 集合四大类合并完成!") + +if __name__ == '__main__': + merge_dataset() \ No newline at end of file From 934b1d2a7da0f590d4db433492e05b35ba1cd3f4 Mon Sep 17 00:00:00 2001 From: weikaiwen348-code Date: Mon, 13 Apr 2026 16:26:26 +0800 Subject: [PATCH 2/4] Update merge_classes.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 完成:编写 265 类至 4 大类的数据集自动化合并脚本。 --- data/merge_classes.py | 43 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/data/merge_classes.py b/data/merge_classes.py index c14dde6..a605aa2 100644 --- a/data/merge_classes.py +++ b/data/merge_classes.py @@ -1,13 +1,30 @@ +"""将原数据集合并为我们需要的四个大类 + 运行时先配置路径 + + author: + weikaiwen + + 厨余垃圾-1 + 可回收物-2 + 其他垃圾-3 + 有害垃圾-4 + + 未知-0 +""" + + import os import shutil # ================= 1. 配置你的路径 ================= # 注意:请确保相对路径正确,以下为示例 -ORIGINAL_DATA_DIR = './garbage265' # 原始数据集的目录 -NEW_DATA_DIR = './garbage_4_classes' # 合并后的新目录 -CLASSNAME_FILE = './garbage265/classname.txt' # txt 文件的位置 +ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 +NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 +CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 # =================================================== + + def build_mapping(): """让 Python 自动读取 txt 文件并建立映射字典""" mapping = {} @@ -24,14 +41,24 @@ def build_mapping(): # 用 '-' 把字符串一分为二:前面的做大类,后面的做小类 big_class, small_class = line.split('-', 1) - # 核心:修复原数据集中的错别字 Bug + # 修改错别字 if big_class == '其它垃圾': big_class = '其他垃圾' - # 为了绝对安全,我们把三种可能出现的文件夹名字全存进字典里: + + # 核心:变为数字分类 + if big_class == '厨余垃圾': + big_class = '1' + elif big_class == '可回收物': + big_class = '2' + elif big_class == '其他垃圾': + big_class = '3' + else : + big_class = '4' + + + # 把文件夹名字全存进字典里: mapping[str(idx)] = big_class # 应对文件夹名为数字 ID (如 '0') 的情况 - mapping[line] = big_class # 应对文件夹名为完整名称 (如 '厨余垃圾-八宝粥') 的情况 - mapping[small_class] = big_class # 应对文件夹名为小类名称 (如 '八宝粥') 的情况 return mapping @@ -59,7 +86,7 @@ def merge_dataset(): continue # 核心:通过字典查询这个小类属于哪个大类 - target_big_class = class_mapping.get(sub_class, "未知分类") + target_big_class = class_mapping.get(sub_class, "0") target_dir = os.path.join(new_split_dir, target_big_class) if not os.path.exists(target_dir): From 543e833fd07314156e728f0287e1c4397e5d17d7 Mon Sep 17 00:00:00 2001 From: weikaiwen348-code Date: Mon, 13 Apr 2026 16:44:23 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E8=B7=AF=E5=BE=84?= =?UTF-8?q?=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/merge_classes.py => merge_classes.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename data/merge_classes.py => merge_classes.py (100%) diff --git a/data/merge_classes.py b/merge_classes.py similarity index 100% rename from data/merge_classes.py rename to merge_classes.py From 1350cdd3197e7749b8edd7d3fe88fa1d7508f414 Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Mon, 13 Apr 2026 22:20:28 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=B8=85=E7=90=86?= =?UTF-8?q?=E7=A8=8B=E5=BA=8F=E6=94=B9=E6=88=90=E7=9B=B8=E5=AF=B9=E8=B7=AF?= =?UTF-8?q?=E5=BE=84=20=E5=AE=8C=E6=88=90dataloader?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dataloader.py | 7 +++---- merge_classes.py => Merge_classes.py | 8 ++++---- Train.py | 13 +++++++++---- 3 files changed, 16 insertions(+), 12 deletions(-) rename merge_classes.py => Merge_classes.py (90%) diff --git a/Dataloader.py b/Dataloader.py index fab833b..7d9ec3e 100644 --- a/Dataloader.py +++ b/Dataloader.py @@ -15,7 +15,7 @@ import matplotlib.pyplot as plt import numpy as np -def create_dataloaders(data_root='..', +def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', batch_size=32, image_size=256, val_split=0.2, @@ -111,9 +111,8 @@ def create_dataloaders(data_root='..', ) # 4. 获取类别名称 - class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3'] + class_names = train_dataset.classes print(f"类别: {class_names}") - print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}") return train_loader, val_loader, class_names @@ -158,7 +157,7 @@ def visualize_batch(dataloader, class_names, num_images=8): if __name__ == '__main__': train_loader, val_loader, class_names = create_dataloaders( - data_root='..', # 与trash-division同级文件夹 + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 batch_size=32, # 根据你的显存调整 image_size=256, # 与你模型输入一致 num_workers=4, # Windows 可能需设为 0 diff --git a/merge_classes.py b/Merge_classes.py similarity index 90% rename from merge_classes.py rename to Merge_classes.py index a605aa2..145c345 100644 --- a/merge_classes.py +++ b/Merge_classes.py @@ -1,5 +1,5 @@ """将原数据集合并为我们需要的四个大类 - 运行时先配置路径 + 已修改成相对路径 具体配置方法详见README.md author: weikaiwen @@ -18,9 +18,9 @@ import shutil # ================= 1. 配置你的路径 ================= # 注意:请确保相对路径正确,以下为示例 -ORIGINAL_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data' # 原始数据集的目录 -NEW_DATA_DIR = '/Users/weikaiwen/Desktop/trash_division_data/ultimate_4_class' # 合并后的新目录 -CLASSNAME_FILE = '/Users/weikaiwen/Desktop/trash_division_data/val/classname.txt' # txt 文件的位置 +ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录 +NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录 +CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置 # =================================================== diff --git a/Train.py b/Train.py index 6c87c69..e5d462c 100644 --- a/Train.py +++ b/Train.py @@ -12,7 +12,7 @@ import torch.optim as optim from tqdm import tqdm # 进度条,可选 import matplotlib.pyplot as plt from Model import Net - +from Dataloader import create_dataloaders def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch): """训练一个epoch""" model.train() # 设置为训练模式 @@ -171,11 +171,16 @@ def plot_training_history(history): # ========== 使用示例 ========== if __name__ == '__main__': # 假设你的 dataloader 已经写好了 - # train_loader = ... - # val_loader = ... + train_loader, val_loader, class_names = create_dataloaders( + data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹 + batch_size=32, # 根据你的显存调整 + image_size=256, # 与你模型输入一致 + num_workers=4, # Windows 可能需设为 0 + augment=True # 训练时使用数据增强 + ) # 1. 创建模型 - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu') model = Net().get_network() # 根据你的 Net 类调整 model = model.to(device)