计算机视觉图像二分类预测脚本

Admin
发布于 2026-05-27 / 4 阅读
0
0
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms

# ==========================================
# 1. 定义 HorsesHumansCNN 模型 (必须和训练时一致)
# ==========================================
class HorsesHumansCNN(nn.Module):
    def __init__(self):
        super(HorsesHumansCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 18 * 18, 512)
        self.fc2 = nn.Linear(512, 1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = self.dropout(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

# ==========================================
# 2. 图像预处理与加载
# ==========================================
transform = transforms.Compose([
    transforms.Resize((150, 150)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def load_image(image_path, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)  # 增加 batch 维度: [1, 3, 150, 150]
    return image

# ==========================================
# 3. 预测函数
# ==========================================
def predict(image_path, model, device, transform):
    model.eval()  # 切换到评估模式,关闭 Dropout
    image = load_image(image_path, transform)
    image = image.to(device)
    
    with torch.no_grad():
        output = model(image)
        
        # 🔥 关键点:因为训练时用的是 BCEWithLogitsLoss,模型输出的是 Logits
        # Logits > 0.0 等价于 Sigmoid概率 > 0.5
        prediction = output > 0.0  
        
        is_human = prediction.item()
        class_name = "人" if is_human else "马"
        
        print(f"图片: {image_path}")
        print(f"预测结果: 这是一张【{class_name}】的图片")
        # 打印原始输出值,帮助理解模型判断的置信度(越远离0表示越确信)
        print(f"模型原始输出值: {output.item():.4f}\n") 

# ==========================================
# 4. 主程序执行
# ==========================================
if __name__ == "__main__":
    # 设置设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 实例化模型并移动到设备
    model = HorsesHumansCNN().to(device)
    
    # 🔥 加载训练好的权重
    # 请确保 'best_model.pth' 文件存在于同一目录下!
    # 如果你之前保存的名字不同,请在这里修改
    weight_path = 'best_model.pth' 
    try:
        model.load_state_dict(torch.load(weight_path, map_location=device))
        print(f"成功加载模型权重: {weight_path}\n")
    except FileNotFoundError:
        print(f"❌ 找不到权重文件 {weight_path},请先训练模型并保存权重!")
        exit()

    # 预测本地图片
    predict("1.png", model, device, transform)
    predict("2.png", model, device, transform)

这段代码是一个基于PyTorch实现的**计算机视觉图像二分类预测脚本**,专门用于识别输入图片中是“人”还是“马”。

下面我将从**实现原理、用途、注意事项**三个方面为你详细解释这段代码:

---

### 一、 实现原理

代码的执行流程可以分为四个核心步骤:

#### 1. 模型定义 HorsesHumansCNN)

这是一个典型的卷积神经网络(CNN),专门为处理150x150的RGB图像设计:

* **特征提取部分**:包含3个卷积层Conv2d)和3个最大池化层MaxPool2d)。卷积层用于提取图像的边缘、纹理等特征,通道数从3(RGB)逐渐增加到16、32、64;池化层用于下采样,缩小特征图尺寸(每次池化尺寸减半),减少计算量并增加平移不变性。

* **分类头部分**:包含2个全连接层Linear)。将三维的特征图展平x.view)为一维向量后,经过Dropout(防止过拟合)和ReLU激活函数,最终映射到**1个输出节点**。

* **输出含义**:输出的是一个单一的**Logit值**,而不是概率值。

#### 2. 图像预处理 transform & load_image)

深度学习模型对输入数据的格式有严格要求,预处理旨在将任意输入图片转化为模型期望的张量:

Resize((150, 150)):强制缩放图片尺寸,与模型全连接层输入维度(6418*18)相匹配(150经过3次池化变为 150/2/2/2 = 18.75,取整为18)。

* ToTensor():将PIL图片转换为PyTorch张量,并将像素值从 [0, 255] 归一化到 [0.0, 1.0]

* Normalize(...):使用均值和标准差0.5进行标准化,将像素范围进一步映射到 [-1.0, 1.0],加速模型收敛。

* unsqueeze(0):PyTorch模型默认按批次处理数据,此操作在维度0增加了一个批次维度,将形状从 [3, 150, 150] 变为 [1, 3, 150, 150]

#### 3. 预测逻辑 predict)

model.eval():将模型切换到评估模式,这会*关闭Dropout层**,确保预测结果的确定性和一致性。

* torch.no_grad():禁用梯度计算,减少内存消耗并加快推理速度。

* **🔥 关键逻辑 output > 0.0**:由于模型在训练时使用的是 BCEWithLogitsLoss(将Sigmoid和二元交叉熵结合的损失函数),模型最后一层没有加Sigmoid激活函数,直接输出Logits。在数学上,Sigmoid函数在输入为0时输出0.5(即分类阈值)。因此,**Logits > 0 等价于 Sigmoid概率 > 0.5**。代码直接判断输出是否大于0,大于0判定为“人”,小于等于0判定为“马”。

#### 4. 权重加载与执行

* 使用 torch.load 加载之前训练好的模型参数.pth文件),并通过 model.load_state_dict 将参数注入到模型中。

* 使用 map_location=device 确保即使模型在GPU上训练,也能在CPU上正确加载。

---

### 二、 用途

1. **图像二分类推理**:这是最直接的用途,用于部署或测试已训练好的“人/马”分类模型。

2. **迁移学习模板**:这段代码结构非常标准,只需修改模型的全连接层输出维度和 predict 函数中的类别映射,即可轻松迁移到其他二分类任务(如猫狗分类、口罩检测等)。

3. **模型验证**:在训练过程中或训练完成后,开发者可以使用此类脚本快速验证模型在本地新图片上的表现,通过打印原始Logit值来评估模型的置信度。

---

### 三、 注意事项

1. **模型结构与训练时必须绝对一致**:

代码中的 HorsesHumansCNN 类的定义必须与保存 best_model.pth 时的结构**完全一致**。即使只改动了一层的一个参数load_state_dict 也会报错导致无法加载权重。

2. **文件路径与依赖**:

* 运行前必须确保 best_model.pth 存在于当前目录。

* 预测的图片(如 1.png, 2.png)也需要存在,否则 Image.open 会报错。

3. **输入尺寸的固定性**:

该模型由于包含了全连接层fc1 的输入特征数被硬编码为 64 18 18),因此**只能**接受 150x150 尺寸的图片。如果输入其他尺寸,在执行 x.view(x.size(0), -1) 后,一维向量的长度不匹配,会导致运行报错。

4. **Logit与概率的区别**:

代码中打印的 模型原始输出值 是 Logit 而不是概率。它的范围是 (-∞, +∞)。值越接近0,说明模型越犹豫(置信度低);值越远离0(例如 +15 或 -15),说明模型对判断越确信。如果需要输出概率,可以手动加上 prob = torch.sigmoid(output).item()

5. **标签映射的对应关系**:

代码中 is_human = prediction.item(),当为 True 时对应“人”。这依赖于**训练时数据集的标签设定**。如果在训练时,标签1代表“马”,标签0代表“人”,那么这里的判断逻辑就完全反了。必须确保推理时的类别映射与训练时的标签编码一致。

如果你有任何关于代码修改或扩展的需求,随时可以问我!

成功加载模型权重: best_model.pth

图片: 1.png
预测结果: 这是一张【人】的图片
模型原始输出值: 4.2016

图片: 2.png
预测结果: 这是一张【马】的图片
模型原始输出值: -3.7493


评论