5.2 KiB
AGENTS.md
Project
CNN-based garbage classification (4 classes: 厨余垃圾/可回收物/其他垃圾/有害垃圾). ResNet-34 architecture, ~21M params, 256×256 RGB input, ~900 lines across 11 Python files. No package structure.
Pipeline (order matters)
python Merge_classes.py # merges 265 → 4 classes, creates ../trash_division_data/ultimate_4_class/
python Train.py # trains the model, saves best_model.pth + training_log.csv
python Finetune.py # optional: freezes early layers, saves finetuned_model.pth + finetune_log.csv
python Evaluate.py # plots confusion matrix / ROC / PR curves from best_model.pth
python Curve.py # plots loss/f1/acc/lr curves from training_log.csv
python baseline/VGG_KNN.py # VGG16 feature extraction + KNN baseline
python baseline/compare_models.py # compares multiple models (ROC + accuracy bar chart)
Also usable standalone: python Model.py prints torchsummary parameter summary.
Dependencies
No requirements.txt — install manually: torch, torchvision, tqdm, matplotlib, pandas, Pillow, torchsummary. Evaluate.py and baseline/*.py additionally need scikit-learn.
Data setup
Data expected outside repo at ../trash_division_data/ (sibling dir). Merge_classes.py reads val/classname.txt there; Train.py and Finetune.py expect ultimate_4_class/{train,val}/ with class-numbered subdirs (1/ to 4/). All paths relative to repo root.
.gitignore — whitelist pattern
.gitignore uses * (ignore everything) then un-ignores specific files with ! patterns. Any new file you add to the repo must be explicitly whitelisted or it will be invisible to git. The current whitelist includes: Dataloader.py, LICENSE, Merge_classes.py, Model.py, README.md, THIRD_PARTY_LICENSES.md, Train.py, .gitattributes, .gitignore, plus Finetune.py, Curve.py, Evaluate.py, AGENTS.md, 4× output PNG, training_log.csv, and baseline/.
best_model.pth and finetuned_model.pth are untracked (~125 MB each) — back them up manually if needed. Finetune.py, Curve.py, Evaluate.py, AGENTS.md, training_log*.csv, and finetune_log.csv are also untracked (not in whitelist).
Gotchas
- Windows: set
num_workers=0increate_dataloaders()call sites (Train.py:191,Finetune.py:196,Dataloader.py:229) - Device selection priority:
cuda > xpu > cpu(xpu= Intel GPU) - Training auto-resumes from
best_model.pthif present in repo root; fine-tuning auto-loads it too Dataloader.pyusesRobustImageFolder— scans all images, skips corrupted ones (tqdm progress), slow on first load- Image normalization: hardcoded ImageNet stats (
mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) create_dataloaders()has aval_splitparameter that's never used — the code always expects a pre-splitval/folder
Finetune-specific
- BUG:
freeze_base_layers()referencesmodel.stage2andmodel.stage3but the model useslayer2/layer3. This crashes at runtime — fix tomodel.layer2/model.layer3(or delete the function, since it would freezelayer2+layer3while docstring says only conv1+stage2). freeze_base_layers()actually freezes conv1, bn1, layer2, AND layer3 (despite docstring saying only conv1 + stage2). Only layer1, layer4, and fc are trainable.- Class weights use
power=1.5(vspower=1.0in Train) — amplifies minority-class weighting - Defaults:
lr=0.0001,epochs=30(vslr=0.001,epochs=20in Train) - Writes
finetune_log.csv(Train writestraining_log.csv) - Loads
best_model.pththen savesfinetuned_model.pth
Curve.py
- Hardcoded to read
training_log.csvonly — won't work forfinetune_log.csv - Requires
pandas, savestraining_curves.png
Evaluate.py
- Hardcoded constants at top of
__main__block:MODEL_PATH,DATA_ROOT,BATCH_SIZE,NUM_WORKERS - Loads model from
best_model.pthby default; handles both bare state_dict andmodel_state_dict/modelkey wrappers - Saves
confusion_matrix.png,roc_curve.png,pr_curve.png - Requires
scikit-learn
baseline/ (VGG_KNN.py + compare_models.py)
baseline/VGG_KNN.pycan run standalone (python baseline/VGG_KNN.py) or be imported fromcompare_models.py- Uses
sys.path.insertat top so it can import root-level modules (Model,Dataloader) from subdirectory compare_models.pyhas aMODELSregistry list — add new models by writing aget_xxx_preds(train_loader, val_loader, device)function and adding one line to the list; no plot code changes needed- VGG16 feature dimension: 25088 (512 channels × 7×7 avgpool)
- KNN uses
predict_proba(neighbor voting proportions) for ROC curves — coarse-grained but valid AUC - Output:
baseline/roc_comparison.png,baseline/accuracy_bar.png,baseline/vgg_knn_confusion_matrix.png - Compare scripts output images to
baseline/dir (not repo root)
Model architecture reference
Model.py attribute names (for freezing / layer access):
conv1,bn1,relu,maxpoollayer1,layer2,layer3,layer4avgpool,dropout(nn.Dropout),fc(nn.Linear(512, 4))
Testing
No test suite.