diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c68b013 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +* +!/.gitattributes +!/Dataloader.py +!/LICENSE +!/Merge_classes.py +!/Model.py +!/README.md +!/THIRD_PARTY_LICENSES.md +!/Train.py +!/AGENTS.md +!/Finetune.py +!/Curve.py +!/Evaluate.py +!/training_log.csv +!/confusion_matrix.png +!/roc_curve.png +!/pr_curve.png +!/training_curves.png +!.gitignore diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..d7c3df2 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,72 @@ +# AGENTS.md + +## Project + +CNN-based garbage classification (4 classes: 厨余垃圾/可回收物/其他垃圾/有害垃圾). ResNet-34 architecture, ~21M params, 256×256 RGB input, ~900 lines across 8 Python files. No package structure. + +## Pipeline (order matters) + +```bash +python Merge_classes.py # merges 265 → 4 classes, creates ../trash_division_data/ultimate_4_class/ +python Train.py # trains the model, saves best_model.pth + training_log.csv +python Finetune.py # optional: freezes early layers, saves finetuned_model.pth + finetune_log.csv +python Evaluate.py # plots confusion matrix / ROC / PR curves from best_model.pth +python Curve.py # plots loss/f1/acc/lr curves from training_log.csv +``` + +Also usable standalone: `python Model.py` prints `torchsummary` parameter summary. + +## Dependencies + +No `requirements.txt` — install manually: `torch`, `torchvision`, `tqdm`, `matplotlib`, `pandas`, `Pillow`, `torchsummary`. `Evaluate.py` additionally needs `scikit-learn`. + +## Data setup + +Data expected **outside repo** at `../trash_division_data/` (sibling dir). `Merge_classes.py` reads `val/classname.txt` there; `Train.py` and `Finetune.py` expect `ultimate_4_class/{train,val}/` with class-numbered subdirs (`1/` to `4/`). All paths relative to repo root. + +## .gitignore — whitelist pattern + +`.gitignore` uses `*` (ignore everything) then un-ignores specific files with `!` patterns. **Any new file you add to the repo must be explicitly whitelisted** or it will be invisible to git. The current whitelist: `Dataloader.py`, `LICENSE`, `Merge_classes.py`, `Model.py`, `README.md`, `THIRD_PARTY_LICENSES.md`, `Train.py`, `.gitattributes`, `.gitignore`. + +`best_model.pth` and `finetuned_model.pth` are **untracked** (~125 MB each) — back them up manually if needed. `Finetune.py`, `Curve.py`, `Evaluate.py`, `AGENTS.md`, `training_log*.csv`, and `finetune_log.csv` are also untracked (not in whitelist). + +## Gotchas + +- **Windows: set `num_workers=0`** in `create_dataloaders()` call sites (`Train.py:191`, `Finetune.py:196`, `Dataloader.py:229`) +- Device selection priority: `cuda > xpu > cpu` (`xpu` = Intel GPU) +- Training auto-resumes from `best_model.pth` if present in repo root; fine-tuning auto-loads it too +- `Dataloader.py` uses `RobustImageFolder` — scans all images, skips corrupted ones (tqdm progress), slow on first load +- Image normalization: hardcoded ImageNet stats (`mean=[0.485, 0.456, 0.406]`, `std=[0.229, 0.224, 0.225]`) +- `create_dataloaders()` has a `val_split` parameter that's **never used** — the code always expects a pre-split `val/` folder + +### Finetune-specific + +- **BUG**: `freeze_base_layers()` references `model.stage2` and `model.stage3` but the model uses `layer2`/`layer3`. This crashes at runtime — fix to `model.layer2`/`model.layer3` (or delete the function, since it would freeze `layer2`+`layer3` while docstring says only conv1+stage2). +- `freeze_base_layers()` actually freezes **conv1, bn1, layer2, AND layer3** (despite docstring saying only conv1 + stage2). Only layer1, layer4, and fc are trainable. +- Class weights use `power=1.5` (vs `power=1.0` in Train) — amplifies minority-class weighting +- Defaults: `lr=0.0001`, `epochs=30` (vs `lr=0.001`, `epochs=20` in Train) +- Writes `finetune_log.csv` (Train writes `training_log.csv`) +- Loads `best_model.pth` then saves `finetuned_model.pth` + +### Curve.py + +- Hardcoded to read `training_log.csv` only — won't work for `finetune_log.csv` +- Requires `pandas`, saves `training_curves.png` + +### Evaluate.py + +- Hardcoded constants at top of `__main__` block: `MODEL_PATH`, `DATA_ROOT`, `BATCH_SIZE`, `NUM_WORKERS` +- Loads model from `best_model.pth` by default; handles both bare state_dict and `model_state_dict`/`model` key wrappers +- Saves `confusion_matrix.png`, `roc_curve.png`, `pr_curve.png` +- Requires `scikit-learn` + +## Model architecture reference + +`Model.py` attribute names (for freezing / layer access): +- `conv1`, `bn1`, `relu`, `maxpool` +- `layer1`, `layer2`, `layer3`, `layer4` +- `avgpool`, `dropout` (nn.Dropout), `fc` (nn.Linear(512, 4)) + +## Testing + +No test suite. diff --git a/README.md b/README.md index 7e89626..a34e646 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ - [文件说明](#文件说明) - [目录结构](#目录结构) - [训练细节](#训练细节) +- [评估与可视化](#评估与可视化) - [许可证](#许可证) --- @@ -84,6 +85,7 @@ - pandas - Pillow - torchsummary +- scikit-learn(仅 `Evaluate.py` 需要) ## 快速开始 @@ -99,6 +101,19 @@ python Train.py ``` +3. **微调模型**(可选,冻结浅层、微调深层): + + ```bash + python Finetune.py + ``` + +4. **评估与可视化**: + + ```bash + python Evaluate.py # 混淆矩阵、ROC 曲线、PR 曲线 + python Curve.py # 训练过程的 loss/f1/acc/lr 曲线 + ``` + > **注意**: > - 数据目录默认为 `../trash_division_data/ultimate_4_class/`,需先运行合并脚本 > - Windows 系统需将 `num_workers` 设为 `0`(参见 `Dataloader.py` 和 `Train.py`) @@ -109,10 +124,14 @@ | 文件 | 功能 | |---|---| | `Train.py` | 训练主脚本,包含训练循环、验证、评估 | +| `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 | | `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 | | `Model.py` | 模型定义,ResNet-34(BasicBlock)+ Dropout | | `Merge_classes.py` | 数据集预处理,265 类合并为 4 类 | -| `best_model.pth` | 训练好的最佳模型权重(约 125 MB) | +| `Evaluate.py` | 模型评估,绘制混淆矩阵、ROC 曲线、PR 曲线 | +| `Curve.py` | 训练曲线绘制,从 CSV 读取并绘制 loss/f1/acc/lr 曲线 | +| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr | +| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) | | `AGENTS.md` | AI 助手指南(开发辅助) | | `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 | @@ -121,15 +140,23 @@ ``` trash-division/ ├── AGENTS.md # AI 助手指南 -├── best_model.pth # 最佳模型权重 +├── best_model.pth # 最佳模型权重(不纳入版本控制) +├── Curve.py # 训练曲线绘制脚本 ├── Dataloader.py # 数据加载模块 +├── Evaluate.py # 模型评估可视化脚本 +├── Finetune.py # 微调脚本 ├── .gitattributes # Git 属性配置 ├── LICENSE # MIT 许可证 ├── Merge_classes.py # 数据集预处理脚本 ├── Model.py # 模型定义 ├── README.md # 项目说明(本文件) ├── THIRD_PARTY_LICENSES.md # 第三方许可证声明 -└── Train.py # 训练主脚本 +├── Train.py # 训练主脚本 +├── training_log.csv # 训练日志 +├── confusion_matrix.png # 混淆矩阵(Evaluate.py 输出) +├── roc_curve.png # ROC 曲线(Evaluate.py 输出) +├── pr_curve.png # PR 曲线(Evaluate.py 输出) +└── training_curves.png # 训练曲线(Curve.py 输出) ``` ## 训练细节 @@ -149,6 +176,43 @@ trash-division/ 训练时数据增强管线:RandomResizedCrop(256, scale=(0.8, 1.0)) + RandomHorizontalFlip(p=0.5) + RandomRotation(+-15 deg) + ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2) +## 评估与可视化 + +训练完成后,`training_log.csv` 会记录每个 epoch 的训练/验证指标。以下两个脚本用于可视化分析: + +### Evaluate.py — 模型评估 + +在验证集上推理,生成三张评估图表: + +| 图表 | 说明 | +|---|---| +| `confusion_matrix.png` | 混淆矩阵,展示各类别的预测分布 | +| `roc_curve.png` | ROC 曲线(One-vs-Rest),含各类别 AUC 和 Macro-avg AUC | +| `pr_curve.png` | Precision-Recall 曲线,含各类别 Average Precision | + +```bash +python Evaluate.py +``` + +脚本顶部的 `MODEL_PATH`、`DATA_ROOT`、`BATCH_SIZE`、`NUM_WORKERS` 可按需修改。 + +### Curve.py — 训练曲线 + +从 `training_log.csv` 读取训练日志,绘制四张子图: + +| 子图 | 说明 | +|---|---| +| Loss vs Epoch | 训练/验证损失曲线 | +| F1 Score vs Epoch | 训练/验证宏平均 F1 曲线 | +| Accuracy vs Epoch | 训练/验证准确率曲线 | +| Learning Rate vs Epoch | 余弦退火学习率变化曲线 | + +```bash +python Curve.py +``` + +输出文件为 `training_curves.png`。 + ## 许可证 本项目主代码采用 [MIT 许可证](LICENSE)。 diff --git a/confusion_matrix.png b/confusion_matrix.png new file mode 100644 index 0000000..765f7c9 Binary files /dev/null and b/confusion_matrix.png differ diff --git a/pr_curve.png b/pr_curve.png new file mode 100644 index 0000000..f24c81d Binary files /dev/null and b/pr_curve.png differ diff --git a/roc_curve.png b/roc_curve.png new file mode 100644 index 0000000..05cdfbd Binary files /dev/null and b/roc_curve.png differ diff --git a/training_curves.png b/training_curves.png new file mode 100644 index 0000000..482cf5e Binary files /dev/null and b/training_curves.png differ diff --git a/training_log.csv b/training_log.csv new file mode 100644 index 0000000..9795ca2 --- /dev/null +++ b/training_log.csv @@ -0,0 +1,81 @@ +epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best +1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best +2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best +3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316, +4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best +5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076, +6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best +7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118, +8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884, +9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621, +10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217, +11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best +12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092, +13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994, +14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best +15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364, +16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369, +17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863, +18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best +19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214, +20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369, +21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355, +22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546, +23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best +24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182, +25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052, +26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best +27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best +28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867, +29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357, +30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best +31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732, +32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683, +33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best +34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633, +35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205, +36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577, +37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594, +38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612, +39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671, +40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996, +41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best +42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878, +43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063, +44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237, +45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best +46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372, +47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best +48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best +49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best +50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762, +51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431, +52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best +53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best +54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285, +55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best +56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178, +57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161, +58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best +59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best +60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316, +61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best +62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best +63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best +64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best +65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376, +66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966, +67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077, +68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best +69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best +70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321, +71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best +72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best +73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05, +74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best +75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best +76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05, +77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05, +78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06, +79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best +80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,