猫狗图像二分类(ResNet18模型)

Admin
发布于 2026-05-28 / 0 阅读
0
0
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import os
import zipfile
import shutil
from PIL import Image

# ==========================================
# 0. 直接解压本地压缩包并整理数据集 (保持不变)
# ==========================================
zip_path = "kagglecatsanddogs_5340.zip"
base_dir = "cats_dogs_dataset"

if not os.path.exists(base_dir):
    print(f"检测到压缩包 {zip_path},开始解压并整理数据集...")
    extract_dir = "temp_extracted"
    os.makedirs(extract_dir, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_dir)
    
    train_cats_dir = os.path.join(base_dir, 'train', 'cats')
    train_dogs_dir = os.path.join(base_dir, 'train', 'dogs')
    val_cats_dir = os.path.join(base_dir, 'val', 'cats')
    val_dogs_dir = os.path.join(base_dir, 'val', 'dogs')
    for d in [train_cats_dir, train_dogs_dir, val_cats_dir, val_dogs_dir]:
        os.makedirs(d, exist_ok=True)

    original_cat_dir = os.path.join(extract_dir, 'PetImages', 'Cat')
    original_dog_dir = os.path.join(extract_dir, 'PetImages', 'Dog')

    def move_files(src_dir, train_dst, val_dst, train_count=10000, val_count=2500):
        moved, val_moved = 0, 0
        for fname in os.listdir(src_dir):
            if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
                src_path = os.path.join(src_dir, fname)
                try:
                    with Image.open(src_path) as img: img.verify()
                    with Image.open(src_path) as img: img.load()
                    if moved < train_count:
                        shutil.copy(src_path, train_dst); moved += 1
                    elif val_moved < val_count:
                        shutil.copy(src_path, val_dst); val_moved += 1
                    if moved >= train_count and val_moved >= val_count: break
                except Exception: pass

    print("整理猫咪图片..."); move_files(original_cat_dir, train_cats_dir, val_cats_dir)
    print("整理狗狗图片..."); move_files(original_dog_dir, train_dogs_dir, val_dogs_dir)
    shutil.rmtree(extract_dir)
    print("✅ 数据集整理完成!")
else:
    print("检测到已整理的数据集,跳过解压步骤。")

# ==========================================
# 1. 🔥 核心升级:使用预训练 ResNet18 替代自定义小 CNN
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载在 ImageNet 上预训练的 ResNet18
model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

# 🔥 冻结卷积基底的参数,不参与梯度更新(防止小数据集破坏预训练特征)
for param in model.parameters():
    param.requires_grad = False

# 🔥 替换最后的全连接层,适应我们的二分类任务
# ResNet18 最后全连接层的输入特征数是 512
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1) 

model = model.to(device)

# ==========================================
# 2. 🔥 核心升级:加入数据增强
# ==========================================
# 训练集增强:加入随机翻转和旋转
train_transform = transforms.Compose([
    transforms.Resize((224, 224)), # ResNet 标准输入是 224x224
    transforms.RandomHorizontalFlip(), # 随机水平翻转
    transforms.RandomRotation(15),     # 随机旋转 +/- 15 度
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准归一化
])

