提高数据加载鲁棒性
This commit is contained in:
parent
cff98f70cc
commit
968e108857
1 changed files with 68 additions and 2 deletions
|
|
@ -14,7 +14,73 @@ from PIL import Image
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
|
||||||
|
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||||
|
|
||||||
|
|
||||||
|
class RobustImageFolder(Dataset):
|
||||||
|
"""包装 ImageFolder,自动跳过损坏图片,带进度条"""
|
||||||
|
|
||||||
|
def __init__(self, root, transform=None):
|
||||||
|
self.transform = transform
|
||||||
|
self.samples = []
|
||||||
|
self.classes = []
|
||||||
|
self.class_to_idx = {}
|
||||||
|
|
||||||
|
# 先构建原始的 ImageFolder 来获取类别信息
|
||||||
|
temp_dataset = datasets.ImageFolder(root, transform=None)
|
||||||
|
self.classes = temp_dataset.classes
|
||||||
|
self.class_to_idx = temp_dataset.class_to_idx
|
||||||
|
|
||||||
|
# 带进度条扫描
|
||||||
|
print(f"\n正在扫描: {root}")
|
||||||
|
print(f"发现 {len(temp_dataset.samples)} 个文件,开始验证...\n")
|
||||||
|
|
||||||
|
corrupted_count = 0
|
||||||
|
success_count = 0
|
||||||
|
|
||||||
|
# 使用 tqdm 显示进度
|
||||||
|
for path, label in tqdm(temp_dataset.samples,
|
||||||
|
desc="验证图片完整性",
|
||||||
|
unit="张",
|
||||||
|
ncols=80):
|
||||||
|
try:
|
||||||
|
self.samples.append((path, label))
|
||||||
|
success_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
corrupted_count += 1
|
||||||
|
# 可选:只打印前10个错误,避免刷屏
|
||||||
|
if corrupted_count <= 10:
|
||||||
|
tqdm.write(f"⚠️ 跳过损坏: {os.path.basename(path)}")
|
||||||
|
elif corrupted_count == 11:
|
||||||
|
tqdm.write(f"⚠️ 后续损坏图片将不再显示...")
|
||||||
|
|
||||||
|
print(f"\n✅ 扫描完成!")
|
||||||
|
print(f" 📁 有效图片: {success_count} 张")
|
||||||
|
print(f" ❌ 损坏跳过: {corrupted_count} 张")
|
||||||
|
print(f" 📊 总计: {len(self.samples)} 张\n")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.samples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
path, label = self.samples[idx]
|
||||||
|
try:
|
||||||
|
img = Image.open(path).convert('RGB')
|
||||||
|
if self.transform:
|
||||||
|
img = self.transform(img)
|
||||||
|
return img, label
|
||||||
|
except Exception as e:
|
||||||
|
# 极少数情况,返回下一个
|
||||||
|
return self.__getitem__((idx + 1) % len(self))
|
||||||
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
image_size=256,
|
image_size=256,
|
||||||
|
|
@ -77,12 +143,12 @@ def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
||||||
# 2. 加载数据集
|
# 2. 加载数据集
|
||||||
# ==================================
|
# ==================================
|
||||||
print("使用独立的 val 文件夹")
|
print("使用独立的 val 文件夹")
|
||||||
train_dataset = datasets.ImageFolder(
|
train_dataset = RobustImageFolder(
|
||||||
root=os.path.join(data_root, 'train'),
|
root=os.path.join(data_root, 'train'),
|
||||||
transform=train_transform if augment else val_transform
|
transform=train_transform if augment else val_transform
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataset = datasets.ImageFolder(
|
val_dataset = RobustImageFolder(
|
||||||
root=os.path.join(data_root, 'val'),
|
root=os.path.join(data_root, 'val'),
|
||||||
transform=val_transform
|
transform=val_transform
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue