Merge branch 'data_cleaning_test'
This commit is contained in:
commit
793852eedd
3 changed files with 122 additions and 8 deletions
|
|
@ -15,7 +15,7 @@ import matplotlib.pyplot as plt
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
|
||||||
def create_dataloaders(data_root='..',
|
def create_dataloaders(data_root='../trash_division_data/ultimate_4_class/',
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
image_size=256,
|
image_size=256,
|
||||||
val_split=0.2,
|
val_split=0.2,
|
||||||
|
|
@ -111,9 +111,8 @@ def create_dataloaders(data_root='..',
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 获取类别名称
|
# 4. 获取类别名称
|
||||||
class_names = train_dataset.classes if hasattr(train_dataset, 'classes') else ['0', '1', '2', '3']
|
class_names = train_dataset.classes
|
||||||
print(f"类别: {class_names}")
|
print(f"类别: {class_names}")
|
||||||
print(f"类别映射: {train_dataset.class_to_idx if hasattr(train_dataset, 'class_to_idx') else '0-3'}")
|
|
||||||
|
|
||||||
return train_loader, val_loader, class_names
|
return train_loader, val_loader, class_names
|
||||||
|
|
||||||
|
|
@ -158,7 +157,7 @@ def visualize_batch(dataloader, class_names, num_images=8):
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
train_loader, val_loader, class_names = create_dataloaders(
|
train_loader, val_loader, class_names = create_dataloaders(
|
||||||
data_root='..', # 与trash-division同级文件夹
|
data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
|
||||||
batch_size=32, # 根据你的显存调整
|
batch_size=32, # 根据你的显存调整
|
||||||
image_size=256, # 与你模型输入一致
|
image_size=256, # 与你模型输入一致
|
||||||
num_workers=4, # Windows 可能需设为 0
|
num_workers=4, # Windows 可能需设为 0
|
||||||
|
|
|
||||||
110
Merge_classes.py
Normal file
110
Merge_classes.py
Normal file
|
|
@ -0,0 +1,110 @@
|
||||||
|
"""将原数据集合并为我们需要的四个大类
|
||||||
|
已修改成相对路径 具体配置方法详见README.md
|
||||||
|
|
||||||
|
author:
|
||||||
|
weikaiwen
|
||||||
|
|
||||||
|
厨余垃圾-1
|
||||||
|
可回收物-2
|
||||||
|
其他垃圾-3
|
||||||
|
有害垃圾-4
|
||||||
|
|
||||||
|
未知-0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
# ================= 1. 配置你的路径 =================
|
||||||
|
# 注意:请确保相对路径正确,以下为示例
|
||||||
|
ORIGINAL_DATA_DIR = '../trash_division_data' # 原始数据集的目录
|
||||||
|
NEW_DATA_DIR = '../trash_division_data/ultimate_4_class' # 合并后的新目录
|
||||||
|
CLASSNAME_FILE = '../trash_division_data/val/classname.txt' # txt 文件的位置
|
||||||
|
# ===================================================
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def build_mapping():
|
||||||
|
"""让 Python 自动读取 txt 文件并建立映射字典"""
|
||||||
|
mapping = {}
|
||||||
|
|
||||||
|
# 打开并读取文件
|
||||||
|
with open(CLASSNAME_FILE, 'r', encoding='utf-8') as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
|
||||||
|
for idx, line in enumerate(lines):
|
||||||
|
# 过滤掉空行
|
||||||
|
if '-' not in line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 用 '-' 把字符串一分为二:前面的做大类,后面的做小类
|
||||||
|
big_class, small_class = line.split('-', 1)
|
||||||
|
|
||||||
|
# 修改错别字
|
||||||
|
if big_class == '其它垃圾':
|
||||||
|
big_class = '其他垃圾'
|
||||||
|
|
||||||
|
|
||||||
|
# 核心:变为数字分类
|
||||||
|
if big_class == '厨余垃圾':
|
||||||
|
big_class = '1'
|
||||||
|
elif big_class == '可回收物':
|
||||||
|
big_class = '2'
|
||||||
|
elif big_class == '其他垃圾':
|
||||||
|
big_class = '3'
|
||||||
|
else :
|
||||||
|
big_class = '4'
|
||||||
|
|
||||||
|
|
||||||
|
# 把文件夹名字全存进字典里:
|
||||||
|
mapping[str(idx)] = big_class # 应对文件夹名为数字 ID (如 '0') 的情况
|
||||||
|
|
||||||
|
return mapping
|
||||||
|
|
||||||
|
def merge_dataset():
|
||||||
|
print("正在解析类别映射文件...")
|
||||||
|
class_mapping = build_mapping()
|
||||||
|
|
||||||
|
# 同时处理训练集和验证集
|
||||||
|
for split in ['train', 'val']:
|
||||||
|
original_split_dir = os.path.join(ORIGINAL_DATA_DIR, split)
|
||||||
|
new_split_dir = os.path.join(NEW_DATA_DIR, split)
|
||||||
|
|
||||||
|
if not os.path.exists(original_split_dir):
|
||||||
|
print(f"⚠️ 找不到文件夹: {original_split_dir},跳过处理。")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"\n🚀 开始合并 [{split}] 集合...")
|
||||||
|
|
||||||
|
# 遍历原始的 265 个文件夹
|
||||||
|
for sub_class in os.listdir(original_split_dir):
|
||||||
|
sub_class_path = os.path.join(original_split_dir, sub_class)
|
||||||
|
|
||||||
|
# 忽略隐藏文件或说明文件,确保只处理文件夹
|
||||||
|
if not os.path.isdir(sub_class_path):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 核心:通过字典查询这个小类属于哪个大类
|
||||||
|
target_big_class = class_mapping.get(sub_class, "0")
|
||||||
|
|
||||||
|
target_dir = os.path.join(new_split_dir, target_big_class)
|
||||||
|
if not os.path.exists(target_dir):
|
||||||
|
os.makedirs(target_dir)
|
||||||
|
|
||||||
|
# 获取该小类文件夹下的所有图片并开始搬运
|
||||||
|
images = os.listdir(sub_class_path)
|
||||||
|
for img in images:
|
||||||
|
src_img_path = os.path.join(sub_class_path, img)
|
||||||
|
|
||||||
|
# 给新图片加个前缀,防止不同小类有同名图片(比如分别叫 001.jpg 导致互相覆盖)
|
||||||
|
new_img_name = f"{sub_class}_{img}"
|
||||||
|
dst_img_path = os.path.join(target_dir, new_img_name)
|
||||||
|
|
||||||
|
# 执行复制操作
|
||||||
|
shutil.copy(src_img_path, dst_img_path)
|
||||||
|
|
||||||
|
print(f"✅ [{split}] 集合四大类合并完成!")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
merge_dataset()
|
||||||
13
Train.py
13
Train.py
|
|
@ -12,7 +12,7 @@ import torch.optim as optim
|
||||||
from tqdm import tqdm # 进度条,可选
|
from tqdm import tqdm # 进度条,可选
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from Model import Net
|
from Model import Net
|
||||||
|
from Dataloader import create_dataloaders
|
||||||
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch):
|
||||||
"""训练一个epoch"""
|
"""训练一个epoch"""
|
||||||
model.train() # 设置为训练模式
|
model.train() # 设置为训练模式
|
||||||
|
|
@ -171,11 +171,16 @@ def plot_training_history(history):
|
||||||
# ========== 使用示例 ==========
|
# ========== 使用示例 ==========
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 假设你的 dataloader 已经写好了
|
# 假设你的 dataloader 已经写好了
|
||||||
# train_loader = ...
|
train_loader, val_loader, class_names = create_dataloaders(
|
||||||
# val_loader = ...
|
data_root='../trash_division_data/ultimate_4_class/', # 与trash-division同级文件夹
|
||||||
|
batch_size=32, # 根据你的显存调整
|
||||||
|
image_size=256, # 与你模型输入一致
|
||||||
|
num_workers=4, # Windows 可能需设为 0
|
||||||
|
augment=True # 训练时使用数据增强
|
||||||
|
)
|
||||||
|
|
||||||
# 1. 创建模型
|
# 1. 创建模型
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'cpu')
|
||||||
model = Net().get_network() # 根据你的 Net 类调整
|
model = Net().get_network() # 根据你的 Net 类调整
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue