Compare commits
5 commits
1557a2c6a0
...
7c82962dc3
| Author | SHA1 | Date | |
|---|---|---|---|
| 7c82962dc3 | |||
|
|
37e24a0b09 | ||
|
|
be45ad32eb | ||
|
|
0e9bd58dce | ||
| 7ebe975186 |
6 changed files with 211 additions and 1 deletions
5
.gitignore
vendored
5
.gitignore
vendored
|
|
@ -8,6 +8,9 @@
|
||||||
!/requirements.txt
|
!/requirements.txt
|
||||||
!/THIRD_PARTY_LICENSES.md
|
!/THIRD_PARTY_LICENSES.md
|
||||||
!/Train.py
|
!/Train.py
|
||||||
|
!/web/
|
||||||
|
!/web/app.py
|
||||||
|
!/web/README.md
|
||||||
!/Baseline.py
|
!/Baseline.py
|
||||||
!/Finetune.py
|
!/Finetune.py
|
||||||
!/Curve.py
|
!/Curve.py
|
||||||
|
|
@ -21,6 +24,8 @@
|
||||||
!/baseline/roc_comparison.png
|
!/baseline/roc_comparison.png
|
||||||
!/baseline/pr_comparison.png
|
!/baseline/pr_comparison.png
|
||||||
!/baseline/accuracy_bar.png
|
!/baseline/accuracy_bar.png
|
||||||
|
!/web/
|
||||||
|
!/web/online.jpg
|
||||||
!/training_log.csv
|
!/training_log.csv
|
||||||
!/confusion_matrix.png
|
!/confusion_matrix.png
|
||||||
!/roc_curve.png
|
!/roc_curve.png
|
||||||
|
|
|
||||||
39
README.md
39
README.md
|
|
@ -6,11 +6,14 @@
|
||||||
|
|
||||||
基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
|
基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 目录
|
## 目录
|
||||||
|
|
||||||
- [项目特点](#项目特点)
|
- [项目特点](#项目特点)
|
||||||
|
- [Web 演示](#web-演示)
|
||||||
- [模型架构](#模型架构)
|
- [模型架构](#模型架构)
|
||||||
- [数据集](#数据集)
|
- [数据集](#数据集)
|
||||||
- [环境要求](#环境要求)
|
- [环境要求](#环境要求)
|
||||||
|
|
@ -34,6 +37,36 @@
|
||||||
- **断点续训**:自动检测 `best_model.pth` 并加载继续训练
|
- **断点续训**:自动检测 `best_model.pth` 并加载继续训练
|
||||||
- **多设备支持**:自动选择 CUDA > Intel XPU > CPU
|
- **多设备支持**:自动选择 CUDA > Intel XPU > CPU
|
||||||
|
|
||||||
|
## Web 演示
|
||||||
|
|
||||||
|
本项目提供基于 Gradio 的 Web 界面,上传图片即可实时预测垃圾类别。
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
### 启动
|
||||||
|
|
||||||
|
1. 确保 `best_model.pth` 位于项目根目录(如没有,先运行 `python Train.py`)
|
||||||
|
2. 安装 Gradio 依赖:
|
||||||
|
```bash
|
||||||
|
pip install "gradio>=4.0,<5.0" "pydantic>=2.5,<2.10"
|
||||||
|
```
|
||||||
|
3. 启动 Web 服务:
|
||||||
|
```bash
|
||||||
|
python web/app.py
|
||||||
|
```
|
||||||
|
4. 浏览器打开 `http://127.0.0.1:7860`
|
||||||
|
|
||||||
|
### 配置
|
||||||
|
|
||||||
|
可在 `web/app.py` 底部 `demo.launch()` 中调整:
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `server_name` | `127.0.0.1` | 局域网访问改为 `0.0.0.0` |
|
||||||
|
| `server_port` | `7860` | 端口冲突时可换 |
|
||||||
|
| `share` | `False` | 改为 `True` 可生成临时公网链接 |
|
||||||
|
| `inbrowser` | `True` | 启动后自动打开浏览器 |
|
||||||
|
|
||||||
## 模型架构
|
## 模型架构
|
||||||
|
|
||||||
模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。
|
模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。
|
||||||
|
|
@ -81,7 +114,7 @@ pip install -r requirements.txt
|
||||||
|
|
||||||
> **注意**:`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
|
> **注意**:`requirements.txt` 不锁定 PyTorch 的 CUDA / XPU 版本,请根据硬件自行安装对应版本,例如:
|
||||||
> - NVIDIA GPU:`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
|
> - NVIDIA GPU:`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121`
|
||||||
> - Intel GPU (XPU):`pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121` 安装
|
> - Intel GPU (XPU):`pip install torch torchvision --index-url https://download.pytorch.org/whl/xpu` 安装
|
||||||
> - CPU:`pip install torch torchvision` 即可
|
> - CPU:`pip install torch torchvision` 即可
|
||||||
|
|
||||||
## 快速开始
|
## 快速开始
|
||||||
|
|
@ -132,6 +165,7 @@ pip install -r requirements.txt
|
||||||
| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
|
| `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 |
|
||||||
| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) |
|
| `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) |
|
||||||
| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) |
|
| `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) |
|
||||||
|
| `web/app.py` | Gradio Web 前端,上传图片实时分类 |
|
||||||
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
|
| `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr |
|
||||||
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) |
|
| `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) |
|
||||||
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
|
| `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 |
|
||||||
|
|
@ -148,6 +182,9 @@ trash-division/
|
||||||
│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出)
|
│ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出)
|
||||||
│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出)
|
│ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出)
|
||||||
│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出)
|
│ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出)
|
||||||
|
├── web/ # Web 前端目录
|
||||||
|
│ ├── app.py # Gradio 应用入口
|
||||||
|
│ └── online.jpg # Web 演示截图
|
||||||
├── best_model.pth # 最佳模型权重(不纳入版本控制)
|
├── best_model.pth # 最佳模型权重(不纳入版本控制)
|
||||||
├── Curve.py # 训练曲线绘制脚本
|
├── Curve.py # 训练曲线绘制脚本
|
||||||
├── Dataloader.py # 数据加载模块
|
├── Dataloader.py # 数据加载模块
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
torchvision>=0.11
|
torchvision>=0.11
|
||||||
|
gradio>=4.0,<5.0
|
||||||
|
pydantic>=2.5,<2.10
|
||||||
tqdm
|
tqdm
|
||||||
matplotlib
|
matplotlib
|
||||||
pandas
|
pandas
|
||||||
|
|
|
||||||
62
web/README.md
Normal file
62
web/README.md
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Trash Division Web 前端
|
||||||
|
|
||||||
|
基于 Gradio 的垃圾分类识别 Web 应用,上传图片即可预测垃圾类别。
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
|
||||||
|
除项目根目录 `requirements.txt` 外,Web 前端额外依赖:
|
||||||
|
|
||||||
|
| 包 | 版本 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `gradio` | `>=4.0,<5.0` | Web UI 框架 |
|
||||||
|
| `pydantic` | `>=2.5,<2.10` | gradio 4.x 兼容性约束(新版会报 `"const" in schema` 错误) |
|
||||||
|
|
||||||
|
> 安装:`pip install gradio>=4.0,<5.0 pydantic>=2.5,<2.10`
|
||||||
|
|
||||||
|
## 启动前准备
|
||||||
|
|
||||||
|
1. **确保 `best_model.pth` 存在**
|
||||||
|
在项目根目录(`trash-division/`)下放置训练好的模型权重。如没有,先运行:
|
||||||
|
```bash
|
||||||
|
cd .. && python Train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **安装依赖**(如还未安装):
|
||||||
|
```bash
|
||||||
|
pip install -r ../requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 启动
|
||||||
|
|
||||||
|
在 `web/` 目录下运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或者在项目根目录运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python web/app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
启动后浏览器会自动打开 `http://127.0.0.1:7860`。
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
可在 `app.py` 底部 `demo.launch()` 中调整以下参数:
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `server_name` | `127.0.0.1` | 本机访问。如需局域网内其他设备访问,改为 `0.0.0.0` |
|
||||||
|
| `server_port` | `7860` | 端口号,冲突时可换 |
|
||||||
|
| `share` | `False` | 改为 `True` 可生成临时公网链接分享给同学 |
|
||||||
|
| `inbrowser` | `True` | 启动后自动打开浏览器 |
|
||||||
|
|
||||||
|
## 兼容性
|
||||||
|
|
||||||
|
| 项 | 说明 |
|
||||||
|
|---|---|
|
||||||
|
| Python | `>=3.9,<3.10`(Gradio 5.x 需 Python 3.10+) |
|
||||||
|
| PyTorch | `>=1.10` |
|
||||||
|
| 设备 | 自动选择 CUDA > Intel XPU > Apple MPS > CPU |
|
||||||
104
web/app.py
Normal file
104
web/app.py
Normal file
|
|
@ -0,0 +1,104 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
# 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import torch
|
||||||
|
from torchvision import transforms
|
||||||
|
from PIL import Image
|
||||||
|
from Model import Net # 根据上传的 Model.py,模型类名为 Net
|
||||||
|
|
||||||
|
# 项目根目录(web/ 的上一级)
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
# 1. 基础配置与类别映射
|
||||||
|
# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
|
||||||
|
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||||
|
|
||||||
|
# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致
|
||||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
|
||||||
|
print(f"当前使用的推理设备: {device}")
|
||||||
|
|
||||||
|
# 2. 初始化模型并加载最佳权重
|
||||||
|
model = Net(num_classes=4)
|
||||||
|
try:
|
||||||
|
# 采用与 Evaluate.py 一致的健壮加载方式
|
||||||
|
model_path = os.path.join(PROJECT_ROOT, 'best_model.pth')
|
||||||
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
|
if 'model_state_dict' in state_dict:
|
||||||
|
state_dict = state_dict['model_state_dict']
|
||||||
|
elif 'model' in state_dict:
|
||||||
|
state_dict = state_dict['model']
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model = model.to(device).eval()
|
||||||
|
print("✅ 成功加载 best_model.pth 权重")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}")
|
||||||
|
|
||||||
|
# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((256, 256)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.485, 0.456, 0.406],
|
||||||
|
std=[0.229, 0.224, 0.225]
|
||||||
|
)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 4. 核心推理函数
|
||||||
|
def predict(image):
|
||||||
|
if image is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Gradio 传入的 pil 图像,确保转为 RGB 格式
|
||||||
|
image = image.convert('RGB')
|
||||||
|
|
||||||
|
# 预处理并增加 batch 维度
|
||||||
|
input_tensor = transform(image).unsqueeze(0).to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(input_tensor)
|
||||||
|
# 使用 Softmax 将 logits 转换为 0~1 的概率分布
|
||||||
|
probabilities = torch.softmax(outputs, dim=1)[0]
|
||||||
|
|
||||||
|
# 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 }
|
||||||
|
result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
||||||
|
return result_dict
|
||||||
|
|
||||||
|
# 5. 构建与美化 Gradio 界面
|
||||||
|
with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo:
|
||||||
|
gr.Markdown(
|
||||||
|
"""
|
||||||
|
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
|
||||||
|
<h1>🗑️ Trash Division - 智能垃圾分类系统</h1>
|
||||||
|
<p>基于 <b>ResNet-34</b> 架构,支持精准识别:<b>厨余垃圾、可回收物、其他垃圾、有害垃圾</b>。</p>
|
||||||
|
<p><i>同济大学 Python 人工智能程序设计课程小组作业</i></p>
|
||||||
|
</div>
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
# type="pil" 让 Gradio 直接传 PIL Image 对象给预测函数,配合 torchvision 最方便
|
||||||
|
image_input = gr.Image(type="pil", label="上传垃圾图片 (支持拍照)")
|
||||||
|
with gr.Row():
|
||||||
|
clear_btn = gr.Button("清空", variant="secondary")
|
||||||
|
submit_btn = gr.Button("开始识别", variant="primary")
|
||||||
|
|
||||||
|
with gr.Column(scale=1):
|
||||||
|
label_output = gr.Label(num_top_classes=4, label="预测结果与置信度")
|
||||||
|
|
||||||
|
# 绑定点击事件
|
||||||
|
submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
|
||||||
|
clear_btn.click(lambda: (None, None), inputs=None, outputs=[image_input, label_output])
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 启动 Web 界面
|
||||||
|
demo.launch(
|
||||||
|
server_name="127.0.0.1",
|
||||||
|
server_port=7860,
|
||||||
|
share=False, # 如果你想生成临时公网链接分享给同学测试,改为 True
|
||||||
|
inbrowser=True # 运行后自动在默认浏览器中打开
|
||||||
|
)
|
||||||
BIN
web/online.jpg
Normal file
BIN
web/online.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 1 MiB |
Loading…
Reference in a new issue