提高数据加载鲁棒性
This commit is contained in:
parent
cff98f70cc
commit
968e108857
1 changed files with 68 additions and 2 deletions
|
|
@ -14,7 +14,73 @@ from PIL import Image
|
|||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
from tqdm import tqdm
|
||||
from torch.utils.data import Dataset
|
||||
from PIL import Image, ImageFile
|
||||
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class RobustImageFolder(Dataset):
|
||||
"""包装 ImageFolder,自动跳过损坏图片,带进度条"""
|
||||
|
||||
def __init__(self, root, transform=None):
|
||||
self.transform = transform
|
||||
self.samples = []
|
||||
self.classes = []
|
||||
self.class_to_idx = {}
|
||||
|
||||
# 先构建原始的 ImageFolder 来获取类别信息
|
||||
temp_dataset = datasets.ImageFolder(root, transform=None)
|
||||
self.classes = temp_dataset.classes
|
||||
self.class_to_idx = temp_dataset.class_to_idx
|
||||
|
||||
# 带进度条扫描
|
||||
print(f"\n正在扫描: {root}")
|
||||
print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
|
||||
|
||||
corrupted_count = 0
|
||||
success_count = 0
|
||||
|
||||
# 使用 tqdm 显示进度
|
||||
for path, label in tqdm(temp_dataset.samples,
|
||||
desc="验证图片完整性",
|
||||
unit="张",
|
||||
ncols=80):
|
||||
try:
|
||||
self.samples.append((path, label))
|
||||
success_count += 1
|
||||
except Exception as e:
|
||||
corrupted_count += 1
|
||||
# 可选:只打印前10个错误,避免刷屏
|
||||
if corrupted_count <= 10:
|
||||
tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
|
||||
elif corrupted_count == 11:
|
||||
tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
|
||||
|
||||
print(f"\n✅ 扫描完成!")
|
||||
print(f" 📁 有效图片: {success_count} 张")
|
||||
print(f" ❌ 损坏跳过: {corrupted_count} 张")
|
||||
print(f" 📊 总计: {len(self.samples)} 张\n")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, label = self.samples[idx]
|
||||
try:
|
||||
img = Image.open(path).convert('RGB')
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, label
|
||||
except Exception as e:
|
||||
# 极少数情况,返回下一个
|
||||
return self.__getitem__((idx + 1) % len(self))
|
||||
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
||||
batch_size=32,
|
||||
image_size=256,
|
||||
|
|
@ -77,12 +143,12 @@ def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
|||
# 2. 加载数据集
|
||||
# ==================================
|
||||
print("使用独立的 val 文件夹")
|
||||
train_dataset = datasets.ImageFolder(
|
||||
train_dataset = RobustImageFolder(
|
||||
root=os.path.join(data_root, 'train'),
|
||||
transform=train_transform if augment else val_transform
|
||||
)
|
||||
|
||||
val_dataset = datasets.ImageFolder(
|
||||
val_dataset = RobustImageFolder(
|
||||
root=os.path.join(data_root, 'val'),
|
||||
transform=val_transform
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue