"石头-布-剪刀”三分类(Inception V3)

Admin
发布于 2026-05-28 / 0 阅读
0
0
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import zipfile
from PIL import Image

# ==========================================
# 0. 解压本地 rps.zip 数据集
# ==========================================
zip_path = "rps.zip"
base_dir = "rps_dataset"

if not os.path.exists(base_dir):
    print(f"检测到压缩包 {zip_path},开始解压...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(base_dir)
    print("✅ 数据集解压完成!")
else:
    print("检测到已解压的数据集,跳过解压步骤。")

# rps.zip 解压后的标准目录结构通常是:
# rps_dataset/rps/rock/, rps_dataset/rps/paper/, rps_dataset/rps/scissors/
# 以及验证集:rps_dataset/rps-test-set/rock/ 等
# 我们需要根据实际解压情况指定路径
train_dir = os.path.join(base_dir, "rps")
# 如果验证集也在zip里,通常叫 rps-test-set
val_dir = os.path.join(base_dir, "rps-test-set") 

# 检查路径是否存在,如果不存在给出提示
if not os.path.exists(train_dir):
    print(f"❌ 找不到训练集路径: {train_dir}")
    print("请检查解压后的文件夹结构,并修改 train_dir 变量。")
    exit()

# ==========================================
# 1. 数据加载与预处理 (Inception V3 需要 299x299)
# ==========================================
train_transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.RandomHorizontalFlip(), # 数据增强
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=train_dir, transform=train_transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 如果没有单独的验证集文件夹,可以暂时用训练集代替验证(仅作演示,不推荐实际操作)
if os.path.exists(val_dir):
    val_dataset = datasets.ImageFolder(root=val_dir, transform=val_transform)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
else:
    print("⚠️ 未找到独立验证集文件夹,将使用训练集作为验证集(仅供测试代码运行)。")
    val_loader = train_loader

# 获取类别映射字典 (非常重要,预测时用来把数字转回文字)
class_names = train_dataset.classes  # 例如:['paper', 'rock', 'scissors']
print(f"识别到的类别: {class_names}")

# ==========================================
# 2. 构建模型:Inception V3 迁移学习 (3分类)
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 重新实例化干净模型,防止之前的报错
model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, aux_logits=True)

# 冻结底层参数
for name, parameter in model.named_parameters():
    parameter.requires_grad = False
    if 'Mixed_7c' in name:
        break

# 修改主分类器
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 1024),
    nn.ReLU(),
    nn.Linear(1024, len(class_names)) # 输出层节点数 = 类别数 (3)
)

# 🔥 必须同步修改辅助分类器,否则训练计算损失时会报错
num_aux_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_aux_ftrs, len(class_names))

model = model.to(device)

# ==========================================
# 3. 定义损失函数和优化器
# ==========================================
# 3分类使用 CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
# 只优化未冻结的全连接层参数
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# ==========================================
# 4. 训练与验证逻辑
# ==========================================
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=5):
    for epoch in range(num_epochs):
        # --- 训练阶段 ---
        model.train()
        running_loss = 0.0
        running_corrects = 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # Inception V3 训练时返回元组 (主输出, 辅助输出)
            outputs = model(inputs)
            if isinstance(outputs, tuple):
                main_output, aux_output = outputs
                loss1 = criterion(main_output, labels)
                loss2 = criterion(aux_output, labels)
                loss = loss1 + 0.4 * loss2 # 辅助损失乘以0.4权重
            else:
                main_output = outputs
                loss = criterion(main_output, labels)

            _, preds = torch.max(main_output, 1) # 必须用主输出计算准确率

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels).item()

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = running_corrects / len(train_loader.dataset)

        # --- 验证阶段 ---
        model.eval()
        val_loss = 0.0
        val_corrects = 0

        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                
                # eval 模式下只返回主输出
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                _, preds = torch.max(outputs, 1)

                val_loss += loss.item() * inputs.size(0)
                val_corrects += torch.sum(preds == labels).item()

        val_epoch_loss = val_loss / len(val_loader.dataset)
        val_epoch_acc = val_corrects / len(val_loader.dataset)

        print(f'Epoch {epoch+1}/{num_epochs} | '
              f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f} | '
              f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')

    # 保存模型
    torch.save(model.state_dict(), 'rps_inception_v3.pth')
    print("✅ 模型训练完成并已保存!")

# 开始训练
train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=5)

# ==========================================
# 5. 预测本地单张图片
# ==========================================
def predict_image(image_path, model, device, transform, class_names):
    model.eval()
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device)
        
        with torch.no_grad():
            outputs = model(image_tensor)
            _, preds = torch.max(outputs, 1)
            predicted_class = class_names[preds.item()]
            
            print(f"图片: {image_path} -> 预测结果: 【{predicted_class}】")
    except Exception as e:
        print(f"预测失败: {e}")

