提高数据加载鲁棒性

This commit is contained in:
yukun-hh 2026-04-16 20:57:48 +08:00
parent cff98f70cc
commit 968e108857

View file

@ -14,7 +14,73 @@ from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class RobustImageFolder(Dataset):
"""包装 ImageFolder自动跳过损坏图片带进度条"""
def __init__(self, root, transform=None):
self.transform = transform
self.samples = []
self.classes = []
self.class_to_idx = {}
# 先构建原始的 ImageFolder 来获取类别信息
temp_dataset = datasets.ImageFolder(root, transform=None)
self.classes = temp_dataset.classes
self.class_to_idx = temp_dataset.class_to_idx
# 带进度条扫描
print(f"\n正在扫描: {root}")
print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
corrupted_count = 0
success_count = 0
# 使用 tqdm 显示进度
for path, label in tqdm(temp_dataset.samples,
desc="验证图片完整性",
unit="",
ncols=80):
try:
self.samples.append((path, label))
success_count += 1
except Exception as e:
corrupted_count += 1
# 可选只打印前10个错误避免刷屏
if corrupted_count <= 10:
tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
elif corrupted_count == 11:
tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
print(f"\n✅ 扫描完成!")
print(f" 📁 有效图片: {success_count}")
print(f" ❌ 损坏跳过: {corrupted_count}")
print(f" 📊 总计: {len(self.samples)}\n")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
try:
img = Image.open(path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
except Exception as e:
# 极少数情况,返回下一个
return self.__getitem__((idx + 1) % len(self))
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
batch_size=32,
image_size=256,
@ -77,12 +143,12 @@ def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
# 2. 加载数据集
# ==================================
print("使用独立的 val 文件夹")
train_dataset = datasets.ImageFolder(
train_dataset = RobustImageFolder(
root=os.path.join(data_root, 'train'),
transform=train_transform if augment else val_transform
)
val_dataset = datasets.ImageFolder(
val_dataset = RobustImageFolder(
root=os.path.join(data_root, 'val'),
transform=val_transform
)