diff --git a/.gitignore b/.gitignore index 3117cf2..2e1c1d6 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,8 @@ !/baseline/roc_comparison.png !/baseline/pr_comparison.png !/baseline/accuracy_bar.png +!/web/ +!/web/online.jpg !/training_log.csv !/confusion_matrix.png !/roc_curve.png diff --git a/README.md b/README.md index 10b86e6..c567308 100644 --- a/README.md +++ b/README.md @@ -6,11 +6,14 @@ 基于 ResNet-34 架构的 CNN 模型(约 21M 参数),将生活垃圾分为厨余垃圾、可回收物、其他垃圾、有害垃圾四个类别,输入为 256×256 RGB 图像。 +![demo](web/online.jpg) + --- ## 目录 - [项目特点](#项目特点) +- [Web 演示](#web-演示) - [模型架构](#模型架构) - [数据集](#数据集) - [环境要求](#环境要求) @@ -34,6 +37,36 @@ - **断点续训**:自动检测 `best_model.pth` 并加载继续训练 - **多设备支持**:自动选择 CUDA > Intel XPU > CPU +## Web 演示 + +本项目提供基于 Gradio 的 Web 界面,上传图片即可实时预测垃圾类别。 + +![demo](web/online.jpg) + +### 启动 + +1. 确保 `best_model.pth` 位于项目根目录(如没有,先运行 `python Train.py`) +2. 安装 Gradio 依赖: + ```bash + pip install "gradio>=4.0,<5.0" "pydantic>=2.5,<2.10" + ``` +3. 启动 Web 服务: + ```bash + python web/app.py + ``` +4. 浏览器打开 `http://127.0.0.1:7860` + +### 配置 + +可在 `web/app.py` 底部 `demo.launch()` 中调整: + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `server_name` | `127.0.0.1` | 局域网访问改为 `0.0.0.0` | +| `server_port` | `7860` | 端口冲突时可换 | +| `share` | `False` | 改为 `True` 可生成临时公网链接 | +| `inbrowser` | `True` | 启动后自动打开浏览器 | + ## 模型架构 模型基于标准 ResNet-34 架构,使用 BasicBlock 构建。 @@ -132,6 +165,7 @@ pip install -r requirements.txt | `baseline/ResNet34_Pretrained_10pct.py` | ResNet-34 ImageNet 预训练 + 10% 数据微调 | | `baseline/HOG_Baseline.py` | HOG + 颜色直方图 + LogisticRegression(纯传统 CV) | | `baseline/compare_models.py` | 多模型对比(ROC / PR 曲线 + 准确率柱状图) | +| `web/app.py` | Gradio Web 前端,上传图片实时分类 | | `training_log.csv` | 训练日志,记录每轮 epoch 的 loss、f1、acc、lr | | `best_model.pth` | 训练好的最佳模型权重(约 125 MB,不纳入版本控制) | | `THIRD_PARTY_LICENSES.md` | 第三方数据集许可证声明 | @@ -148,6 +182,9 @@ trash-division/ │ ├── roc_comparison.png # 多模型 ROC 对比(compare_models.py 输出) │ ├── pr_comparison.png # 多模型 PR 对比(compare_models.py 输出) │ └── accuracy_bar.png # 多模型准确率对比(compare_models.py 输出) +├── web/ # Web 前端目录 +│ ├── app.py # Gradio 应用入口 +│ └── online.jpg # Web 演示截图 ├── best_model.pth # 最佳模型权重(不纳入版本控制) ├── Curve.py # 训练曲线绘制脚本 ├── Dataloader.py # 数据加载模块 diff --git a/web/online.jpg b/web/online.jpg new file mode 100644 index 0000000..ac4d441 Binary files /dev/null and b/web/online.jpg differ