# 取消注释以测试预测:
# predict_image("test_rock.jpg", model, device, val_transform, class_names)

这段代码实现了一个完整的基于深度学习的**图像分类任务**,具体来说是对“石头-布-剪刀”三分类数据集进行训练和预测。它使用了 PyTorch 框架,并采用了 **Inception V3** 模型进行**迁移学习**。

下面我将从实现原理、用途和注意事项三个方面为您详细解释这段代码:

### 一、 实现原理

代码的执行流程可以分为以下五个核心步骤:

1. **数据准备与解压 (第0部分)**

- 使用 zipfile 将本地的 rps.zip 解压到 rps_dataset 目录。

- 自动推断训练集和验证集的路径torchvision.datasets.ImageFolder 会自动根据子文件夹名称(如 rock, paper, scissors)为图片打标签,生成类别映射。

2. **数据预处理与增强 (第1部分)**

- **尺寸调整**:Inception V3 模型的原始设计要求输入图像尺寸为 **299x299**,这与 ResNet 等常用的 224x224 不同,因此必须 Resize((299, 299))

- **数据增强**:训练集加入了随机水平翻转和随机旋转,以增加数据多样性,防止模型过拟合。

- **标准化**:使用 ImageNet 数据集的均值和标准差对图像张量进行归一化,这是迁移学习的标准操作,使得输入数据分布与模型预训练时的分布一致。

3. **构建迁移学习模型 (第2部分)**

- **加载预训练权重**:使用 models.Inception_V3_Weights.DEFAULT 加载在 ImageNet 上预训练的模型。

- **冻结底层参数**:遍历模型参数,将底层(特征提取层)的 requires_grad 设为 False。在训练时,这些层的权重将不会更新,从而保留预训练提取通用特征的能力,同时大幅减少显存占用和计算量。代码中设置冻结到 Mixed_7c 层。

- **替换分类头**:

- **主分类器**:将最后的全连接层替换为自定义结构(包含一个 1024 节点的隐藏层和 ReLU 激活,最后输出节点数为 3,对应 3 个类别)。

- **辅助分类器🔥**:这是 Inception V3 的特殊之处。Inception V3 在网络中间层有一个辅助输出分支以帮助梯度回传。**必须同时修改 model.AuxLogits.fc 的输出维度为 3**,否则后续计算损失函数时会因为维度不匹配而报错。

4. **模型训练与验证 (第3、4部分)**

- **损失计算**:由于 Inception V3 在 train 模式下返回的是一个元组 (主输出, 辅助输出),代码对此做了特殊处理:分别计算两部分的损失,并按 总损失 = 主损失 + 0.4 * 辅助损失 进行加权求和。这个 0.4 是原论文中推荐的权重。

- **优化器**:只将 model.fc.parameters() 传入优化器,意味着只对未冻结的新分类层进行参数更新。

- **验证阶段**:模型处于 eval 模式时,Inception V3 会自动关闭辅助分支,只返回主输出,因此验证时的损失计算与常规模型无异。

5. **单张图片预测 (第5部分)**

- 加载图片 -> 转为 RGB 模式 -> 应用验证集的 transform -> 增加 batch 维度 unsqueeze(0)) -> 送入模型 -> 取 argmax 获取类别索引 -> 映射回类别名称。

---

### 二、 用途

1. **岩石-布-剪刀手势识别**:这是代码的直接用途,可以部署到摄像头应用或机器人视觉中,实现人机猜拳游戏。

2. **Inception V3 迁移学习模板**:这段代码非常标准地展示了如何处理带有辅助分类器的特殊网络(如 Inception 系列)。如果你需要用 Inception V3 去做其他分类任务(如猫狗分类、医学图像分类),只需修改数据路径和最后的输出维度,这套代码可以直接复用。

3. **小样本数据集训练方案**:通过冻结底层、只训练顶层的方式,代码展示了如何在数据量不大的情况下(几千张图片),利用大模型在 ImageNet 上学到的特征快速达到高精度。

---

### 三、 注意事项

1. **Inception V3 的输入尺寸必须是 299x299**

如果将 Resize((299, 299)) 改为其他尺寸(如 224x224),模型在运行时会抛出维度不匹配的 RuntimeError。

2. **辅助分类器必须修改**

很多初学者在迁移学习时只修改了 model.fc,而忘记了 model.AuxLogits.fc。这会导致训练时 loss = criterion(aux_output, labels) 报错,因为辅助分类器默认输出 1000 类,而标签是 3 类。

3. **优化器需与冻结策略匹配**

代码中冻结了底层参数,只训练 model.fc。如果优化器写成了 optim.Adam(model.parameters(), lr=0.001),虽然不会报错(因为 requires_grad=False 的参数即使传给优化器也不会更新),但会无谓地占用计算资源。

