import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
# ==========================================
# 1. 定义 HorsesHumansCNN 模型 (必须和训练时一致)
# ==========================================
class HorsesHumansCNN(nn.Module):
def __init__(self):
super(HorsesHumansCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(64 * 18 * 18, 512)
self.fc2 = nn.Linear(512, 1)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.pool(torch.relu(self.conv1(x)))
x = self.pool(torch.relu(self.conv2(x)))
x = self.pool(torch.relu(self.conv3(x)))
x = x.view(x.size(0), -1)
x = self.dropout(torch.relu(self.fc1(x)))
x = self.fc2(x)
return x
# ==========================================
# 2. 图像预处理与加载
# ==========================================
transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def load_image(image_path, transform):
image = Image.open(image_path).convert('RGB')
image = transform(image)
image = image.unsqueeze(0) # 增加 batch 维度: [1, 3, 150, 150]
return image
# ==========================================
# 3. 预测函数
# ==========================================
def predict(image_path, model, device, transform):
model.eval() # 切换到评估模式,关闭 Dropout
image = load_image(image_path, transform)
image = image.to(device)
with torch.no_grad():
output = model(image)
# 🔥 关键点:因为训练时用的是 BCEWithLogitsLoss,模型输出的是 Logits
# Logits > 0.0 等价于 Sigmoid概率 > 0.5
prediction = output > 0.0
is_human = prediction.item()
class_name = "人" if is_human else "马"
print(f"图片: {image_path}")
print(f"预测结果: 这是一张【{class_name}】的图片")
# 打印原始输出值,帮助理解模型判断的置信度(越远离0表示越确信)
print(f"模型原始输出值: {output.item():.4f}\n")
# ==========================================
# 4. 主程序执行
# ==========================================
if __name__ == "__main__":
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 实例化模型并移动到设备
model = HorsesHumansCNN().to(device)
# 🔥 加载训练好的权重
# 请确保 'best_model.pth' 文件存在于同一目录下!
# 如果你之前保存的名字不同,请在这里修改
weight_path = 'best_model.pth'
try:
model.load_state_dict(torch.load(weight_path, map_location=device))
print(f"成功加载模型权重: {weight_path}\n")
except FileNotFoundError:
print(f"❌ 找不到权重文件 {weight_path},请先训练模型并保存权重!")
exit()
# 预测本地图片
predict("1.png", model, device, transform)
predict("2.png", model, device, transform)这段代码是一个基于PyTorch实现的**计算机视觉图像二分类预测脚本**,专门用于识别输入图片中是“人”还是“马”。
下面我将从**实现原理、用途、注意事项**三个方面为你详细解释这段代码:
---
### 一、 实现原理
代码的执行流程可以分为四个核心步骤:
#### 1. 模型定义 HorsesHumansCNN)
这是一个典型的卷积神经网络(CNN),专门为处理150x150的RGB图像设计:
* **特征提取部分**:包含3个卷积层Conv2d)和3个最大池化层MaxPool2d)。卷积层用于提取图像的边缘、纹理等特征,通道数从3(RGB)逐渐增加到16、32、64;池化层用于下采样,缩小特征图尺寸(每次池化尺寸减半),减少计算量并增加平移不变性。
* **分类头部分**:包含2个全连接层Linear)。将三维的特征图展平x.view)为一维向量后,经过Dropout(防止过拟合)和ReLU激活函数,最终映射到**1个输出节点**。
* **输出含义**:输出的是一个单一的**Logit值**,而不是概率值。
#### 2. 图像预处理 transform & load_image)
深度学习模型对输入数据的格式有严格要求,预处理旨在将任意输入图片转化为模型期望的张量:
Resize((150, 150)):强制缩放图片尺寸,与模型全连接层输入维度(6418*18)相匹配(150经过3次池化变为 150/2/2/2 = 18.75,取整为18)。
* ToTensor():将PIL图片转换为PyTorch张量,并将像素值从 [0, 255] 归一化到 [0.0, 1.0]。
* Normalize(...):使用均值和标准差0.5进行标准化,将像素范围进一步映射到 [-1.0, 1.0],加速模型收敛。
* unsqueeze(0):PyTorch模型默认按批次处理数据,此操作在维度0增加了一个批次维度,将形状从 [3, 150, 150] 变为 [1, 3, 150, 150]。
#### 3. 预测逻辑 predict)
model.eval():将模型切换到评估模式,这会*关闭Dropout层**,确保预测结果的确定性和一致性。
* torch.no_grad():禁用梯度计算,减少内存消耗并加快推理速度。
* **🔥 关键逻辑 output > 0.0**:由于模型在训练时使用的是 BCEWithLogitsLoss(将Sigmoid和二元交叉熵结合的损失函数),模型最后一层没有加Sigmoid激活函数,直接输出Logits。在数学上,Sigmoid函数在输入为0时输出0.5(即分类阈值)。因此,**Logits > 0 等价于 Sigmoid概率 > 0.5**。代码直接判断输出是否大于0,大于0判定为“人”,小于等于0判定为“马”。
#### 4. 权重加载与执行
* 使用 torch.load 加载之前训练好的模型参数.pth文件),并通过 model.load_state_dict 将参数注入到模型中。
* 使用 map_location=device 确保即使模型在GPU上训练,也能在CPU上正确加载。
---
### 二、 用途
1. **图像二分类推理**:这是最直接的用途,用于部署或测试已训练好的“人/马”分类模型。
2. **迁移学习模板**:这段代码结构非常标准,只需修改模型的全连接层输出维度和 predict 函数中的类别映射,即可轻松迁移到其他二分类任务(如猫狗分类、口罩检测等)。
3. **模型验证**:在训练过程中或训练完成后,开发者可以使用此类脚本快速验证模型在本地新图片上的表现,通过打印原始Logit值来评估模型的置信度。
---
### 三、 注意事项
1. **模型结构与训练时必须绝对一致**:
代码中的 HorsesHumansCNN 类的定义必须与保存 best_model.pth 时的结构**完全一致**。即使只改动了一层的一个参数load_state_dict 也会报错导致无法加载权重。
2. **文件路径与依赖**:
* 运行前必须确保 best_model.pth 存在于当前目录。
* 预测的图片(如 1.png, 2.png)也需要存在,否则 Image.open 会报错。
3. **输入尺寸的固定性**:
该模型由于包含了全连接层fc1 的输入特征数被硬编码为 64 18 18),因此**只能**接受 150x150 尺寸的图片。如果输入其他尺寸,在执行 x.view(x.size(0), -1) 后,一维向量的长度不匹配,会导致运行报错。
4. **Logit与概率的区别**:
代码中打印的 模型原始输出值 是 Logit 而不是概率。它的范围是 (-∞, +∞)。值越接近0,说明模型越犹豫(置信度低);值越远离0(例如 +15 或 -15),说明模型对判断越确信。如果需要输出概率,可以手动加上 prob = torch.sigmoid(output).item()。
5. **标签映射的对应关系**:
代码中 is_human = prediction.item(),当为 True 时对应“人”。这依赖于**训练时数据集的标签设定**。如果在训练时,标签1代表“马”,标签0代表“人”,那么这里的判断逻辑就完全反了。必须确保推理时的类别映射与训练时的标签编码一致。
如果你有任何关于代码修改或扩展的需求,随时可以问我!
成功加载模型权重: best_model.pth
图片: 1.png
预测结果: 这是一张【人】的图片
模型原始输出值: 4.2016
图片: 2.png
预测结果: 这是一张【马】的图片
模型原始输出值: -3.7493