Merge branch 'data_cleaning_test' into main

This commit is contained in:
weikaiwen348-code 2026-06-03 20:53:07 +08:00
commit 37e24a0b09
4 changed files with 171 additions and 0 deletions

3
.gitignore vendored
View file

@ -8,6 +8,9 @@
!/requirements.txt
!/THIRD_PARTY_LICENSES.md
!/Train.py
!/web/
!/web/app.py
!/web/README.md
!/Baseline.py
!/Finetune.py
!/Curve.py

View file

@ -1,5 +1,7 @@
torch>=1.10
torchvision>=0.11
gradio>=4.0,<5.0
pydantic>=2.5,<2.10
tqdm
matplotlib
pandas

62
web/README.md Normal file
View file

@ -0,0 +1,62 @@
# Trash Division Web 前端
基于 Gradio 的垃圾分类识别 Web 应用,上传图片即可预测垃圾类别。
## 依赖
除项目根目录 `requirements.txt`Web 前端额外依赖:
| 包 | 版本 | 说明 |
|---|---|---|
| `gradio` | `>=4.0,<5.0` | Web UI 框架 |
| `pydantic` | `>=2.5,<2.10` | gradio 4.x 兼容性约束(新版会报 `"const" in schema` 错误) |
> 安装:`pip install gradio>=4.0,<5.0 pydantic>=2.5,<2.10`
## 启动前准备
1. **确保 `best_model.pth` 存在**
在项目根目录(`trash-division/`)下放置训练好的模型权重。如没有,先运行:
```bash
cd .. && python Train.py
```
2. **安装依赖**(如还未安装):
```bash
pip install -r ../requirements.txt
```
## 启动
`web/` 目录下运行:
```bash
python app.py
```
或者在项目根目录运行:
```bash
python web/app.py
```
启动后浏览器会自动打开 `http://127.0.0.1:7860`
## 配置说明
可在 `app.py` 底部 `demo.launch()` 中调整以下参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
| `server_name` | `127.0.0.1` | 本机访问。如需局域网内其他设备访问,改为 `0.0.0.0` |
| `server_port` | `7860` | 端口号,冲突时可换 |
| `share` | `False` | 改为 `True` 可生成临时公网链接分享给同学 |
| `inbrowser` | `True` | 启动后自动打开浏览器 |
## 兼容性
| 项 | 说明 |
|---|---|
| Python | `>=3.9,<3.10`Gradio 5.x 需 Python 3.10+ |
| PyTorch | `>=1.10` |
| 设备 | 自动选择 CUDA > Intel XPU > Apple MPS > CPU |

104
web/app.py Normal file
View file

@ -0,0 +1,104 @@
import sys
import os
# 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from Model import Net # 根据上传的 Model.py模型类名为 Net
# 项目根目录web/ 的上一级)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 1. 基础配置与类别映射
# 根据 Merge_classes.py1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致
device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"当前使用的推理设备: {device}")
# 2. 初始化模型并加载最佳权重
model = Net(num_classes=4)
try:
# 采用与 Evaluate.py 一致的健壮加载方式
model_path = os.path.join(PROJECT_ROOT, 'best_model.pth')
state_dict = torch.load(model_path, map_location=device)
if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict']
elif 'model' in state_dict:
state_dict = state_dict['model']
model.load_state_dict(state_dict)
model = model.to(device).eval()
print("✅ 成功加载 best_model.pth 权重")
except Exception as e:
print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}")
# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# 4. 核心推理函数
def predict(image):
if image is None:
return None
# Gradio 传入的 pil 图像,确保转为 RGB 格式
image = image.convert('RGB')
# 预处理并增加 batch 维度
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
# 使用 Softmax 将 logits 转换为 0~1 的概率分布
probabilities = torch.softmax(outputs, dim=1)[0]
# 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 }
result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
return result_dict
# 5. 构建与美化 Gradio 界面
with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo:
gr.Markdown(
"""
<div style="text-align: center; max-width: 800px; margin: 0 auto;">
<h1>🗑 Trash Division - 智能垃圾分类系统</h1>
<p>基于 <b>ResNet-34</b> 架构支持精准识别<b>厨余垃圾可回收物其他垃圾有害垃圾</b></p>
<p><i>同济大学 Python 人工智能程序设计课程小组作业</i></p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
# type="pil" 让 Gradio 直接传 PIL Image 对象给预测函数,配合 torchvision 最方便
image_input = gr.Image(type="pil", label="上传垃圾图片 (支持拍照)")
with gr.Row():
clear_btn = gr.Button("清空", variant="secondary")
submit_btn = gr.Button("开始识别", variant="primary")
with gr.Column(scale=1):
label_output = gr.Label(num_top_classes=4, label="预测结果与置信度")
# 绑定点击事件
submit_btn.click(fn=predict, inputs=image_input, outputs=label_output)
clear_btn.click(lambda: (None, None), inputs=None, outputs=[image_input, label_output])
if __name__ == "__main__":
# 启动 Web 界面
demo.launch(
server_name="127.0.0.1",
server_port=7860,
share=False, # 如果你想生成临时公网链接分享给同学测试,改为 True
inbrowser=True # 运行后自动在默认浏览器中打开
)