4. **训练与评估模式的切换 model.train() vs model.eval())**

- 对于 Inception V3train 模式启用 Dropout 和 Batch Norm 的训练模式,且**返回两个输出**。

- eval 模式关闭 Dropout,使用 Batch Norm 的运行时统计量,且**只返回主输出**。

- 如果在验证时忘记调用 model.eval(),会导致验证输出变成元组,后续 torch.max(outputs, 1) 会报错。

5. **数据集目录结构要求**

代码使用了 datasets.ImageFolder,这要求数据集必须严格遵循如下文件夹结构:

```text

rps_dataset/

└── rps/

├── paper/

├── rock/

└── scissors/

```

如果你的压缩包解压后的结构不是这样ImageFolder 将无法正确识别类别。

6. **冻结层截断点的选择**

代码中使用 if 'Mixed_7c' in name: break 来解冻顶层。Inception V3 的结构比较复杂,如果你想微调更多的层以获得更高精度(前提是数据量足够),可以修改这个截断条件,例如改为 Mixed_7aMixed_6e,并注意适当降低学习率。

7. **预测时的 Transform**

预测单张图片时,使用的是 val_transform,这非常重要。不能使用 train_transform,因为预测时不需要数据增强(如随机翻转、旋转),否则同一张图每次预测的结果可能不同。

import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import os

# ==========================================
# 1. 重新定义模型结构 (必须和训练时完全一致!)
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化不带预训练权重的 Inception V3 结构
model = models.inception_v3(weights=None, aux_logits=True)

# 冻结底层参数 (结构需要保留,虽然预测时不需要梯度,但结构必须对齐才能加载权重)
for name, parameter in model.named_parameters():
    parameter.requires_grad = False
    if 'Mixed_7c' in name:
        break

# 🔥 必须定义和训练时一模一样的自定义分类头
# 类别数是 3 (rock, paper, scissors)
num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 1024),
    nn.ReLU(),
    nn.Linear(1024, 3) 
)

# 🔥 辅助分类器也必须修改,否则加载权重会报错
num_aux_ftrs = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Linear(num_aux_ftrs, 3)

model = model.to(device)

# ==========================================
# 2. 加载训练好的模型权重
# ==========================================
weight_path = 'rps_inception_v3.pth' 

if not os.path.exists(weight_path):
    print(f"❌ 找不到权重文件 {weight_path},请确保文件在同一目录下!")
    exit()

# 将权重加载到模型中
model.load_state_dict(torch.load(weight_path, map_location=device))
print(f"✅ 成功加载模型权重: {weight_path}")

# 🔥 切换到评估模式!这会关闭 Dropout 和 BatchNorm 的训练行为
# 并且让 Inception V3 只输出主结果,不输出辅助结果
model.eval()

# ==========================================
# 3. 图像预处理 (必须和训练时的验证集预处理一致,不能有数据增强)
# ==========================================
transform = transforms.Compose([
    transforms.Resize((299, 299)), # Inception V3 标准输入
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 🔥 类别映射字典 (根据 ImageFolder 按首字母自动排序的规则:paper=0, rock=1, scissors=2)
class_names = ['paper', 'rock', 'scissors']

# ==========================================
# 4. 预测函数
# ==========================================
def predict_image(image_path, model, device, transform, class_names):
    try:
        image = Image.open(image_path).convert('RGB')
        image_tensor = transform(image).unsqueeze(0).to(device) # 增加 batch 维度并移动到设备
        
        with torch.no_grad(): # 预测时不计算梯度,节省内存和算力
            outputs = model(image_tensor)
            # outputs 现在是主分类器的输出 (因为 model.eval())
            _, preds = torch.max(outputs, 1) 
            predicted_class = class_names[preds.item()]
            
            # 获取各个类别的概率
            probabilities = torch.nn.functional.softmax(outputs[0], dim=0)
            confidence = probabilities[preds.item()].item() * 100
            
            print(f"图片: {image_path} -> 预测结果: 【{predicted_class}】 (置信度: {confidence:.2f}%)")
    except Exception as e:
        print(f"预测失败 {image_path}: {e}")

# ==========================================
# 5. 执行预测
# ==========================================
# 请确保这些测试图片存在于当前目录
predict_image("test_rock.jpg", model, device, transform, class_names)
predict_image("test_paper.jpg", model, device, transform, class_names)
predict_image("test_scissors1.jpg", model, device, transform, class_names)
✅ 成功加载模型权重: rps_inception_v3.pth
图片: test_rock.jpg -> 预测结果: 【paper】 (置信度: 91.51%)
图片: test_paper.jpg -> 预测结果: 【paper】 (置信度: 95.79%)
图片: test_scissors1.jpg -> 预测结果: 【scissors】 (置信度: 100.00%)


评论