# 验证集不增强,只做缩放和归一化
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=os.path.join(base_dir, 'train'), transform=train_transform)
val_dataset = datasets.ImageFolder(root=os.path.join(base_dir, 'val'), transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# ==========================================
# 3. 训练与评估逻辑
# ==========================================
criterion = nn.BCEWithLogitsLoss()
# 🔥 只传入全连接层的参数进行优化,学习率设小一点
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

def evaluate(model, data_loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images).view(-1)
            predicted = outputs > 0.0  
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def train_model(max_epochs, target_accuracy):
    for epoch in range(max_epochs):
        model.train()
        running_loss = 0.0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device).float()
            
            optimizer.zero_grad()
            outputs = model(images).view(-1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        val_accuracy = evaluate(model, val_loader, device)
        
        print(f'Epoch {epoch + 1}/{max_epochs}, Loss: {avg_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%')

        if val_accuracy >= target_accuracy:
            print(f'\n🎉 验证准确率 {val_accuracy:.2f}% 已达到目标 {target_accuracy}%,停止训练!')
            torch.save(model.state_dict(), 'resnet_cats_dogs_best.pth')
            print('模型权重已保存至 resnet_cats_dogs_best.pth')
            break

# 🔥 目标提升至 95%
train_model(max_epochs=10, target_accuracy=95.0)

# ==========================================
# 4. 预测本地图片
# ==========================================
def predict(image_path, model, device, transform):
    model.eval()
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(image)
            prediction = output > 0.0
            class_name = "狗" if prediction.item() else "猫" 
            print(f"图片: {image_path} -> 预测结果: 这是一只【{class_name}】")
            print(f"模型原始输出值: {output.item():.4f}\n")
    except Exception as e:
        print(f"预测失败: {e}")

# predict("test_cat1.jpg", model, device, val_transform)
# predict("test_dog1.jpg", model, device, val_transform)

这段代码实现了一个基于深度学习的**猫狗图像二分类任务**,其核心亮点在于使用了**迁移学习**(基于预训练的ResNet18模型)和**数据增强**技术,使得模型能够在少量计算资源和较少训练轮次下,快速达到极高的准确率(目标95%)。

下面我将从实现原理、用途和注意事项三个方面为您详细解释这段代码:

---

### 一、 实现原理

代码的执行流程可以分为四个主要阶段:

#### 1. 数据准备与清洗(第0部分)

* **解压与整理**:从本地读取Kaggle猫狗数据集的压缩包,将其解压到临时目录。

* **划分数据集**:将猫狗图片分别按 10000张训练集 和 2500张验证集 的比例,复制到规范化的目录结构中train/cats, train/dogs, val/cats, val/dogs)。

* **数据清洗**:使用 PIL.Image.verify()load() 来验证图片的完整性,自动跳过损坏的图片文件(Kaggle该数据集中存在少量损坏图片)。

#### 2. 模型构建与迁移学习(第1部分)

* **加载预训练模型**:使用 models.resnet18(weights=...) 加载在ImageNet(包含百万级图片,1000类)上训练好的ResNet18模型。该模型已经具备了极强的底层图像特征提取能力(如边缘、纹理、形状)。

* **冻结参数**:通过 param.requires_grad = False 冻结了ResNet18所有卷积层的参数。在反向传播时,这些层的权重不会更新。这既节省了显存,又防止了小数据集训练导致预训练特征被破坏(过拟合)。

* **替换分类头**:ResNet18原始的最后一层是输出1000个类别的全连接层。代码获取其输入特征数(512),将其替换为 nn.Linear(512, 1),输出一个单一的 logits 值,用于二分类。

#### 3. 数据增强与预处理(第2部分)

* **训练集增强**:除了常规的缩放到224x224(ResNet标准输入),还加入了 RandomHorizontalFlip(随机水平翻转)和 RandomRotation(随机旋转15度)。这相当于变相增加了数据量,能极大提升模型的泛化能力。

* **ImageNet归一化**:使用 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 进行标准化。**这是关键一步**,因为预训练模型是在这样的数据分布上训练的,输入数据必须与之对齐。

* **验证集处理**:只做缩放和归一化,不做增强,以保证评估结果的客观性。

#### 4. 训练与评估逻辑(第3、4部分)

* **损失函数**:使用 nn.BCEWithLogitsLoss()。这是结合了Sigmoid激活和二元交叉熵损失的函数,数值稳定性更好。模型输出单个 logits 值,大于0即为正类(狗),小于0即为负类(猫)。

* **优化器**optim.Adam **仅传入 model.fc.parameters()**,即只训练最后替换的全连接层,学习率设为0.001。

* **早停机制**:每个Epoch结束后计算验证集准确率,一旦达到95%的目标,就保存模型并停止训练。

---

### 二、 用途

1. **快速图像二分类原型开发**:适用于任何需要区分两类物体的计算机视觉任务(如医疗影像的阴阳性、工业缺陷检测、动植物分类等),只需替换数据集即可。

2. **迁移学习最佳实践演示**:代码展示了标准的“冻结基底+替换分类头+对齐归一化”的迁移学习范式,是学习和参考的优秀模板。

3. **资源受限环境下的模型训练**:因为冻结了卷积层,只训练一层全连接层,参数量极小,即使没有高端GPU,在普通CPU上也能快速跑完训练。

---

### 三、 注意事项

1. **首次运行耗时较长**:第0部分的解压、复制和图片验证过程较慢(尤其是逐张验证图片)。如果重复运行,代码通过 if not os.path.exists(base_dir) 做了缓存优化,不要轻易删除生成好的 cats_dogs_dataset 文件夹。

2. **类别映射关系**torchvision.datasets.ImageFolder 会根据文件夹名称的字母顺序自动分配标签cats (c) 在前dogs (d) 在后,因此 **0代表猫,1代表狗**。预测时 output > 0.0 预测为1(狗),否则为0(猫),代码中的 class_name = "狗" if prediction.item() else "猫" 是正确的,但如果更换了数据集,需注意文件夹命名顺序。

3. **目标准确率可能无法达成**:虽然迁移学习很强,但仅训练最后一层、且仅训练10个Epoch,在某些情况下可能无法稳定达到95%的验证准确率。如果发现准确率卡在某个值上不去,可以考虑:

* 取消冻结部分深层卷积参数(微调/Fine-tuning)。

* 增加训练轮次 max_epochs

* 调整学习率。

4. **预测时的预处理一致性**:在第4部分预测本地图片时,使用的是 val_transform,这是完全正确的。切勿在预测时使用带 Random 的增强操作,否则同一张图每次预测结果可能不同。

5. **BCEWithLogitsLoss 的标签格式**:代码中 labels = labels.to(device).float() 将标签转为 float 类型,这是因为 BCEWithLogitsLoss 要求标签和输出同维度同类型,如果不转为 float 会报错。这是一个非常细节且正确的写法。

如果您有关于这段代码的修改需求(比如想要加入微调逻辑,或者改为多分类任务),随时可以告诉我!

import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import os

# ==========================================
# 1. 🔥 重新定义 ResNet18 模型结构 (必须和训练时一致)
# ==========================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 实例化 ResNet18 架构
model = models.resnet18(weights=None) # 预测时不需要下载预训练权重,我们用自己的

# 替换最后的全连接层 (和训练时一样)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 1)

model = model.to(device)

# ==========================================
# 2. 加载训练好的模型权重
# ==========================================
weight_path = 'resnet_cats_dogs_best.pth' 

if not os.path.exists(weight_path):
    print(f"❌ 找不到权重文件 {weight_path},请确保文件在同一目录下!")
    exit()

# 加载权重字典并应用到模型上
model.load_state_dict(torch.load(weight_path, map_location=device))
print(f"✅ 成功加载模型权重: {weight_path}")

# 🔥 必须切换到评估模式!关闭 Dropout 和 BatchNorm 的训练行为
model.eval()

# ==========================================
# 3. 图像预处理 (必须和训练时的验证集预处理一致,不能有数据增强)
# ==========================================
transform = transforms.Compose([
    transforms.Resize((224, 224)), # ResNet 标准输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 标准归一化
])

# ==========================================
# 4. 预测函数
# ==========================================
def predict(image_path, model, device, transform):
    try:
        image = Image.open(image_path).convert('RGB')
        image = transform(image).unsqueeze(0).to(device) # 增加 batch 维度并移动到设备
        
        with torch.no_grad(): # 预测时不计算梯度,节省内存和算力
            output = model(image)
            prediction = output > 0.0  # Logits > 0 视为正类 (狗)
            class_name = "狗" if prediction.item() else "猫" 
            
            print(f"图片: {image_path} -> 预测结果: 这是一只【{class_name}】")
            print(f"模型原始输出值: {output.item():.4f} (越正越像狗,越负越像猫)\n")
    except Exception as e:
        print(f"预测失败 {image_path}: {e}")

# ==========================================
# 5. 执行预测
# ==========================================
# 请确保 test_cat1.jpg 和 test_dog1.jpg 在同一目录下
predict("test_dog1.png", model, device, transform)
predict("test_dog563.png", model, device, transform)
✅ 成功加载模型权重: resnet_cats_dogs_best.pth
图片: test_dog1.png -> 预测结果: 这是一只【狗】
模型原始输出值: 6.0138 (越正越像狗,越负越像猫)

图片: test_dog568.png -> 预测结果: 这是一只【狗】
模型原始输出值: 0.6955 (越正越像狗,越负越像猫)


评论