diff --git a/.gitignore b/.gitignore index c650b04..3117cf2 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,9 @@ !/requirements.txt !/THIRD_PARTY_LICENSES.md !/Train.py +!/web/ +!/web/app.py +!/web/README.md !/Baseline.py !/Finetune.py !/Curve.py diff --git a/requirements.txt b/requirements.txt index f612349..3f85381 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ torch>=1.10 torchvision>=0.11 +gradio>=4.0,<5.0 +pydantic>=2.5,<2.10 tqdm matplotlib pandas diff --git a/web/README.md b/web/README.md new file mode 100644 index 0000000..b2934d7 --- /dev/null +++ b/web/README.md @@ -0,0 +1,62 @@ +# Trash Division Web 前端 + +基于 Gradio 的垃圾分类识别 Web 应用,上传图片即可预测垃圾类别。 + +## 依赖 + +除项目根目录 `requirements.txt` 外,Web 前端额外依赖: + +| 包 | 版本 | 说明 | +|---|---|---| +| `gradio` | `>=4.0,<5.0` | Web UI 框架 | +| `pydantic` | `>=2.5,<2.10` | gradio 4.x 兼容性约束(新版会报 `"const" in schema` 错误) | + +> 安装:`pip install gradio>=4.0,<5.0 pydantic>=2.5,<2.10` + +## 启动前准备 + +1. **确保 `best_model.pth` 存在** + 在项目根目录(`trash-division/`)下放置训练好的模型权重。如没有,先运行: + ```bash + cd .. && python Train.py + ``` + +2. **安装依赖**(如还未安装): + ```bash + pip install -r ../requirements.txt + ``` + +## 启动 + +在 `web/` 目录下运行: + +```bash +python app.py +``` + +或者在项目根目录运行: + +```bash +python web/app.py +``` + +启动后浏览器会自动打开 `http://127.0.0.1:7860`。 + +## 配置说明 + +可在 `app.py` 底部 `demo.launch()` 中调整以下参数: + +| 参数 | 默认值 | 说明 | +|---|---|---| +| `server_name` | `127.0.0.1` | 本机访问。如需局域网内其他设备访问,改为 `0.0.0.0` | +| `server_port` | `7860` | 端口号,冲突时可换 | +| `share` | `False` | 改为 `True` 可生成临时公网链接分享给同学 | +| `inbrowser` | `True` | 启动后自动打开浏览器 | + +## 兼容性 + +| 项 | 说明 | +|---|---| +| Python | `>=3.9,<3.10`(Gradio 5.x 需 Python 3.10+) | +| PyTorch | `>=1.10` | +| 设备 | 自动选择 CUDA > Intel XPU > Apple MPS > CPU | diff --git a/web/app.py b/web/app.py new file mode 100644 index 0000000..fe3471a --- /dev/null +++ b/web/app.py @@ -0,0 +1,104 @@ +import sys +import os +# 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import gradio as gr +import torch +from torchvision import transforms +from PIL import Image +from Model import Net # 根据上传的 Model.py,模型类名为 Net + +# 项目根目录(web/ 的上一级) +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# 1. 基础配置与类别映射 +# 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾 +class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + +# 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致 +device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu') +print(f"当前使用的推理设备: {device}") + +# 2. 初始化模型并加载最佳权重 +model = Net(num_classes=4) +try: + # 采用与 Evaluate.py 一致的健壮加载方式 + model_path = os.path.join(PROJECT_ROOT, 'best_model.pth') + state_dict = torch.load(model_path, map_location=device) + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + + model.load_state_dict(state_dict) + model = model.to(device).eval() + print("✅ 成功加载 best_model.pth 权重") +except Exception as e: + print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}") + +# 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致) +transform = transforms.Compose([ + transforms.Resize((256, 256)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) +]) + +# 4. 核心推理函数 +def predict(image): + if image is None: + return None + + # Gradio 传入的 pil 图像,确保转为 RGB 格式 + image = image.convert('RGB') + + # 预处理并增加 batch 维度 + input_tensor = transform(image).unsqueeze(0).to(device) + + with torch.no_grad(): + outputs = model(input_tensor) + # 使用 Softmax 将 logits 转换为 0~1 的概率分布 + probabilities = torch.softmax(outputs, dim=1)[0] + + # 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 } + result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} + return result_dict + +# 5. 构建与美化 Gradio 界面 +with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo: + gr.Markdown( + """ +
基于 ResNet-34 架构,支持精准识别:厨余垃圾、可回收物、其他垃圾、有害垃圾。
+同济大学 Python 人工智能程序设计课程小组作业
+