From 968e108857f98efc8bfdaca5afbd7098485294bc Mon Sep 17 00:00:00 2001 From: yukun-hh Date: Thu, 16 Apr 2026 20:57:48 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=90=E9=AB=98=E6=95=B0=E6=8D=AE=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E9=B2=81=E6=A3=92=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dataloader.py | 70 +++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/Dataloader.py b/Dataloader.py index 19b060a..5d51e3f 100644 --- a/Dataloader.py +++ b/Dataloader.py @@ -14,7 +14,73 @@ from PIL import Image import matplotlib.pyplot as plt import numpy as np import pandas as pd +from torch.utils.data import Dataset +from PIL import Image, ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +from tqdm import tqdm +from torch.utils.data import Dataset +from PIL import Image, ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +class RobustImageFolder(Dataset): + """包装 ImageFolder,自动跳过损坏图片,带进度条""" + + def __init__(self, root, transform=None): + self.transform = transform + self.samples = [] + self.classes = [] + self.class_to_idx = {} + + # 先构建原始的 ImageFolder 来获取类别信息 + temp_dataset = datasets.ImageFolder(root, transform=None) + self.classes = temp_dataset.classes + self.class_to_idx = temp_dataset.class_to_idx + + # 带进度条扫描 + print(f"\n正在扫描: {root}") + print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n") + + corrupted_count = 0 + success_count = 0 + + # 使用 tqdm 显示进度 + for path, label in tqdm(temp_dataset.samples, + desc="验证图片完整性", + unit="张", + ncols=80): + try: + self.samples.append((path, label)) + success_count += 1 + except Exception as e: + corrupted_count += 1 + # 可选:只打印前10个错误,避免刷屏 + if corrupted_count <= 10: + tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}") + elif corrupted_count == 11: + tqdm.write(f"⚠️ 后续损坏图片将不再显示...") + + print(f"\n✅ 扫描完成!") + print(f" 📁 有效图片: {success_count} 张") + print(f" ❌ 损坏跳过: {corrupted_count} 张") + print(f" 📊 总计: {len(self.samples)} 张\n") + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, label = self.samples[idx] + try: + img = Image.open(path).convert('RGB') + if self.transform: + img = self.transform(img) + return img, label + except Exception as e: + # 极少数情况,返回下一个 + return self.__getitem__((idx + 1) % len(self)) def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', batch_size=32, image_size=256, @@ -77,12 +143,12 @@ def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', # 2. 加载数据集 # ================================== print("使用独立的 val 文件夹") - train_dataset = datasets.ImageFolder( + train_dataset = RobustImageFolder( root=os.path.join(data_root, 'train'), transform=train_transform if augment else val_transform ) - val_dataset = datasets.ImageFolder( + val_dataset = RobustImageFolder( root=os.path.join(data_root, 'val'), transform=val_transform )