提高数据加载鲁棒性

This commit is contained in:
yukun-hh 2026-04-16 20:57:48 +08:00
parent cff98f70cc
commit 968e108857

View file

@ -14,7 +14,73 @@ from PIL import Image
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from tqdm import tqdm
from torch.utils.data import Dataset
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
class RobustImageFolder(Dataset):
"""包装 ImageFolder自动跳过损坏图片带进度条"""
def __init__(self, root, transform=None):
self.transform = transform
self.samples = []
self.classes = []
self.class_to_idx = {}
# 先构建原始的 ImageFolder 来获取类别信息
temp_dataset = datasets.ImageFolder(root, transform=None)
self.classes = temp_dataset.classes
self.class_to_idx = temp_dataset.class_to_idx
# 带进度条扫描
print(f"\n正在扫描: {root}")
print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
corrupted_count = 0
success_count = 0
# 使用 tqdm 显示进度
for path, label in tqdm(temp_dataset.samples,
desc="验证图片完整性",
unit="",
ncols=80):
try:
self.samples.append((path, label))
success_count += 1
except Exception as e:
corrupted_count += 1
# 可选只打印前10个错误避免刷屏
if corrupted_count <= 10:
tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
elif corrupted_count == 11:
tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
print(f"\n✅ 扫描完成!")
print(f" 📁 有效图片: {success_count}")
print(f" ❌ 损坏跳过: {corrupted_count}")
print(f" 📊 总计: {len(self.samples)}\n")
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
path, label = self.samples[idx]
try:
img = Image.open(path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, label
except Exception as e:
# 极少数情况,返回下一个
return self.__getitem__((idx + 1) % len(self))
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/', def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
batch_size=32, batch_size=32,
image_size=256, image_size=256,
@ -77,12 +143,12 @@ def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
# 2. 加载数据集 # 2. 加载数据集
# ================================== # ==================================
print("使用独立的 val 文件夹") print("使用独立的 val 文件夹")
train_dataset = datasets.ImageFolder( train_dataset = RobustImageFolder(
root=os.path.join(data_root, 'train'), root=os.path.join(data_root, 'train'),
transform=train_transform if augment else val_transform transform=train_transform if augment else val_transform
) )
val_dataset = datasets.ImageFolder( val_dataset = RobustImageFolder(
root=os.path.join(data_root, 'val'), root=os.path.join(data_root, 'val'),
transform=val_transform transform=val_transform
) )