import sys import os # 确保可以从 web/ 目录或项目根目录运行,都能找到 Model.py 和 best_model.pth sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import gradio as gr import torch from torchvision import transforms from PIL import Image from Model import Net # 根据上传的 Model.py,模型类名为 Net # 项目根目录(web/ 的上一级) PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 1. 基础配置与类别映射 # 根据 Merge_classes.py,1=厨余垃圾, 2=可回收物, 3=其他垃圾, 4=有害垃圾 class_names = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] # 设备自动选择逻辑,保持与 Train.py 和 Evaluate.py 一致 device = torch.device('cuda' if torch.cuda.is_available() else 'xpu' if torch.xpu.is_available() else 'mps' if torch.mps.is_available() else 'cpu') print(f"当前使用的推理设备: {device}") # 2. 初始化模型并加载最佳权重 model = Net(num_classes=4) try: # 采用与 Evaluate.py 一致的健壮加载方式 model_path = os.path.join(PROJECT_ROOT, 'best_model.pth') state_dict = torch.load(model_path, map_location=device) if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict'] elif 'model' in state_dict: state_dict = state_dict['model'] model.load_state_dict(state_dict) model = model.to(device).eval() print("✅ 成功加载 best_model.pth 权重") except Exception as e: print(f"⚠️ 模型加载失败,请确保目录下存在 best_model.pth。错误信息: {e}") # 3. 定义数据预处理流程 (必须与 Evaluate.py 中的 val_transform 保持完全一致) transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 4. 核心推理函数 def predict(image): if image is None: return None # Gradio 传入的 pil 图像,确保转为 RGB 格式 image = image.convert('RGB') # 预处理并增加 batch 维度 input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(input_tensor) # 使用 Softmax 将 logits 转换为 0~1 的概率分布 probabilities = torch.softmax(outputs, dim=1)[0] # 组装为 Gradio Label 组件需要的字典格式 { "类别名": 概率值 } result_dict = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))} return result_dict # 5. 构建与美化 Gradio 界面 with gr.Blocks(theme=gr.themes.Soft(), title="Trash Division 垃圾分类识别") as demo: gr.Markdown( """
基于 ResNet-34 架构,支持精准识别:厨余垃圾、可回收物、其他垃圾、有害垃圾。
同济大学 Python 人工智能程序设计课程小组作业