《PyTorch深度学习实践》8. 加载数据集

加载数据集

Dataset和Dataloader的用法

两个极端

  1. 直接用全部的数据集训练(Batch):最大化利用向量计算优势,但消耗资源过大

  2. 每次只用一个样本训练模型:具有较好的随机性,有助于跨越鞍点,但并行化差,计算效率太低(存疑:batch_size过小,每个mini-batch的样本数据将没有统计意义)

因此引入Mini-Batch,可以均衡训练的速度和训练效果

1
2
3
4
# Training cycle
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):

一些概念

  • Epoch:所有样本都参与了一次训练
  • Batch-Size:一次forward-backward pass中用的样本数量
  • Iteration:内层迭代一共进行了多少次,即pass的数目

DataLoader用法

这是训练的四个步骤中Prepare dataset这一步

  • Prepare dataset

    Dataset and Dataloader

  • Design model using Class

    • Inherit from nn.Module
  • Construct loss and optimizer

  • Training cycle

    • forward, backward, update
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
from torch.utils.data import Dataset # Dataset是一个抽象类,不能直接实例化,必须被继承
from torch.utils.data import DataLoader

class DiabetesDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index): # 这是一个magic function, dataset[index]会调用这个函数
pass
def __len__(self): # 这个也是magic function,在调用len(dataset)的时候被调用
pass

dataset = DiabetesDataset()
train_loader = DataLoader(dataset=dataset,
batch_size=32,
shuffle=True, # 是否打乱数据
num_workers=2) # 几个进程来读数据

for epoch in range(100):
for i, data in enumerate(train_loader, 0): # 从train_loader中读取mini-batch
# 0表示从0开始枚举
# 1. Prepare data
inputs, labels = data
# 2. Forward
# 3. Backward
# 4. Update

一些现成数据集

torchvision包中带有众多已有数据集

在Colab上运行

课程来源:《PyTorch深度学习实践》完结合集