diff --git a/.gitignore b/.gitignore index fb1c79f..4dd24aa 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,10 @@ !/Finetune.py !/Curve.py !/Evaluate.py +!/baseline/ +!/baseline/__init__.py +!/baseline/VGG_KNN.py +!/baseline/compare_models.py !/training_log.csv !/confusion_matrix.png !/roc_curve.png diff --git a/AGENTS.md b/AGENTS.md index d7c3df2..02b7493 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -2,23 +2,25 @@ ## Project -CNN-based garbage classification (4 classes: 厨余垃圾/可回收物/其他垃圾/有害垃圾). ResNet-34 architecture, ~21M params, 256×256 RGB input, ~900 lines across 8 Python files. No package structure. +CNN-based garbage classification (4 classes: 厨余垃圾/可回收物/其他垃圾/有害垃圾). ResNet-34 architecture, ~21M params, 256×256 RGB input, ~900 lines across 11 Python files. No package structure. ## Pipeline (order matters) ```bash -python Merge_classes.py # merges 265 → 4 classes, creates ../trash_division_data/ultimate_4_class/ -python Train.py # trains the model, saves best_model.pth + training_log.csv -python Finetune.py # optional: freezes early layers, saves finetuned_model.pth + finetune_log.csv -python Evaluate.py # plots confusion matrix / ROC / PR curves from best_model.pth -python Curve.py # plots loss/f1/acc/lr curves from training_log.csv +python Merge_classes.py # merges 265 → 4 classes, creates ../trash_division_data/ultimate_4_class/ +python Train.py # trains the model, saves best_model.pth + training_log.csv +python Finetune.py # optional: freezes early layers, saves finetuned_model.pth + finetune_log.csv +python Evaluate.py # plots confusion matrix / ROC / PR curves from best_model.pth +python Curve.py # plots loss/f1/acc/lr curves from training_log.csv +python baseline/VGG_KNN.py # VGG16 feature extraction + KNN baseline +python baseline/compare_models.py # compares multiple models (ROC + accuracy bar chart) ``` Also usable standalone: `python Model.py` prints `torchsummary` parameter summary. ## Dependencies -No `requirements.txt` — install manually: `torch`, `torchvision`, `tqdm`, `matplotlib`, `pandas`, `Pillow`, `torchsummary`. `Evaluate.py` additionally needs `scikit-learn`. +No `requirements.txt` — install manually: `torch`, `torchvision`, `tqdm`, `matplotlib`, `pandas`, `Pillow`, `torchsummary`. `Evaluate.py` and `baseline/*.py` additionally need `scikit-learn`. ## Data setup @@ -26,7 +28,7 @@ Data expected **outside repo** at `../trash_division_data/` (sibling dir). `Merg ## .gitignore — whitelist pattern -`.gitignore` uses `*` (ignore everything) then un-ignores specific files with `!` patterns. **Any new file you add to the repo must be explicitly whitelisted** or it will be invisible to git. The current whitelist: `Dataloader.py`, `LICENSE`, `Merge_classes.py`, `Model.py`, `README.md`, `THIRD_PARTY_LICENSES.md`, `Train.py`, `.gitattributes`, `.gitignore`. +`.gitignore` uses `*` (ignore everything) then un-ignores specific files with `!` patterns. **Any new file you add to the repo must be explicitly whitelisted** or it will be invisible to git. The current whitelist includes: `Dataloader.py`, `LICENSE`, `Merge_classes.py`, `Model.py`, `README.md`, `THIRD_PARTY_LICENSES.md`, `Train.py`, `.gitattributes`, `.gitignore`, plus `Finetune.py`, `Curve.py`, `Evaluate.py`, `AGENTS.md`, 4× output PNG, `training_log.csv`, and `baseline/`. `best_model.pth` and `finetuned_model.pth` are **untracked** (~125 MB each) — back them up manually if needed. `Finetune.py`, `Curve.py`, `Evaluate.py`, `AGENTS.md`, `training_log*.csv`, and `finetune_log.csv` are also untracked (not in whitelist). @@ -60,6 +62,16 @@ Data expected **outside repo** at `../trash_division_data/` (sibling dir). `Merg - Saves `confusion_matrix.png`, `roc_curve.png`, `pr_curve.png` - Requires `scikit-learn` +### baseline/ (VGG_KNN.py + compare_models.py) + +- `baseline/VGG_KNN.py` can run standalone (`python baseline/VGG_KNN.py`) or be imported from `compare_models.py` +- Uses `sys.path.insert` at top so it can import root-level modules (`Model`, `Dataloader`) from subdirectory +- `compare_models.py` has a `MODELS` registry list — add new models by writing a `get_xxx_preds(train_loader, val_loader, device)` function and adding one line to the list; no plot code changes needed +- VGG16 feature dimension: 25088 (512 channels × 7×7 avgpool) +- KNN uses `predict_proba` (neighbor voting proportions) for ROC curves — coarse-grained but valid AUC +- Output: `baseline/roc_comparison.png`, `baseline/accuracy_bar.png`, `baseline/vgg_knn_confusion_matrix.png` +- Compare scripts output images to `baseline/` dir (not repo root) + ## Model architecture reference `Model.py` attribute names (for freezing / layer access): diff --git a/Baseline.py b/Baseline.py deleted file mode 100644 index 9be8a36..0000000 --- a/Baseline.py +++ /dev/null @@ -1,129 +0,0 @@ -""" -Baseline.py -VGG16 预训练模型特征提取 + KNN 四分类基线 -author: yukun-hh -date: 2026-5-14 -""" -import os -import numpy as np -import matplotlib.pyplot as plt -import matplotlib - -import torch -import torch.nn as nn -from torch.utils.data import DataLoader -from torchvision import models, transforms -from tqdm import tqdm - -from sklearn.neighbors import KNeighborsClassifier -from sklearn.metrics import ( - accuracy_score, f1_score, - confusion_matrix, ConfusionMatrixDisplay, - classification_report, -) - -from Dataloader import RobustImageFolder - -matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] -matplotlib.rcParams['axes.unicode_minus'] = False - -# ============================================================ -# ★★★ 可配置参数 ★★★ -# ============================================================ -DATA_ROOT = '../trash_division_data/ultimate_4_class/' -BATCH_SIZE = 32 -IMAGE_SIZE = 256 -NUM_WORKERS = 4 -K = 5 -# ============================================================ - -CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] - - -def load_vgg16_extractor(device): - try: - model = models.vgg16(weights='IMAGENET1K_V1') - except TypeError: - model = models.vgg16(pretrained=True) - model.classifier = nn.Identity() - model = model.to(device).eval() - for param in model.parameters(): - param.requires_grad = False - return model - - -def extract_features(model, loader, device): - model.eval() - all_features = [] - all_labels = [] - with torch.no_grad(): - for images, labels in tqdm(loader, desc='Extracting features'): - images = images.to(device) - feats = model(images) - all_features.append(feats.cpu().numpy()) - all_labels.append(labels.numpy()) - return np.concatenate(all_features), np.concatenate(all_labels) - - -if __name__ == '__main__': - val_transform = transforms.Compose([ - transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), - transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]), - ]) - - train_dataset = RobustImageFolder( - root=os.path.join(DATA_ROOT, 'train'), - transform=val_transform, - ) - val_dataset = RobustImageFolder( - root=os.path.join(DATA_ROOT, 'val'), - transform=val_transform, - ) - print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}") - - train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, - shuffle=False, num_workers=NUM_WORKERS, - pin_memory=True, drop_last=False) - val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, - shuffle=False, num_workers=NUM_WORKERS, - pin_memory=True, drop_last=False) - - device = torch.device('cuda' if torch.cuda.is_available() - else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() - else 'cpu') - print(f"Device: {device}") - - extractor = load_vgg16_extractor(device) - print("VGG16 特征提取器加载完成 (classifier 已移除)") - - print("提取训练集特征 ...") - train_feats, train_labels = extract_features(extractor, train_loader, device) - print(f"训练特征: {train_feats.shape}") - - print("提取验证集特征 ...") - val_feats, val_labels = extract_features(extractor, val_loader, device) - print(f"验证特征: {val_feats.shape}") - - knn = KNeighborsClassifier(n_neighbors=K, n_jobs=-1) - knn.fit(train_feats, train_labels) - print(f"KNN (K={K}) 训练完成") - - val_preds = knn.predict(val_feats) - - acc = accuracy_score(val_labels, val_preds) - macro_f1 = f1_score(val_labels, val_preds, average='macro') - print(f"\n验证集 Accuracy: {acc:.4f}") - print(f"验证集 Macro-F1: {macro_f1:.4f}") - print(f"\n分类报告:\n{classification_report(val_labels, val_preds, target_names=CLASS_NAMES)}") - - cm = confusion_matrix(val_labels, val_preds) - fig, ax = plt.subplots(figsize=(8, 7)) - ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot( - ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) - ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14) - plt.tight_layout() - plt.savefig('baseline_confusion_matrix.png', dpi=150, bbox_inches='tight') - plt.show() - print("混淆矩阵已保存: baseline_confusion_matrix.png") diff --git a/README.md b/README.md index 1a85aec..6c47516 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,9 @@ | 文件 | 功能 | |---|---| -| `Baseline.py` | 基线模型,VGG16 预训练特征提取 + KNN 四分类 | +| `Baseline.py` → `baseline/` | 基线模型目录,VGG16+KNN 及多模型对比 | +| `baseline/VGG_KNN.py` | VGG16 预训练特征提取 + KNN 四分类 | +| `baseline/compare_models.py` | 多模型 ROC 曲线与准确率柱状图对比 | | `Train.py` | 训练主脚本,包含训练循环、验证、评估 | | `Finetune.py` | 微调脚本,冻结浅层后微调深层网络 | | `Dataloader.py` | 数据加载模块,包含 RobustImageFolder 和 DataLoader 创建 | @@ -141,7 +143,9 @@ ``` trash-division/ ├── AGENTS.md # AI 助手指南 -├── Baseline.py # 基线模型脚本 +├── baseline/ # 基线模型目录 +│ ├── VGG_KNN.py # VGG16 + KNN 分类脚本 +│ └── compare_models.py # 多模型对比脚本 ├── best_model.pth # 最佳模型权重(不纳入版本控制) ├── Curve.py # 训练曲线绘制脚本 ├── Dataloader.py # 数据加载模块 diff --git a/baseline/VGG_KNN.py b/baseline/VGG_KNN.py new file mode 100644 index 0000000..2e9a466 --- /dev/null +++ b/baseline/VGG_KNN.py @@ -0,0 +1,145 @@ +""" +baseline/VGG_KNN.py +VGG16 预训练模型特征提取 + KNN 四分类基线 +可独立运行,也可被 compare_models.py 导入复用 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +import torch.nn as nn +from torch.utils.data import DataLoader +from torchvision import models, transforms +from tqdm import tqdm + +from sklearn.neighbors import KNeighborsClassifier +from sklearn.metrics import ( + accuracy_score, f1_score, + confusion_matrix, ConfusionMatrixDisplay, + classification_report, +) + +from Dataloader import RobustImageFolder + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] + + +def load_vgg16_extractor(device): + try: + model = models.vgg16(weights='IMAGENET1K_V1') + except TypeError: + model = models.vgg16(pretrained=True) + model.classifier = nn.Identity() + model = model.to(device).eval() + for param in model.parameters(): + param.requires_grad = False + return model + + +def extract_features(model, loader, device): + model.eval() + all_features = [] + all_labels = [] + with torch.no_grad(): + for images, labels in tqdm(loader, desc='Extracting features'): + images = images.to(device) + feats = model(images) + all_features.append(feats.cpu().numpy()) + all_labels.append(labels.numpy()) + return np.concatenate(all_features), np.concatenate(all_labels) + + +class VGGKNNBaseline: + def __init__(self, k=5, device='cpu', + data_root='../trash_division_data/ultimate_4_class/', + image_size=256, batch_size=32, num_workers=4): + self.k = k + self.device = device + self.data_root = data_root + self.image_size = image_size + self.batch_size = batch_size + self.num_workers = num_workers + self.extractor = load_vgg16_extractor(device) + self.knn = KNeighborsClassifier(n_neighbors=k, n_jobs=-1) + + def _get_loader(self, split): + transform = transforms.Compose([ + transforms.Resize((self.image_size, self.image_size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + dataset = RobustImageFolder( + root=os.path.join(self.data_root, split), + transform=transform, + ) + print(f" {split}: {len(dataset)} 张") + return DataLoader(dataset, batch_size=self.batch_size, + shuffle=False, num_workers=self.num_workers, + pin_memory=True, drop_last=False) + + def fit(self, train_loader=None): + if train_loader is None: + train_loader = self._get_loader('train') + print(" 提取训练集特征 ...") + train_feats, train_labels = extract_features(self.extractor, train_loader, self.device) + self.knn.fit(train_feats, train_labels) + + def predict(self, val_loader=None): + if val_loader is None: + val_loader = self._get_loader('val') + print(" 提取验证集特征 ...") + val_feats, val_labels = extract_features(self.extractor, val_loader, self.device) + preds = self.knn.predict(val_feats) + probs = self.knn.predict_proba(val_feats) + return val_labels, preds, probs + + +if __name__ == '__main__': + DATA_ROOT = '../trash_division_data/ultimate_4_class/' + BATCH_SIZE = 32 + IMAGE_SIZE = 256 + NUM_WORKERS = 4 + K = 5 + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + baseline = VGGKNNBaseline(k=K, device=device, + data_root=DATA_ROOT, image_size=IMAGE_SIZE, + batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + + train_loader = baseline._get_loader('train') + val_loader = baseline._get_loader('val') + + baseline.fit(train_loader) + y_true, y_preds, y_probs = baseline.predict(val_loader) + + acc = accuracy_score(y_true, y_preds) + macro_f1 = f1_score(y_true, y_preds, average='macro') + print(f"\n验证集 Accuracy: {acc:.4f}") + print(f"验证集 Macro-F1: {macro_f1:.4f}") + print(f"\n分类报告:\n{classification_report(y_true, y_preds, target_names=CLASS_NAMES)}") + + cm = confusion_matrix(y_true, y_preds) + fig, ax = plt.subplots(figsize=(8, 7)) + ConfusionMatrixDisplay(cm, display_labels=CLASS_NAMES).plot( + ax=ax, cmap='Blues', values_format='d', xticks_rotation=30) + ax.set_title(f'Baseline Confusion Matrix (VGG16 + KNN, K={K})', fontsize=14) + plt.tight_layout() + out_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'vgg_knn_confusion_matrix.png') + plt.savefig(out_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"混淆矩阵已保存: {out_path}") diff --git a/baseline/__init__.py b/baseline/__init__.py new file mode 100644 index 0000000..7471dfe --- /dev/null +++ b/baseline/__init__.py @@ -0,0 +1 @@ +# baseline package diff --git a/baseline/compare_models.py b/baseline/compare_models.py new file mode 100644 index 0000000..7175071 --- /dev/null +++ b/baseline/compare_models.py @@ -0,0 +1,180 @@ +""" +baseline/compare_models.py +多模型对比:ROC 曲线 + 准确率柱状图 +添加新模型只需在 MODELS 列表加一行,无需修改绘图代码 +author: yukun-hh +date: 2026-5-14 +""" +import sys, os +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + +import torch +from torch.utils.data import DataLoader +from torchvision import transforms +from tqdm import tqdm + +from sklearn.metrics import roc_curve, auc, accuracy_score + +from Model import Net +from Dataloader import RobustImageFolder +from baseline.VGG_KNN import VGGKNNBaseline + +matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans'] +matplotlib.rcParams['axes.unicode_minus'] = False + +# ============================================================ +# ★★★ 可配置参数 ★★★ +# ============================================================ +DATA_ROOT = '../trash_division_data/ultimate_4_class/' +BATCH_SIZE = 32 +IMAGE_SIZE = 256 +NUM_WORKERS = 4 +K_KNN = 5 +# ============================================================ + +CLASS_NAMES = ['厨余垃圾', '可回收物', '其他垃圾', '有害垃圾'] +NUM_CLASSES = 4 + +# ============================================================ +# 预测函数 — 每个函数签名: (train_loader, val_loader, device) -> (y_true, y_preds, y_probs) +# ============================================================ + +def get_resnet34_preds(train_loader, val_loader, device): + model = Net(num_classes=NUM_CLASSES) + state_dict = torch.load('best_model.pth', map_location='cpu') + if 'model_state_dict' in state_dict: + state_dict = state_dict['model_state_dict'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + model.load_state_dict(state_dict) + model = model.to(device).eval() + + y_true, y_preds, y_probs = [], [], [] + with torch.no_grad(): + for images, labels in tqdm(val_loader, desc='ResNet-34'): + images, labels = images.to(device), labels + logits = model(images) + probs = torch.softmax(logits, dim=1) + preds = probs.argmax(dim=1) + y_true.append(labels.numpy()) + y_preds.append(preds.cpu().numpy()) + y_probs.append(probs.cpu().numpy()) + return np.concatenate(y_true), np.concatenate(y_preds), np.concatenate(y_probs) + + +def get_vgg_knn_preds(train_loader, val_loader, device): + baseline = VGGKNNBaseline(k=K_KNN, device=device) + baseline.fit(train_loader) + return baseline.predict(val_loader) + + +# ============================================================ +# ★ 模型注册表 — 添加新模型只需在这里加一行 ★ +# ============================================================ + +MODELS = [ + ('ResNet-34', get_resnet34_preds), + ('VGG16 + KNN (K=5)', get_vgg_knn_preds), + # 未来轻松扩展示例: + # ('ResNet-18 (pretrained)', get_resnet18_preds), + # ('ResNet-50 (pretrained)', get_resnet50_preds), + # ('ResNet-34 (finetuned)', get_finetuned_preds), +] + +# ============================================================ +# 调色板 (扩展时无需修改) +# ============================================================ +COLORS = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', + '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] + + +def compute_macro_roc(y_true, y_probs): + one_hot = np.eye(NUM_CLASSES)[y_true] + fpr_dict, tpr_dict = {}, {} + for c in range(NUM_CLASSES): + fpr_dict[c], tpr_dict[c], _ = roc_curve(one_hot[:, c], y_probs[:, c]) + all_fpr = np.unique(np.concatenate([fpr_dict[c] for c in range(NUM_CLASSES)])) + mean_tpr = np.zeros_like(all_fpr) + for c in range(NUM_CLASSES): + mean_tpr += np.interp(all_fpr, fpr_dict[c], tpr_dict[c]) + mean_tpr /= NUM_CLASSES + macro_auc = auc(all_fpr, mean_tpr) + return all_fpr, mean_tpr, macro_auc + + +if __name__ == '__main__': + out_dir = os.path.dirname(os.path.abspath(__file__)) + + device = torch.device('cuda' if torch.cuda.is_available() + else 'xpu' if hasattr(torch, 'xpu') and torch.xpu.is_available() + else 'cpu') + print(f"Device: {device}") + + val_transform = transforms.Compose([ + transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]), + ]) + + train_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'train'), + transform=val_transform) + val_dataset = RobustImageFolder(root=os.path.join(DATA_ROOT, 'val'), + transform=val_transform) + print(f"训练集: {len(train_dataset)} 验证集: {len(val_dataset)}") + + train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, + num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) + val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, + num_workers=NUM_WORKERS, pin_memory=True, drop_last=False) + + # ———— 评估所有模型 ———— + results = {} + for name, func in MODELS: + print(f"\n{'='*50}") + print(f"评估: {name}") + y_true, y_preds, y_probs = func(train_loader, val_loader, device) + acc = accuracy_score(y_true, y_preds) + fpr, tpr, roc_auc = compute_macro_roc(y_true, y_probs) + results[name] = {'y_true': y_true, 'y_preds': y_preds, 'y_probs': y_probs, + 'acc': acc, 'fpr': fpr, 'tpr': tpr, 'auc': roc_auc} + print(f" Accuracy: {acc:.4f} | Macro-AUC: {roc_auc:.4f}") + + # ———— ROC 对比图 ———— + fig, ax = plt.subplots(figsize=(8, 7)) + for i, (name, r) in enumerate(results.items()): + color = COLORS[i % len(COLORS)] + ax.plot(r['fpr'], r['tpr'], color=color, lw=2, + label=f"{name} (AUC={r['auc']:.4f})") + ax.plot([0, 1], [0, 1], 'k--', lw=1, alpha=0.5) + ax.set_xlim(0, 1); ax.set_ylim(0, 1.05) + ax.set_xlabel('False Positive Rate'); ax.set_ylabel('True Positive Rate') + ax.set_title('ROC Curve Comparison (Macro-Average)', fontsize=14) + ax.legend(loc='lower right'); ax.grid(True, alpha=0.3) + plt.tight_layout() + roc_path = os.path.join(out_dir, 'roc_comparison.png') + plt.savefig(roc_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"\nROC 对比图已保存: {roc_path}") + + # ———— 准确率柱状图 ———— + names = list(results.keys()) + accs = [results[n]['acc'] for n in names] + fig, ax = plt.subplots(figsize=(8, 5)) + bar_colors = [COLORS[i % len(COLORS)] for i in range(len(names))] + bars = ax.bar(names, accs, color=bar_colors, edgecolor='white', linewidth=1.2) + for bar, acc in zip(bars, accs): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005, + f'{acc:.4f}', ha='center', va='bottom', fontsize=12, fontweight='bold') + ax.set_ylim(0, max(accs) * 1.15) + ax.set_ylabel('Accuracy'); ax.set_title('Accuracy Comparison', fontsize=14) + ax.grid(True, alpha=0.3, axis='y') + plt.tight_layout() + bar_path = os.path.join(out_dir, 'accuracy_bar.png') + plt.savefig(bar_path, dpi=150, bbox_inches='tight') + plt.show() + print(f"准确率柱状图已保存: {bar_path}") diff --git a/training_log.csv b/training_log.csv index 9795ca2..6bd1cfe 100644 --- a/training_log.csv +++ b/training_log.csv @@ -1,81 +1 @@ epoch,train_loss,train_f1,train_acc,val_loss,val_f1,val_acc,lr,best -1,1.0409312975676923,0.4329540729522705,48.04254100337675,1.1043149583345566,0.4398210048675537,48.66548042704626,0.004998072590601808,best -2,0.9862563783744079,0.4695238769054413,52.5943680656054,0.9867177669753319,0.5062971115112305,58.397207774431976,0.00499229333433282,best -3,0.9462850892451784,0.49421826004981995,55.40279787747226,1.0144445673589866,0.4907984733581543,53.15494114426499,0.004982671142387316, -4,0.910117163585685,0.514958381652832,57.832097202122526,0.8787865286946395,0.5453917980194092,62.544483985765126,0.004969220851487844,best -5,0.8786031692946986,0.5320333242416382,59.74282440906898,1.0686318878927787,0.4803737998008728,52.73063235696688,0.004951963201008076, -6,0.8518873820889128,0.5481140613555908,61.51938615533044,0.7650798693964196,0.6073676347732544,68.98439638653161,0.004930924800994191,best -7,0.8256270786701512,0.5604796409606934,62.90249638205499,0.8796401012116773,0.5789190530776978,62.63345195729537,0.004906138091134118, -8,0.8003699506646013,0.5742803812026978,64.3014351181862,0.9246643470014378,0.5521833896636963,60.88831097727895,0.004877641290737884, -9,0.780536473588097,0.5827116966247559,65.25642185238785,0.8404132533719564,0.5876226425170898,65.89789214344374,0.00484547833980621, -10,0.7604798049209087,0.595557451248169,66.54079232995659,0.9228097118810533,0.564703643321991,60.77881193539557,0.004809698831278217, -11,0.7410275131047088,0.6043155789375305,67.35784491075735,0.7576604621266131,0.6295210123062134,69.83985765124555,0.0047703579345627035,best -12,0.7195374228732343,0.6127941608428955,68.07766521948867,0.9624476881507583,0.5630610585212708,61.59321105940323,0.00472751631047092, -13,0.6997139808122973,0.6210756301879883,68.85175470332851,0.7615812296349155,0.6177672147750854,69.12127018888584,0.004681240017681994, -14,0.6824904908837182,0.630592942237854,69.65448625180898,0.6715762626299165,0.6534035205841064,73.48754448398577,0.004631600410885231,best -15,0.6653590450468583,0.6379610300064087,70.39088880849012,0.694461440047988,0.6517682075500488,73.0906104571585,0.004578674030756364, -16,0.6514209577758935,0.6478185653686523,71.27879281234925,0.7036816360785346,0.6470745801925659,71.76977826444019,0.004522542485937369, -17,0.6330186040776395,0.6530008316040039,71.7935962373372,0.7222367905930823,0.6418735980987549,71.07856556255133,0.004463292327201863, -18,0.6166394593717968,0.6634106040000916,72.6038651712494,0.6067476719332303,0.6886636018753052,77.49110320284697,0.004401014914000078,best -19,0.5973944908975692,0.6721534729003906,73.47367945007235,0.6952472055509845,0.6622275114059448,71.79715302491103,0.004335806273589214, -20,0.5820678306183721,0.6758297681808472,73.73145803183792,0.7708474785342401,0.6217234134674072,68.74486723241172,0.004267766952966369, -21,0.5650806297110982,0.6851130723953247,74.5741377231066,0.7461620579141478,0.6384793519973755,71.35231316725978,0.004197001863832355, -22,0.5500074683958588,0.6915749311447144,75.099493487699,0.6420613189380593,0.672810435295105,74.9452504790583,0.00412362012082546, -23,0.5367840825001858,0.6979560852050781,75.66102870236372,0.6252713002082977,0.6949211359024048,75.32849712565014,0.0040477348732745845,best -24,0.5234906795055925,0.7052106857299805,76.26025084418717,0.7471277477021352,0.6447888016700745,69.70982753900904,0.003969463130731182, -25,0.5044557179049829,0.7132176160812378,76.91148094548963,0.6325626891507145,0.6857903003692627,75.52012044894607,0.0038889255825490052, -26,0.4938347232885195,0.7174828052520752,77.23784973468403,0.5635758755127437,0.70375657081604,78.50396934026827,0.003806246411789872,best -27,0.4793313278116239,0.7242900133132935,77.87475880366618,0.5505193975847648,0.7201660871505737,78.89405967697783,0.003721553103742388,best -28,0.46573570758837524,0.7336312532424927,78.60437771345876,0.640248272807638,0.6859503984451294,75.13003011223651,0.003634976249348867, -29,0.44927708754289913,0.737967312335968,78.9352689339122,0.6151526539644867,0.7065733075141907,74.69887763482069,0.00354664934384357, -30,0.4373503708129221,0.7443608045578003,79.40409430776653,0.5578661908719627,0.7272701263427734,78.20969066520668,0.0034567085809127244,best -31,0.42717206794400175,0.7488712072372437,79.7779486251809,0.58909761693554,0.7034546136856079,76.88201478237066,0.003365292642693732, -32,0.4100124511779706,0.7580570578575134,80.60630728412929,0.6458172935624336,0.6865078210830688,75.32849712565014,0.0032725424859373683, -33,0.3993677339991451,0.763430118560791,80.9876989869754,0.47995706558097007,0.754202127456665,81.65891048453327,0.003178601124662685,best -34,0.3858378949555808,0.7697042226791382,81.54772672455378,0.6427663844838523,0.6931804418563843,74.28141253764029,0.0030836134096397633, -35,0.37397055771404913,0.7764154672622681,81.9992161119151,0.6085299244000101,0.7046636343002319,77.08048179578428,0.0029877258050403205, -36,0.3597575335889638,0.7818952798843384,82.53587795465509,0.5254679805051781,0.7415529489517212,78.89405967697783,0.002891086162600577, -37,0.3487578573732404,0.7871347665786743,82.8999336710082,0.5125140355052995,0.748577356338501,80.1875171092253,0.002793843493644594, -38,0.3325814358527052,0.7965956330299377,83.59035817655571,0.5408413317834649,0.7290798425674438,79.87270736381056,0.002696147739319612, -39,0.3261546248608721,0.7988470792770386,83.75316570188133,0.5301555857539602,0.7376729249954224,80.06433068710649,0.002598149539397671, -40,0.30964642827472305,0.8070269823074341,84.52725518572117,0.5468305750249544,0.7365171313285828,78.73665480427046,0.0024999999999999996, -41,0.3009217412674191,0.8119726777076721,84.96140858658949,0.46898612490165015,0.7599539756774902,82.18587462359704,0.002401850460602329,best -42,0.28925693789887874,0.8200639486312866,85.6458031837916,0.5167866427621677,0.7465909123420715,80.83766767040788,0.0023038522606803878, -43,0.2707157838268379,0.8313596248626709,86.46360950313556,0.5156203284349763,0.7596548199653625,80.6049822064057,0.0022061565063554063, -44,0.2580273799019566,0.836384654045105,86.8412325132658,0.5318487190707494,0.746901273727417,80.27648508075555,0.0021089138373994237, -45,0.2504911308580703,0.8410984873771667,87.30779667149059,0.49164087495639364,0.763725221157074,81.82315904735833,0.00201227419495968,best -46,0.24104372076995695,0.8451772928237915,87.64094910757356,0.5290114263981752,0.7580969333648682,80.94032302217356,0.0019163865903602372, -47,0.22337641519549614,0.8570870161056519,88.56201760733236,0.43634469677838694,0.7913081049919128,85.10128661374213,0.0018213988753373142,best -48,0.2128122905210861,0.8645581603050232,89.1152616980222,0.43183545479456625,0.7972898483276367,85.49137695045168,0.001727457514062632,best -49,0.2003101470318182,0.8717849254608154,89.69187168355042,0.4289672785945178,0.806715726852417,85.7993430057487,0.0016347073573062686,best -50,0.1888495338707803,0.8796613216400146,90.34687047756874,0.4568697272132202,0.7956517338752747,84.64275937585546,0.0015432914190872762, -51,0.1756466486088274,0.886497437953949,90.97096599131693,0.4541556305781541,0.7938134670257568,84.45113605255953,0.001453350656156431, -52,0.16742963044469736,0.8907681703567505,91.3101483357453,0.4230425570913775,0.8147625923156738,86.100465370928,0.0013650237506511336,best -53,0.15311133117022804,0.9007841944694519,92.11212614568258,0.4146759752586969,0.8208259344100952,86.83274021352314,0.0012784468962576128,best -54,0.1423091164071722,0.9078108668327332,92.6714001447178,0.4709351422719488,0.8029968738555908,85.71721872433616,0.0011937535882101285, -55,0.13160189816137902,0.9135022163391113,93.10932223830197,0.40829240685941226,0.8264325857162476,87.42814125376403,0.0011110744174509947,best -56,0.12707359487800943,0.9174070358276367,93.4409671972986,0.42565728100299705,0.8230471611022949,87.65398302764851,0.0010305368692688178, -57,0.11291898237482914,0.925153374671936,94.10953328509407,0.43774206922898184,0.8247347474098206,87.5376402956474,0.0009522651267254161, -58,0.10370979833767516,0.9329074025154114,94.62584418716835,0.4068256767521223,0.8333848118782043,88.05091705447578,0.0008763798791745416,best -59,0.0946220946482491,0.9380815029144287,95.09768451519537,0.41103389083751746,0.8357677459716797,88.4889132220093,0.0008029981361676465,best -60,0.08804238645213414,0.9423004388809204,95.45194163048721,0.4207381762007377,0.8308929204940796,88.17410347659458,0.0007322330470336316, -61,0.07913849578165794,0.9495129585266113,95.9916184273999,0.4083157278823748,0.8420299291610718,89.00218998083767,0.0006641937264107861,best -62,0.06981624146470565,0.9559226036071777,96.49662325132658,0.41332166066585485,0.843511700630188,88.98850260060225,0.0005989850859999229,best -63,0.06394793773276639,0.9602090120315552,96.79811866859623,0.41052334102789995,0.8489691019058228,89.35121817684096,0.0005367076727981376,best -64,0.057007493751794744,0.9636315107345581,97.11016642547034,0.3970402057488879,0.8546013832092285,90.04927456884752,0.00047745751406263185,best -65,0.05285091761448427,0.967146635055542,97.40035576459238,0.4247235585867853,0.8433754444122314,89.34437448672324,0.0004213259692436376, -66,0.04614944799553407,0.9710246324539185,97.6912988422576,0.4035414461747053,0.8538572788238525,89.85080755543389,0.00036839958911476966, -67,0.042909558690492254,0.9727644920349121,97.8428002894356,0.41578795896298,0.8530543446540833,89.91924445661101,0.0003187599823180077, -68,0.03640206224887977,0.9769999980926514,98.15710926193921,0.42073161891477256,0.8551151156425476,89.94661921708185,0.0002724836895290806,best -69,0.034432517173010414,0.9790080785751343,98.34328268210324,0.4133664407223553,0.8576940298080444,90.28880372296743,0.00022964206543729668,best -70,0.030669637766023813,0.9817556142807007,98.5249336710082,0.418308269951463,0.8569411039352417,90.12455516014235,0.00019030116872178321, -71,0.028112305183924133,0.9827903509140015,98.606337433671,0.4151474575991667,0.8596312999725342,90.37092800437996,0.00015452166019378966,best -72,0.024704152367817256,0.9853801727294922,98.81135431741437,0.4153705465811558,0.8635820746421814,90.68573774979468,0.0001223587092621162,best -73,0.024846541488804174,0.9855506420135498,98.8369814278823,0.4177400436290088,0.8632140159606934,90.59676977826444,9.38619088658821e-05, -74,0.022639600746625622,0.9868491888046265,98.94702725518572,0.41732572613841307,0.8648342490196228,90.78154941144265,6.907519900580863e-05,best -75,0.02120214593173326,0.9878177642822266,99.0231548480463,0.4163925825270714,0.866214394569397,90.89789214344374,4.803679899192394e-05,best -76,0.019741657997631577,0.9883521795272827,99.06385672937772,0.42005763620917286,0.8647006750106812,90.82945524226663,3.077914851215586e-05, -77,0.019116416042495393,0.9889511466026306,99.10003617945007,0.4159400745789841,0.8657370805740356,90.84998631261976,1.7328857612684272e-05, -78,0.019259902796210714,0.9888157844543457,99.0962674867342,0.4192042892654481,0.8641382455825806,90.69942513003011,7.706665667180091e-06, -79,0.01933925595445387,0.9887675046920776,99.0759165460685,0.4180937778044573,0.8662786483764648,90.84998631261976,1.9274093981927482e-06,best -80,0.01922732148408437,0.9889604449272156,99.10078991799324,0.41794140280912484,0.864332914352417,90.82261155214891,0.0,