新增 HOG + LogisticRegression 纯传统基线模型(零神经网络依赖)
This commit is contained in:
parent
3fee1c82ab
commit
562fc4142a
3 changed files with 153 additions and 3 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -17,6 +17,7 @@
|
||||||
!/baseline/VGG_KNN.py
|
!/baseline/VGG_KNN.py
|
||||||
!/baseline/compare_models.py
|
!/baseline/compare_models.py
|
||||||
!/baseline/ResNet34_Pretrained_10pct.py
|
!/baseline/ResNet34_Pretrained_10pct.py
|
||||||
|
!/baseline/HOG_Baseline.py
|
||||||
!/training_log.csv
|
!/training_log.csv
|
||||||
!/confusion_matrix.png
|
!/confusion_matrix.png
|
||||||
!/roc_curve.png
|
!/roc_curve.png
|
||||||
|
|
|
||||||
147
baseline/HOG_Baseline.py
Normal file
147
baseline/HOG_Baseline.py
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
"""
|
||||||
|
baseline/HOG_Baseline.py
|
||||||
|
HOG + 颜色直方图特征提取 + LogisticRegression 四分类
|
||||||
|
纯传统 CV/ML 基线,零神经网络依赖
|
||||||
|
可独立运行,也可被 compare_models.py 导入
|
||||||
|
author: yukun-hh
|
||||||
|
date: 2026-5-14
|
||||||
|
"""
|
||||||
|
import sys, os
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import matplotlib
|
||||||
|
|
||||||
|
from skimage.feature import hog
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score, f1_score,
|
||||||
|
confusion_matrix, ConfusionMatrixDisplay,
|
||||||
|
classification_report,
|
||||||
|
)
|
||||||
|
|
||||||
|
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||||
|
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# ★★★ 可配置参数 ★★★
|
||||||
|
# ============================================================
|
||||||
|
DATA_ROOT = '../../trash_division_data/ultimate_4_class/'
|
||||||
|
IMAGE_SIZE = 128
|
||||||
|
HOG_ORIENTATIONS = 9
|
||||||
|
HOG_PIXELS_PER_CELL = (8, 8)
|
||||||
|
HOG_CELLS_PER_BLOCK = (2, 2)
|
||||||
|
COLOR_BINS = 32
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||||
|
NUM_CLASSES = 4
|
||||||
|
|
||||||
|
|
||||||
|
def extract_hog_color(image):
|
||||||
|
img = image.convert('RGB').resize((IMAGE_SIZE, IMAGE_SIZE))
|
||||||
|
arr = np.array(img, dtype=np.float64) / 255.0
|
||||||
|
|
||||||
|
hog_feat = hog(arr, orientations=HOG_ORIENTATIONS,
|
||||||
|
pixels_per_cell=HOG_PIXELS_PER_CELL,
|
||||||
|
cells_per_block=HOG_CELLS_PER_BLOCK,
|
||||||
|
channel_axis=2, feature_vector=True)
|
||||||
|
|
||||||
|
color_feat = []
|
||||||
|
for c in range(3):
|
||||||
|
hist, _ = np.histogram(arr[:, :, c], bins=COLOR_BINS, range=(0, 1))
|
||||||
|
color_feat.append(hist)
|
||||||
|
color_feat = np.concatenate(color_feat)
|
||||||
|
|
||||||
|
return np.concatenate([hog_feat, color_feat])
|
||||||
|
|
||||||
|
|
||||||
|
class HOGLRBaseline:
|
||||||
|
def __init__(self, data_root=DATA_ROOT, image_size=IMAGE_SIZE):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.image_size = image_size
|
||||||
|
self.clf = LogisticRegression(
|
||||||
|
C=1.0, max_iter=1000,
|
||||||
|
multi_class='multinomial', solver='lbfgs',
|
||||||
|
n_jobs=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_data(self, split):
|
||||||
|
dir_path = os.path.join(self.data_root, split)
|
||||||
|
features, labels = [], []
|
||||||
|
for class_id in range(1, NUM_CLASSES + 1):
|
||||||
|
class_dir = os.path.join(dir_path, str(class_id))
|
||||||
|
if not os.path.isdir(class_dir):
|
||||||
|
continue
|
||||||
|
files = sorted(os.listdir(class_dir))
|
||||||
|
for fname in tqdm(files, desc=f'{split}/class_{class_id}'):
|
||||||
|
fpath = os.path.join(class_dir, fname)
|
||||||
|
try:
|
||||||
|
with Image.open(fpath) as img:
|
||||||
|
feat = extract_hog_color(img)
|
||||||
|
features.append(feat)
|
||||||
|
labels.append(class_id - 1)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
print(f" {split}: {len(features)} 张")
|
||||||
|
return np.array(features, dtype=np.float32), np.array(labels)
|
||||||
|
|
||||||
|
def fit(self, train_dir=None):
|
||||||
|
if train_dir is None:
|
||||||
|
train_dir = 'train'
|
||||||
|
print(" 提取训练集 HOG 特征 ...")
|
||||||
|
X, y = self._load_data(train_dir)
|
||||||
|
self.clf.fit(X, y)
|
||||||
|
|
||||||
|
def predict(self, val_dir=None):
|
||||||
|
if val_dir is None:
|
||||||
|
val_dir = 'val'
|
||||||
|
print(" 提取验证集 HOG 特征 ...")
|
||||||
|
X, y = self._load_data(val_dir)
|
||||||
|
preds = self.clf.predict(X)
|
||||||
|
probs = self.clf.predict_proba(X)
|
||||||
|
return y, preds, probs
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# compare_models.py 导入接口
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def get_hog_lr_preds(train_loader, val_loader, device):
|
||||||
|
baseline = HOGLRBaseline()
|
||||||
|
baseline.fit('train')
|
||||||
|
return baseline.predict('val')
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 独立运行入口
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
out_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
print("HOG + LogisticRegression 基线")
|
||||||
|
baseline = HOGLRBaseline()
|
||||||
|
|
||||||
|
baseline.fit('train')
|
||||||
|
y_true, y_preds, y_probs = baseline.predict('val')
|
||||||
|
|
||||||
|
acc = accuracy_score(y_true, y_preds)
|
||||||
|
macro_f1 = f1_score(y_true, y_preds, average='macro')
|
||||||
|
print(f"\n验证集 Accuracy: {acc:.4f}")
|
||||||
|
print(f"验证集 Macro-F1: {macro_f1:.4f}")
|
||||||
|
print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}")
|
||||||
|
|
||||||
|
cm = confusion_matrix(y_true, y_preds)
|
||||||
|
fig, ax = plt.subplots(figsize=(8, 7))
|
||||||
|
ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot(
|
||||||
|
ax=ax, cmap='Blues', values_format='d', xticks_rotation=30)
|
||||||
|
ax.set_title('HOG + LogisticRegression 混淆矩阵', fontsize=14)
|
||||||
|
plt.tight_layout()
|
||||||
|
cm_path = os.path.join(out_dir, 'hog_lr_confusion_matrix.png')
|
||||||
|
plt.savefig(cm_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.show()
|
||||||
|
print(f"混淆矩阵已保存: {cm_path}")
|
||||||
|
|
@ -26,6 +26,7 @@ from Model import Net
|
||||||
from Dataloader import RobustImageFolder
|
from Dataloader import RobustImageFolder
|
||||||
from baseline.VGG_KNN import VGGKNNBaseline
|
from baseline.VGG_KNN import VGGKNNBaseline
|
||||||
from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
|
from baseline.ResNet34_Pretrained_10pct import get_resnet34_10pct_preds
|
||||||
|
from baseline.HOG_Baseline import get_hog_lr_preds
|
||||||
|
|
||||||
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']
|
||||||
matplotlib.rcParams['axes.unicode_minus'] = False
|
matplotlib.rcParams['axes.unicode_minus'] = False
|
||||||
|
|
@ -81,9 +82,10 @@ def get_vgg_knn_preds(train_loader, val_loader, device):
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
('ResNet-34', get_resnet34_preds),
|
('ResNet-34', get_resnet34_preds),
|
||||||
('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
|
('ResNet-34 (10% Fine-tune)', get_resnet34_10pct_preds),
|
||||||
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
|
('VGG16 + KNN (K=5)', get_vgg_knn_preds),
|
||||||
|
('HOG + LogisticRegression', get_hog_lr_preds),
|
||||||
# 未来轻松扩展示例:
|
# 未来轻松扩展示例:
|
||||||
# ('ResNet-18 (pretrained)', get_resnet18_preds),
|
# ('ResNet-18 (pretrained)', get_resnet18_preds),
|
||||||
# ('ResNet-50 (pretrained)', get_resnet50_preds),
|
# ('ResNet-50 (pretrained)', get_resnet50_preds),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue