模型早停机制

Admin
发布于 2026-05-26 / 1 阅读
0
0
# EARLY STOPPING VERSION
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Load the dataset
transform = transforms.Compose([
    transforms.ToTensor()  # Automatically converts to tensor and scales to [0, 1]
])
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Define the model
class FashionMNISTModel(nn.Module):
    def __init__(self):
        super(FashionMNISTModel, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = FashionMNISTModel()

# Define the loss function and optimizer
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())

# Function to calculate accuracy
def get_accuracy(pred, labels):
    _, predictions = torch.max(pred, 1)
    correct = (predictions == labels).float().sum()
    accuracy = correct / labels.shape[0]
    return accuracy

# Train the model with accuracy reporting
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    total_loss, total_accuracy = 0, 0

    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)
        accuracy = get_accuracy(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_accuracy += accuracy.item()

        if batch % 100 == 0:
            current = batch * len(X)
            avg_loss = total_loss / (batch + 1)
            avg_accuracy = total_accuracy / (batch + 1) * 100
            print(f"Batch {batch}, Loss: {avg_loss:>7f}, Accuracy: {avg_accuracy:>0.2f}% [{current:>5d}/{size:>5d}]")

    # Early stopping condition
    if avg_accuracy >= 95:
        print("Reached 95% accuracy, stopping training.")
        return True  # Stop training

# Training process
epochs = 50
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    if train(train_loader, model, loss_function, optimizer):  # Check for the early stopping signal
        print("Early stopping triggered.")
        break
print("Done!")

这段代码实现了一个基于 PyTorch 的图像分类任务,使用了 FashionMNIST 数据集(包含衣服、鞋子等时尚单品的灰度图像),并**特别引入了早停机制**——当训练集准确率达到 95% 时自动终止训练。

下面我将从实现原理、用途和注意事项三个方面为您详细解释:

### 一、 实现原理

代码的执行流程可以分为以下几个核心步骤:

1. **数据准备**:

- 使用 transforms.ToTensor() 将 PIL 图像转换为 PyTorch 张量,并自动将像素值从 [0, 255] 归一化到 [0, 1]

- 通过 DataLoader 将数据打包成批次(batch size = 64),并打乱训练数据shuffle=True)以增加随机性,防止模型陷入局部最优。

2. **模型构建**:

- FashionMNISTModel 是一个简单的前馈神经网络。

- nn.Flatten():将 28x28 的二维图像展平为 784 维的一维向量。

- nn.Linear(28*28, 128):输入层到隐藏层的全连接,将 784 维映射到 128 维。

- nn.ReLU():激活函数,引入非线性。

- nn.Linear(128, 10):隐藏层到输出层的全连接,映射到 10 个分类。

- nn.LogSoftmax(dim=1):对输出取对数概率。这是为了配合后面的 NLLLoss(负对数似然损失)使用。

3. **损失函数与优化器**:

- nn.NLLLoss():负对数似然损失,通常与 LogSoftmax 搭配使用,等价于交叉熵损失CrossEntropyLoss)。

- optim.Adam:自适应矩估计优化器,能够根据梯度的一阶和二阶矩动态调整学习率。

4. **训练与早停**:

- 在 train 函数中,模型执行前向传播、计算损失、反向传播和参数更新。

- **早停逻辑**:在每个 epoch 结束时,计算该 epoch 的平均准确率 avg_accuracy。如果准确率 $\ge 95\%$,函数返回 True。主循环接收到 True 后,触发 break 跳出循环,结束训练。

---

### 二、 用途

1. **图像分类**:这是计算机视觉中最基础的任务,用于识别输入图像属于哪个类别。

2. **深度学习入门与原型验证**:FashionMNIST 是 MNIST 的升级版,比手写数字更具挑战性,且不会因为太简单而掩盖代码逻辑错误,非常适合用来测试新的网络结构或训练策略。

3. **节省计算资源**:通过早停机制,在模型达到预期性能时提前终止训练,避免不必要的计算开销,这在训练大型模型或资源受限时尤为重要。

---

### 三、 注意事项(包含代码中的潜在问题)

