import urllib.request
import zipfile
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# ================= 1. 数据下载与解压 =================
url_train = "https://storage.googleapis.com/learning-datasets/horse-or-human.zip"
file_train = "horse-or-human.zip"
training_dir = 'horse-or-human/training/' # 训练集目录
url_val = "https://storage.googleapis.com/learning-datasets/validation-horse-or-human.zip"
file_val = "validation-horse-or-human.zip"
validation_dir = 'horse-or-human/validation/' # 验证集目录
def download_and_extract(url, filename, target_dir):
os.makedirs(target_dir, exist_ok=True)
if os.path.exists(filename):
print(f"✅ 检测到本地已存在 {filename},跳过下载。")
else:
print(f"⬇️ 本地未找到 {filename},开始下载...")
urllib.request.urlretrieve(url, filename)
print(f"📂 正在解压 {filename} 到 {target_dir} ...")
with zipfile.ZipFile(filename, 'r') as zip_ref:
zip_ref.extractall(target_dir)
print(f"✅ 解压完成。\n")
# 执行下载解压
download_and_extract(url_train, file_train, training_dir)
download_and_extract(url_val, file_val, validation_dir)
# ================= 2. 数据变换与加载 =================
# 训练集变换:包含数据增强 + 标准化
train_transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(20),
transforms.RandomAffine(
degrees=0,
translate=(0.2, 0.2),
scale=(0.8, 1.2),
shear=20,
),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# ⚠️ 验证集变换:绝不能包含数据增强!只做缩放和标准化
val_transform = transforms.Compose([
transforms.Resize((150, 150)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
# 加载数据集,传入对应的目录和变换规则
train_dataset = datasets.ImageFolder(root=training_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(root=validation_dir, transform=val_transform)
# 数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 验证集不需要打乱顺序
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
# ================= 3. 验证加载结果 =================
print(f"训练集样本数: {len(train_dataset)}")
print(f"验证集样本数: {len(val_dataset)}")
print(f"类别映射关系: {train_dataset.class_to_idx}") # ImageFolder会自动将子文件夹名映射为0, 1...
这段代码是一个非常规范的**计算机视觉(CV)二分类任务的数据准备流程**,主要用于构建一个“马与人”的图像分类模型。
下面我将从**实现原理、用途、注意事项**三个方面为你详细拆解这段代码:
---
### 一、 代码用途
这段代码的核心目的是**自动化获取并构建用于PyTorch模型训练的数据管道**。
具体来说,它:
1. 从网络下载“马与人”的训练集和验证集压缩包。
2. 将其解压到本地指定目录。
3. 对图像进行预处理(训练集包含数据增强,验证集仅做基础缩放和标准化)。
4. 将数据打包成PyTorch的 DataLoader,以便在模型训练时按批次高效读取。
---
### 二、 实现原理(代码分步解析)
#### 1. 数据下载与解压
* **断点续传/跳过机制**:通过 os.path.exists(filename) 检查本地是否已有压缩包。如果有,则跳过下载,这在进行多次实验时能节省大量时间。
* **目录创建**os.makedirs(target_dir, exist_ok=True) 确保目标文件夹存在exist_ok=True 避免了目录已存在时报错。
* **解压**:使用Python内置的 zipfile 模块将图片解压到对应的训练或验证目录中。
#### 2. 数据变换与加载
这是这段代码最核心的部分,使用了 torchvision.transforms 和 datasets.ImageFolder。
* **训练集变换 train_transform)**:
* Resize((150, 150)):将大小不一的图片统一缩放到 150x150 像素,满足神经网络固定输入维度的要求。
* **数据增强**RandomHorizontalFlip(随机水平翻转)RandomRotation(随机旋转20度)RandomAffine(仿射变换,包含平移、缩放和剪切)。这些操作能让模型看到更多形态的图片,防止过拟合,提升泛化能力。
* ToTensor():将 PIL Image 转换为 PyTorch 张量,并把像素值范围从 [0, 255] 归一化到 [0.0, 1.0]。
* Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):将 [0.0, 1.0] 的数据进一步映射到 [-1.0, 1.0],这能加速模型收敛。
* **验证集变换 val_transform)**:
* **仅包含 Resize、ToTensor 和 Normalize**。**绝不能**加入数据增强,因为验证集的目的是客观评估模型在未见过的原始数据上的表现,数据增强会破坏这种客观性。
* **数据集加载 datasets.ImageFolder)**:
* ImageFolder 是PyTorch提供的一个极其方便的API,它会自动根据子文件夹的名字来分类。例如,你的 training_dir 下有 horses 和 humans 两个文件夹,它会自动将前者标记为类别 0,后者标记为类别 1。
* **数据加载器 DataLoader)**:
* 将数据集封装成迭代器batch_size=32 表示每次向模型喂入32张图片。
* 训练集 shuffle=True:打乱数据顺序,防止模型按顺序学习到某种规律。
* 验证集 shuffle=False:不需要打乱,节省算力且便于追踪特定样本的预测结果。
---
### 三、 注意事项与潜在改进点
1. **解压逻辑的冗余执行**:
* **当前问题**:代码中只要执行 download_and_extract,就会无条件执行解压操作zip_ref.extractall)。即使文件已经解压过,每次运行脚本都会重新解压覆盖,这会浪费时间。
* **改进建议**:可以增加一个判断,如果目标目录下已经有文件,则跳过解压步骤。
2. **数据增强的边界风险**:
* **当前问题**RandomAffine 中的 translate=(0.2, 0.2) 表示图片在水平和垂直方向最多平移 20%。对于 150x150 的图片,平移 20% 可能会导致目标(马或人)移出画面,而背景被填充为黑色(默认值)。如果模型频繁看到“只有背景没有目标”的图片被标记为“马”或“人”,会干扰学习。
* **改进建议**:可以适当降低平移比例(如 0.1),或在 RandomAffine 中设置 fill=255(用白色填充)等。
3. **归一化参数的准确性**:
* **当前问题**:代码使用了 mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5],这是一个通用的快速近似值。但对于特定的数据集(如本例中的马和人),RGB三个通道的真实均值和方差可能并不是 0.5。
* **改进建议**:在工业级任务中,通常会先跑一个脚本计算整个训练集的 RGB 均值和标准差,然后再填入 Normalize,这样模型收敛最快、效果最好。
4. **并发加载(num_workers)**:
* **当前问题**DataLoader 默认 num_workers=0,意味着数据加载在主进程中进行。如果模型训练速度很快,CPU读取数据的速度可能跟不上GPU的计算速度,导致GPU闲置。
* **改进建议**:在 DataLoader 中添加 num_workers=4(根据CPU核心数调整),利用多进程并发加载数据,提升训练效率。
---
**总结**:这是一段结构清晰、逻辑正确的深度学习数据准备代码,特别强调了训练集与验证集在数据增强上的区别,非常适合作为图像分类任务的起点模板。只需注意上述提到的解压冗余和数据增强越界的小细节即可。
✅ 检测到本地已存在 horse-or-human.zip,跳过下载。
📂 正在解压 horse-or-human.zip 到 horse-or-human/training/ ...
✅ 解压完成。
✅ 检测到本地已存在 validation-horse-or-human.zip,跳过下载。
📂 正在解压 validation-horse-or-human.zip 到 horse-or-human/validation/ ...
✅ 解压完成。
训练集样本数: 1027
验证集样本数: 256
类别映射关系: {'horses': 0, 'humans': 1}