廣泛用于圖像分類(lèi)的數(shù)據(jù)集之一是手寫(xiě)數(shù)字的MNIST 數(shù)據(jù)集 (LeCun等人,1998 年) 。在 1990 年代發(fā)布時(shí),它對(duì)大多數(shù)機(jī)器學(xué)習(xí)算法提出了巨大挑戰(zhàn),其中包含 60,000 張圖像 28×28像素分辨率(加上 10,000 張圖像的測(cè)試數(shù)據(jù)集)。客觀地說(shuō),在 1995 年,配備高達(dá) 64MB RAM 和驚人的 5 MFLOPs 的 Sun SPARCStation 5 被認(rèn)為是 AT&T 貝爾實(shí)驗(yàn)室最先進(jìn)的機(jī)器學(xué)習(xí)設(shè)備。實(shí)現(xiàn)數(shù)字識(shí)別的高精度是一個(gè)1990 年代 USPS 自動(dòng)分揀信件的關(guān)鍵組件。深度網(wǎng)絡(luò),如 LeNet-5 (LeCun等人,1995 年)、具有不變性的支持向量機(jī) (Sch?lkopf等人,1996 年)和切線距離分類(lèi)器 (Simard等人,1998 年)都允許達(dá)到 1% 以下的錯(cuò)誤率。
十多年來(lái),MNIST 一直是比較機(jī)器學(xué)習(xí)算法的參考點(diǎn)。雖然它作為基準(zhǔn)數(shù)據(jù)集運(yùn)行良好,但即使是按照當(dāng)今標(biāo)準(zhǔn)的簡(jiǎn)單模型也能達(dá)到 95% 以上的分類(lèi)準(zhǔn)確率,這使得它不適合區(qū)分強(qiáng)模型和弱模型。更重要的是,數(shù)據(jù)集允許非常高的準(zhǔn)確性,這在許多分類(lèi)問(wèn)題中通常是看不到的。這種算法的發(fā)展偏向于可以利用干凈數(shù)據(jù)集的特定算法系列,例如活動(dòng)集方法和邊界搜索活動(dòng)集算法。今天,MNIST 更像是一種健全性檢查,而不是基準(zhǔn)。ImageNet ( Deng et al. , 2009 )提出了一個(gè)更相關(guān)的挑戰(zhàn)。不幸的是,對(duì)于本書(shū)中的許多示例和插圖來(lái)說(shuō),ImageNet 太大了,因?yàn)橛?xùn)練這些示例需要很長(zhǎng)時(shí)間才能使示例具有交互性。作為替代,我們將在接下來(lái)的部分中重點(diǎn)討論定性相似但規(guī)模小得多的 Fashion-MNIST 數(shù)據(jù)集(Xiao等人,2017 年),該數(shù)據(jù)集于 2017 年發(fā)布。它包含 10 類(lèi)服裝的圖像 28×28像素分辨率。
%matplotlib inline
import time
import jax
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from jax import numpy as jnp
from d2l import jax as d2l
d2l.use_svg_display()
4.2.1. 加載數(shù)據(jù)集
由于它是一個(gè)經(jīng)常使用的數(shù)據(jù)集,所有主要框架都提供了它的預(yù)處理版本。我們可以使用內(nèi)置的框架實(shí)用程序?qū)?Fashion-MNIST 數(shù)據(jù)集下載并讀取到內(nèi)存中。
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
trans = transforms.Compose([transforms.Resize(resize),
transforms.ToTensor()])
self.train = torchvision.datasets.FashionMNIST(
root=self.root, train=True, transform=trans, download=True)
self.val = torchvision.datasets.FashionMNIST(
root=self.root, train=False, transform=trans, download=True)
class FashionMNIST(d2l.DataModule): #@save
"""The Fashion-MNIST dataset."""
def __init__(self, batch_size=64, resize=(28, 28)):
super().__init__()
self.save_hyperparameters()
trans = transforms.Compose([transforms.Resize(resize),
transforms.ToTensor()])
self.train = gluon.data.vision.FashionMNIST(
train=True).transform_first(trans)
self.val = gluon.data.vision.FashionMNIST(
train=False).transform_first(trans)
Fashion-MNIST 包含來(lái)自 10 個(gè)類(lèi)別的圖像,每個(gè)類(lèi)別在訓(xùn)練數(shù)據(jù)集中由 6,000 個(gè)圖像表示,在測(cè)試數(shù)據(jù)集中由 1,000 個(gè)圖像表示。測(cè)試 數(shù)據(jù)集用于評(píng)估模型性能(不得用于訓(xùn)練)。因此,訓(xùn)練集和測(cè)試集分別包含 60,000 和 10,000 張圖像。
圖像是灰度和放大到32×32分辨率以上的像素。這類(lèi)似于由(二進(jìn)制)黑白圖像組成的原始 MNIST 數(shù)據(jù)集。但請(qǐng)注意,大多數(shù)具有 3 個(gè)通道(紅色、綠色、藍(lán)色)的現(xiàn)代圖像數(shù)據(jù)和超過(guò) 100 個(gè)通道的高光譜圖像(HyMap 傳感器有 126 個(gè)通道)。按照慣例,我們將圖像存儲(chǔ)為 c×h×w張量,其中c是顏色通道數(shù),h是高度和w是寬度。
評(píng)論