这段代码虽然可以运行,但在实际应用中存在几个需要注意和改进的地方:

1. **早停逻辑的缺陷(仅监控训练集)**:

- **问题**:代码中的早停是基于**训练集准确率**train_loader)判断的。这非常危险!训练集准确率高并不代表模型泛化能力强,模型可能已经严重过拟合。

- **改进**:标准的早停机制应该监控**验证集/测试集**的损失或准确率。当验证集的指标连续几个 epoch 不再下降(或不再上升)时,才触发早停。

2. **变量未定义风险**:

- **问题**:在 train 函数末尾判断 if avg_accuracy >= 95:,但如果 dataloader 为空for 循环不会执行avg_accuracy 变量将未被定义,从而引发 NameError

- **改进**:应在 for 循环前初始化 avg_accuracy = 0,或者确保数据加载器不为空。

3. **损失函数的冗余搭配**:

- **问题**nn.LogSoftmax + nn.NLLLoss() 在数学上等价于 nn.CrossEntropyLoss()

- **改进**:在 PyTorch 中,更推荐直接使用 nn.CrossEntropyLoss(),并删除模型最后一层的 nn.LogSoftmax(dim=1)CrossEntropyLoss 内部集成了 LogSoftmax,数值稳定性更好,且代码更简洁。

4. **缺少模型评估模式 model.eval())**:

- 虽然当前模型只有全连接层和 ReLU,没有 Dropout 和 BatchNorm,所以 model.train()model.eval() 切换与否影响不大。但如果未来要在该框架上增加 Dropout 等层,必须要在验证/测试时调用 model.eval(),否则会导致推理结果错误。

5. **设备兼容性 (CPU/GPU)**:

- 代码默认在 CPU 上运行。如果要在 GPU 上训练,需要显式地将模型和数据移动到 GPU 设备上(如 device = "cuda" if torch.cuda.is_available() else "cpu")。

**总结**:这段代码是一个典型的 PyTorch 基础训练脚本,演示了早停机制的基本实现思路。但在生产环境或严谨的实验中,建议将早停条件改为监控验证集指标,并使用更标准的 CrossEntropyLoss

Epoch 1

-------------------------------

Batch 0, Loss: 2.283812, Accuracy: 15.62% [ 0/60000]

Batch 100, Loss: 0.976695, Accuracy: 67.20% [ 6400/60000]

Batch 200, Loss: 0.795045, Accuracy: 72.68% [12800/60000]

Batch 300, Loss: 0.712192, Accuracy: 75.42% [19200/60000]

Batch 400, Loss: 0.663087, Accuracy: 77.08% [25600/60000]

Batch 500, Loss: 0.626909, Accuracy: 78.38% [32000/60000]

Batch 600, Loss: 0.603203, Accuracy: 79.15% [38400/60000]

Batch 700, Loss: 0.579861, Accuracy: 79.88% [44800/60000]

Batch 800, Loss: 0.561982, Accuracy: 80.44% [51200/60000]

Batch 900, Loss: 0.546929, Accuracy: 80.99% [57600/60000]

Epoch 35

-------------------------------

Batch 0, Loss: 0.170784, Accuracy: 90.62% [ 0/60000]

Batch 100, Loss: 0.124102, Accuracy: 95.51% [ 6400/60000]

Batch 200, Loss: 0.131458, Accuracy: 95.31% [12800/60000]

Batch 300, Loss: 0.130267, Accuracy: 95.23% [19200/60000]

Batch 400, Loss: 0.131773, Accuracy: 95.09% [25600/60000]

Batch 500, Loss: 0.133363, Accuracy: 95.10% [32000/60000]

Batch 600, Loss: 0.132709, Accuracy: 95.10% [38400/60000]

Batch 700, Loss: 0.133067, Accuracy: 95.11% [44800/60000]

Batch 800, Loss: 0.133494, Accuracy: 95.07% [51200/60000]

Batch 900, Loss: 0.134263, Accuracy: 95.02% [57600/60000]

Reached 95% accuracy, stopping training.

Early stopping triggered.

Done!


评论