新增web推理前端:Gradio网页界面,支持上传图片实时分类

This commit is contained in:
weikaiwen348-code 2026-06-03 20:52:28 +08:00
parent 0e9bd58dce
commit be45ad32eb
4 changed files with 77 additions and 2 deletions

4
.gitignore vendored
View file

@ -8,7 +8,9 @@
!/requirements.txt !/requirements.txt
!/THIRD_PARTY_LICENSES.md !/THIRD_PARTY_LICENSES.md
!/Train.py !/Train.py
!/app.py !/web/
!/web/app.py
!/web/README.md
!/Baseline.py !/Baseline.py
!/Finetune.py !/Finetune.py
!/Curve.py !/Curve.py

View file

@ -1,5 +1,7 @@
torch>=1.10 torch>=1.10
torchvision>=0.11 torchvision>=0.11
gradio>=4.0,<5.0
pydantic>=2.5,<2.10
tqdm tqdm
matplotlib matplotlib
pandas pandas

62
web/README.md Normal file
View file

@ -0,0 +1,62 @@
# Trash Division Web 前端
基于 Gradio 的垃圾分类识别 Web 应用,上传图片即可预测垃圾类别。
## 依赖
除项目根目录 `requirements.txt`Web 前端额外依赖:
| 包 | 版本 | 说明 |
|---|---|---|
| `gradio` | `>=4.0,<5.0` | Web UI 框架 |
| `pydantic` | `>=2.5,<2.10` | gradio 4.x 兼容性约束(新版会报 `"const" in schema` 错误) |
> 安装:`pip install gradio>=4.0,<5.0 pydantic>=2.5,<2.10`
## 启动前准备
1. **确保 `best_model.pth` 存在**
在项目根目录(`trash-division/`)下放置训练好的模型权重。如没有,先运行:
```bash
cd .. && python Train.py
```
2. **安装依赖**(如还未安装):
```bash
pip install -r ../requirements.txt
```
## 启动
`web/` 目录下运行:
```bash
python app.py
```
或者在项目根目录运行:
```bash
python web/app.py
```
启动后浏览器会自动打开 `http://127.0.0.1:7860`
## 配置说明
可在 `app.py` 底部 `demo.launch()` 中调整以下参数:
| 参数 | 默认值 | 说明 |
|---|---|---|
| `server_name` | `127.0.0.1` | 本机访问。如需局域网内其他设备访问,改为 `0.0.0.0` |
| `server_port` | `7860` | 端口号,冲突时可换 |
| `share` | `False` | 改为 `True` 可生成临时公网链接分享给同学 |
| `inbrowser` | `True` | 启动后自动打开浏览器 |
## 兼容性
| 项 | 说明 |
|---|---|
| Python | `>=3.9,<3.10`Gradio 5.x 需 Python 3.10+ |
| PyTorch | `>=1.10` |
| 设备 | 自动选择 CUDA > Intel XPU > Apple MPS > CPU |

View file

@ -1,9 +1,17 @@
import sys
import os
# 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import gradio as gr import gradio as gr
import torch import torch
from torchvision import transforms from torchvision import transforms
from PIL import Image from PIL import Image
from Model import Net # 根据上传的 Model.py模型类名为 Net from Model import Net # 根据上传的 Model.py模型类名为 Net
# 项目根目录web/ 的上一级)
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 1. 基础配置与类别映射 # 1. 基础配置与类别映射
# 根据 Merge_classes.py1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾 # 根据 Merge_classes.py1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
@ -16,7 +24,8 @@ print(f"当前使用的推理设备: {device}")
model = Net(num_classes=4) model = Net(num_classes=4)
try: try:
# 采用与 Evaluate.py 一致的健壮加载方式 # 采用与 Evaluate.py 一致的健壮加载方式
state_dict = torch.load('best_model.pth', map_location=device) model_path = os.path.join(PROJECT_ROOT, 'best_model.pth')
state_dict = torch.load(model_path, map_location=device)
if 'model_state_dict' in state_dict: if 'model_state_dict' in state_dict:
state_dict = state_dict['model_state_dict'] state_dict = state_dict['model_state_dict']
elif 'model' in state_dict: elif 'model' in state_dict: