修改accuracy_bar.png
This commit is contained in:
parent
25e11c1914
commit
65c9742a1c
4 changed files with 0 additions and 87 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -8,7 +8,6 @@
|
|||
!/THIRD_PARTY_LICENSES.md
|
||||
!/Train.py
|
||||
!/Baseline.py
|
||||
!/AGENTS.md
|
||||
!/Finetune.py
|
||||
!/Curve.py
|
||||
!/Evaluate.py
|
||||
|
|
|
|||
84
AGENTS.md
84
AGENTS.md
|
|
@ -1,84 +0,0 @@
|
|||
# AGENTS.md
|
||||
|
||||
## Project
|
||||
|
||||
CNN-based garbage classification (4 classes: 厨余垃圾/可回收物/其他垃圾/有害垃圾). ResNet-34 architecture, ~21M params, 256×256 RGB input, ~900 lines across 11 Python files. No package structure.
|
||||
|
||||
## Pipeline (order matters)
|
||||
|
||||
```bash
|
||||
python Merge_classes.py # merges 265 → 4 classes, creates ../trash_division_data/ultimate_4_class/
|
||||
python Train.py # trains the model, saves best_model.pth + training_log.csv
|
||||
python Finetune.py # optional: freezes early layers, saves finetuned_model.pth + finetune_log.csv
|
||||
python Evaluate.py # plots confusion matrix / ROC / PR curves from best_model.pth
|
||||
python Curve.py # plots loss/f1/acc/lr curves from training_log.csv
|
||||
python baseline/VGG_KNN.py # VGG16 feature extraction + KNN baseline
|
||||
python baseline/compare_models.py # compares multiple models (ROC + accuracy bar chart)
|
||||
```
|
||||
|
||||
Also usable standalone: `python Model.py` prints `torchsummary` parameter summary.
|
||||
|
||||
## Dependencies
|
||||
|
||||
No `requirements.txt` — install manually: `torch`, `torchvision`, `tqdm`, `matplotlib`, `pandas`, `Pillow`, `torchsummary`. `Evaluate.py` and `baseline/*.py` additionally need `scikit-learn`.
|
||||
|
||||
## Data setup
|
||||
|
||||
Data expected **outside repo** at `../trash_division_data/` (sibling dir). `Merge_classes.py` reads `val/classname.txt` there; `Train.py` and `Finetune.py` expect `ultimate_4_class/{train,val}/` with class-numbered subdirs (`1/` to `4/`). All paths relative to repo root.
|
||||
|
||||
## .gitignore — whitelist pattern
|
||||
|
||||
`.gitignore` uses `*` (ignore everything) then un-ignores specific files with `!` patterns. **Any new file you add to the repo must be explicitly whitelisted** or it will be invisible to git. The current whitelist includes: `Dataloader.py`, `LICENSE`, `Merge_classes.py`, `Model.py`, `README.md`, `THIRD_PARTY_LICENSES.md`, `Train.py`, `.gitattributes`, `.gitignore`, plus `Finetune.py`, `Curve.py`, `Evaluate.py`, `AGENTS.md`, 4× output PNG, `training_log.csv`, and `baseline/`.
|
||||
|
||||
`best_model.pth` and `finetuned_model.pth` are **untracked** (~125 MB each) — back them up manually if needed. `Finetune.py`, `Curve.py`, `Evaluate.py`, `AGENTS.md`, `training_log*.csv`, and `finetune_log.csv` are also untracked (not in whitelist).
|
||||
|
||||
## Gotchas
|
||||
|
||||
- **Windows: set `num_workers=0`** in `create_dataloaders()` call sites (`Train.py:191`, `Finetune.py:196`, `Dataloader.py:229`)
|
||||
- Device selection priority: `cuda > xpu > cpu` (`xpu` = Intel GPU)
|
||||
- Training auto-resumes from `best_model.pth` if present in repo root; fine-tuning auto-loads it too
|
||||
- `Dataloader.py` uses `RobustImageFolder` — scans all images, skips corrupted ones (tqdm progress), slow on first load
|
||||
- Image normalization: hardcoded ImageNet stats (`mean=[0.485, 0.456, 0.406]`, `std=[0.229, 0.224, 0.225]`)
|
||||
- `create_dataloaders()` has a `val_split` parameter that's **never used** — the code always expects a pre-split `val/` folder
|
||||
|
||||
### Finetune-specific
|
||||
|
||||
- **BUG**: `freeze_base_layers()` references `model.stage2` and `model.stage3` but the model uses `layer2`/`layer3`. This crashes at runtime — fix to `model.layer2`/`model.layer3` (or delete the function, since it would freeze `layer2`+`layer3` while docstring says only conv1+stage2).
|
||||
- `freeze_base_layers()` actually freezes **conv1, bn1, layer2, AND layer3** (despite docstring saying only conv1 + stage2). Only layer1, layer4, and fc are trainable.
|
||||
- Class weights use `power=1.5` (vs `power=1.0` in Train) — amplifies minority-class weighting
|
||||
- Defaults: `lr=0.0001`, `epochs=30` (vs `lr=0.001`, `epochs=20` in Train)
|
||||
- Writes `finetune_log.csv` (Train writes `training_log.csv`)
|
||||
- Loads `best_model.pth` then saves `finetuned_model.pth`
|
||||
|
||||
### Curve.py
|
||||
|
||||
- Hardcoded to read `training_log.csv` only — won't work for `finetune_log.csv`
|
||||
- Requires `pandas`, saves `training_curves.png`
|
||||
|
||||
### Evaluate.py
|
||||
|
||||
- Hardcoded constants at top of `__main__` block: `MODEL_PATH`, `DATA_ROOT`, `BATCH_SIZE`, `NUM_WORKERS`
|
||||
- Loads model from `best_model.pth` by default; handles both bare state_dict and `model_state_dict`/`model` key wrappers
|
||||
- Saves `confusion_matrix.png`, `roc_curve.png`, `pr_curve.png`
|
||||
- Requires `scikit-learn`
|
||||
|
||||
### baseline/ (VGG_KNN.py + compare_models.py)
|
||||
|
||||
- `baseline/VGG_KNN.py` can run standalone (`python baseline/VGG_KNN.py`) or be imported from `compare_models.py`
|
||||
- Uses `sys.path.insert` at top so it can import root-level modules (`Model`, `Dataloader`) from subdirectory
|
||||
- `compare_models.py` has a `MODELS` registry list — add new models by writing a `get_xxx_preds(train_loader, val_loader, device)` function and adding one line to the list; no plot code changes needed
|
||||
- VGG16 feature dimension: 25088 (512 channels × 7×7 avgpool)
|
||||
- KNN uses `predict_proba` (neighbor voting proportions) for ROC curves — coarse-grained but valid AUC
|
||||
- Output: `baseline/roc_comparison.png`, `baseline/accuracy_bar.png`, `baseline/vgg_knn_confusion_matrix.png`
|
||||
- Compare scripts output images to `baseline/` dir (not repo root)
|
||||
|
||||
## Model architecture reference
|
||||
|
||||
`Model.py` attribute names (for freezing / layer access):
|
||||
- `conv1`, `bn1`, `relu`, `maxpool`
|
||||
- `layer1`, `layer2`, `layer3`, `layer4`
|
||||
- `avgpool`, `dropout` (nn.Dropout), `fc` (nn.Linear(512, 4))
|
||||
|
||||
## Testing
|
||||
|
||||
No test suite.
|
||||
|
|
@ -138,14 +138,12 @@
|
|||
| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) |
|
||||
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
|
||||
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) |
|
||||
| `AGENTS.md` | AI 助手指南(开发辅助) |
|
||||
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
trash-division/
|
||||
├── AGENTS.md # AI 助手指南
|
||||
├── baseline/ # 基线模型目录
|
||||
│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本
|
||||
│ ├── ResNet34_Pretrained_10pct.py # ResNet-34 ImageNet 预训练 + 10% 微调
|
||||
|
|
|
|||
Binary file not shown.
|
Before Width: | Height: | Size: 42 KiB After Width: | Height: | Size: 45 KiB |
Loading…
Reference in a new issue