diff --git a/.gitignore b/.gitignore index 8bacd63..53ddded 100644 --- a/.gitignore +++ b/.gitignore @@ -18,6 +18,9 @@ !/baseline/compare_models.py !/baseline/ResNet34_Pretrained_10pct.py !/baseline/HOG_Baseline.py +!/baseline/roc_comparison.png +!/baseline/pr_comparison.png +!/baseline/accuracy_bar.png !/training_log.csv !/confusion_matrix.png !/roc_curve.png diff --git a/README.md b/README.md index 6c47516..e859ff4 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ - Pillow - torchsummary - scikit-learn(仅 `Evaluate.py` 需要) +- scikit-image(仅 `baseline/HOG_Baseline.py` 需要) ## 快速开始 @@ -110,8 +111,9 @@ 4. **评估与可视化**: ```bash - python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线 - python Curve.py # 训练过程的 loss/f1/acc/lr 曲线 + python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线 + python Curve.py # 训练过程的 loss/f1/acc/lr 曲线 + python baseline/compare_models.py # 多模型基线对比(ROC/PR/准确率) ``` > **注意**: @@ -123,9 +125,6 @@ | 文件 | 功能 | |---|---| -| `Baseline.py` → `baseline/` | 基线模型目录,VGG16+KNN 及多模型对比 | -| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类 | -| `baseline/compare_models.py` | 多模型 ROC 曲线与准确率柱状图对比 | | `Train.py` | 训练主脚本,包含训练循环、验证、评估 | | `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 | | `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 | @@ -133,6 +132,10 @@ | `Merge_classes.py` | 数据集预处理,265 类合并为 4 类 | | `Evaluate.py` | 模型评估,绘制混淆矩阵、ROC 曲线、PR 曲线 | | `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 | +| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类基线 | +| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 | +| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) | +| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) | | `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr | | `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) | | `AGENTS.md` | AI 助手指南(开发辅助) | @@ -145,7 +148,12 @@ trash-division/ ├── AGENTS.md # AI 助手指南 ├── baseline/ # 基线模型目录 │ ├── VGG_KNN.py # VGG16 + KNN 分类脚本 -│ └── compare_models.py # 多模型对比脚本 +│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调 +│ ├── HOG_Baseline.py # HOG + LogisticRegression 纯传统基线 +│ ├── compare_models.py # 多模型对比脚本 +│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出) +│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出) +│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出) ├── best_model.pth # 最佳模型权重(不纳入版本控制) ├── Curve.py # 训练曲线绘制脚本 ├── Dataloader.py # 数据加载模块 @@ -218,6 +226,28 @@ python Curve.py ![training_curves](training_curves.png) +### 基线模型对比 + +`compare_models.py` 对所有模型在验证集上统一评估,生成三张对比图表: + +```bash +python baseline/compare_models.py +``` + +对比阵容:ResNet-34、ResNet-34 (10% Fine-tune)、VGG16 + KNN、HOG + LogisticRegression。 + +**ROC 曲线对比** + +![roc_comparison](baseline/roc_comparison.png) + +**PR 曲线对比** + +![pr_comparison](baseline/pr_comparison.png) + +**准确率柱状图** + +![accuracy_bar](baseline/accuracy_bar.png) + ## 许可证 本项目主代码采用 [MIT 许可证](LICENSE)。 diff --git a/baseline/accuracy_bar.png b/baseline/accuracy_bar.png new file mode 100644 index 0000000..a5ddaef Binary files /dev/null and b/baseline/accuracy_bar.png differ diff --git a/baseline/pr_comparison.png b/baseline/pr_comparison.png new file mode 100644 index 0000000..03f80ef Binary files /dev/null and b/baseline/pr_comparison.png differ diff --git a/baseline/roc_comparison.png b/baseline/roc_comparison.png new file mode 100644 index 0000000..66ae8cc Binary files /dev/null and b/baseline/roc_comparison.png differ