新增web推理前端:Gradio网页界面,支持上传图片实时分类
This commit is contained in:
parent
0e9bd58dce
commit
be45ad32eb
4 changed files with 77 additions and 2 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -8,7 +8,9 @@
|
||||||
!/requirements.txt
|
!/requirements.txt
|
||||||
!/THIRD_PARTY_LICENSES.md
|
!/THIRD_PARTY_LICENSES.md
|
||||||
!/Train.py
|
!/Train.py
|
||||||
!/app.py
|
!/web/
|
||||||
|
!/web/app.py
|
||||||
|
!/web/README.md
|
||||||
!/Baseline.py
|
!/Baseline.py
|
||||||
!/Finetune.py
|
!/Finetune.py
|
||||||
!/Curve.py
|
!/Curve.py
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
torch>=1.10
|
torch>=1.10
|
||||||
torchvision>=0.11
|
torchvision>=0.11
|
||||||
|
gradio>=4.0,<5.0
|
||||||
|
pydantic>=2.5,<2.10
|
||||||
tqdm
|
tqdm
|
||||||
matplotlib
|
matplotlib
|
||||||
pandas
|
pandas
|
||||||
|
|
|
||||||
62
web/README.md
Normal file
62
web/README.md
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
# Trash Division Web 前端
|
||||||
|
|
||||||
|
基于 Gradio 的垃圾分类识别 Web 应用,上传图片即可预测垃圾类别。
|
||||||
|
|
||||||
|
## 依赖
|
||||||
|
|
||||||
|
除项目根目录 `requirements.txt` 外,Web 前端额外依赖:
|
||||||
|
|
||||||
|
| 包 | 版本 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `gradio` | `>=4.0,<5.0` | Web UI 框架 |
|
||||||
|
| `pydantic` | `>=2.5,<2.10` | gradio 4.x 兼容性约束(新版会报 `"const" in schema` 错误) |
|
||||||
|
|
||||||
|
> 安装:`pip install gradio>=4.0,<5.0 pydantic>=2.5,<2.10`
|
||||||
|
|
||||||
|
## 启动前准备
|
||||||
|
|
||||||
|
1. **确保 `best_model.pth` 存在**
|
||||||
|
在项目根目录(`trash-division/`)下放置训练好的模型权重。如没有,先运行:
|
||||||
|
```bash
|
||||||
|
cd .. && python Train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **安装依赖**(如还未安装):
|
||||||
|
```bash
|
||||||
|
pip install -r ../requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
## 启动
|
||||||
|
|
||||||
|
在 `web/` 目录下运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
或者在项目根目录运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python web/app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
启动后浏览器会自动打开 `http://127.0.0.1:7860`。
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
可在 `app.py` 底部 `demo.launch()` 中调整以下参数:
|
||||||
|
|
||||||
|
| 参数 | 默认值 | 说明 |
|
||||||
|
|---|---|---|
|
||||||
|
| `server_name` | `127.0.0.1` | 本机访问。如需局域网内其他设备访问,改为 `0.0.0.0` |
|
||||||
|
| `server_port` | `7860` | 端口号,冲突时可换 |
|
||||||
|
| `share` | `False` | 改为 `True` 可生成临时公网链接分享给同学 |
|
||||||
|
| `inbrowser` | `True` | 启动后自动打开浏览器 |
|
||||||
|
|
||||||
|
## 兼容性
|
||||||
|
|
||||||
|
| 项 | 说明 |
|
||||||
|
|---|---|
|
||||||
|
| Python | `>=3.9,<3.10`(Gradio 5.x 需 Python 3.10+) |
|
||||||
|
| PyTorch | `>=1.10` |
|
||||||
|
| 设备 | 自动选择 CUDA > Intel XPU > Apple MPS > CPU |
|
||||||
|
|
@ -1,9 +1,17 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
# 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
from torchvision import transforms
|
from torchvision import transforms
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from Model import Net # 根据上传的 Model.py,模型类名为 Net
|
from Model import Net # 根据上传的 Model.py,模型类名为 Net
|
||||||
|
|
||||||
|
# 项目根目录(web/ 的上一级)
|
||||||
|
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
# 1. 基础配置与类别映射
|
# 1. 基础配置与类别映射
|
||||||
# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
|
# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾
|
||||||
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾']
|
||||||
|
|
@ -16,7 +24,8 @@ print(f"当前使用的推理设备: {device}")
|
||||||
model = Net(num_classes=4)
|
model = Net(num_classes=4)
|
||||||
try:
|
try:
|
||||||
# 采用与 Evaluate.py 一致的健壮加载方式
|
# 采用与 Evaluate.py 一致的健壮加载方式
|
||||||
state_dict = torch.load('best_model.pth', map_location=device)
|
model_path = os.path.join(PROJECT_ROOT, 'best_model.pth')
|
||||||
|
state_dict = torch.load(model_path, map_location=device)
|
||||||
if 'model_state_dict' in state_dict:
|
if 'model_state_dict' in state_dict:
|
||||||
state_dict = state_dict['model_state_dict']
|
state_dict = state_dict['model_state_dict']
|
||||||
elif 'model' in state_dict:
|
elif 'model' in state_dict:
|
||||||
Loading…
Reference in a new issue