[study] tutorial Dataset, DataLoader
import torch
import numpy as np
data = [[1,2], [3,4]]
t_data = torch.tensor(data)
t_data
tensor([[1, 2], [3, 4]])
t_data.dtype
torch.int64
t_data.size()
torch.Size([2, 2])
np_array = np.array(data)
t_np = torch.from_numpy(np_array)
t_np
tensor([[1, 2], [3, 4]])
t_np.size()
torch.Size([2, 2])
t_np.shape
torch.Size([2, 2])
x_ones = torch.ones_like(t_data)
x_ones
tensor([[1, 1], [1, 1]])
x_rand = torch.rand_like(t_data, dtype=torch.float)
x_rand
tensor([[0.4474, 0.7949], [0.4205, 0.8668]])
shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(rand_tensor)
print(ones_tensor)
print(zeros_tensor)
tensor([[0.8367, 0.4293, 0.0447], [0.2943, 0.1218, 0.8405]]) tensor([[1., 1., 1.], [1., 1., 1.]]) tensor([[0., 0., 0.], [0., 0., 0.]])
tensor = torch.rand(3,4)
print(tensor.shape)
print(tensor.dtype)
print(tensor.device)
torch.Size([3, 4]) torch.float32 cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cpu'
tensor = torch.ones(4,4)
print(tensor)
print("First row", tensor[0])
print("First column", tensor[:,0])
print("Last Column", tensor[..., -1])
tensor[:,1] = 0
print(tensor)
tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) First row tensor([1., 1., 1., 1.]) First column tensor([1., 1., 1., 1.]) Last Column tensor([1., 1., 1., 1.]) tensor([[1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.]])
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)
tensor([[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.], [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.], [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.], [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.]])
t2 = torch.cat([tensor,tensor,tensor], dim=0)
t2
tensor([[1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.]])
t2.shape
torch.Size([12, 4])
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)
y3 = torch.rand_like(tensor)
print(y3)
torch.matmul(tensor, tensor.T, out=y3)
print(y1)
print(y2)
print(y3)
tensor([[0.3516, 0.7082, 0.2357, 0.0189], [0.5164, 0.8764, 0.1845, 0.9110], [0.4364, 0.2171, 0.0746, 0.9275], [0.3004, 0.5628, 0.6385, 0.0230]]) tensor([[3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.]]) tensor([[3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.]]) tensor([[3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.], [3., 3., 3., 3.]])
z1 = tensor * tensor
z2 = tensor.mul(tensor)
z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)
print(z1)
print(z2)
print(z3)
tensor([[1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.]]) tensor([[1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.]]) tensor([[1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.], [1., 0., 1., 1.]])
agg = tensor.sum()
agg = agg.type(torch.int32)
agg.item()
12
tensor = tensor.sub(1)
tensor
tensor([[ 0., -1., 0., 0.], [ 0., -1., 0., 0.], [ 0., -1., 0., 0.], [ 0., -1., 0., 0.]])
tensor = tensor.add(2)
tensor
tensor([[2., 1., 2., 2.], [2., 1., 2., 2.], [2., 1., 2., 2.], [2., 1., 2., 2.]])
t = torch.ones(5)
print(t)
print(type(t))
n = t.numpy()
print(n)
print(type(n))
at = torch.from_numpy(n)
print(at)
print(type(at))
tensor([1., 1., 1., 1., 1.]) <class 'torch.Tensor'> [1. 1. 1. 1. 1.] <class 'numpy.ndarray'> tensor([1., 1., 1., 1., 1.]) <class 'torch.Tensor'>
t.add_(1)
print(t)
print(n)
print(at)
tensor([2., 2., 2., 2., 2.]) [2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.])
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
train_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/26421880 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/29515 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/4422102 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/5148 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(9,9))
cols, rows = 4,4
for i in range(1, cols*rows + 1):
sample_idx = torch.randint(len(train_data), size=(1,)).item()
img, label = train_data[sample_idx]
figure.add_subplot(rows,cols, i)
plt.title(labels_map[label])
print(img.shape, labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
torch.Size([1, 28, 28]) Sneaker torch.Size([1, 28, 28]) T-Shirt torch.Size([1, 28, 28]) Shirt torch.Size([1, 28, 28]) Bag torch.Size([1, 28, 28]) Pullover torch.Size([1, 28, 28]) Bag torch.Size([1, 28, 28]) Trouser torch.Size([1, 28, 28]) Trouser torch.Size([1, 28, 28]) T-Shirt torch.Size([1, 28, 28]) Shirt torch.Size([1, 28, 28]) Ankle Boot torch.Size([1, 28, 28]) Coat torch.Size([1, 28, 28]) Pullover torch.Size([1, 28, 28]) Sandal torch.Size([1, 28, 28]) Shirt torch.Size([1, 28, 28]) Shirt
sample_idx = torch.randint(len(train_data), size=(1,)).item()
sample_idx
10819
len(train_data)
60000
train_data[sample_idx]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.8078, 0.8314, 0.8118, 1.0000, 0.1529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0392, 0.0000, 0.2588, 0.7255, 0.9804, 1.0000, 0.8549, 0.4824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627, 0.1059, 0.0000, 0.0000, 0.6000, 0.5608, 0.8196, 0.9333, 0.6667, 0.7333, 0.3922, 0.0000, 0.0588, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0980, 0.0000, 0.0353, 0.8667, 0.9098, 0.8196, 0.7098, 0.6000, 0.8039, 0.7961, 0.9647, 0.5647, 0.0000, 0.0863, 0.0745, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0784, 0.0000, 0.0000, 0.8000, 0.8667, 0.7059, 0.7569, 0.7373, 0.6667, 0.7765, 0.7216, 0.7294, 0.9686, 0.3608, 0.0000, 0.1059, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0863, 0.0000, 0.3686, 0.9333, 0.7059, 0.7490, 0.7490, 0.7294, 0.6627, 0.7804, 0.7725, 0.7490, 0.8196, 0.8706, 0.0000, 0.0314, 0.0863, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510, 0.0824, 0.0000, 0.6824, 0.7686, 0.7216, 0.7412, 0.7529, 0.7216, 0.6784, 0.7686, 0.5882, 0.5608, 0.5098, 0.8824, 0.2000, 0.0000, 0.1176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0549, 0.0510, 0.0196, 0.8431, 0.7686, 0.7294, 0.7412, 0.7490, 0.7176, 0.6902, 0.7725, 0.7255, 0.7804, 0.7216, 0.8824, 0.3490, 0.0000, 0.1059, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0275, 0.1451, 0.8235, 0.7804, 0.7569, 0.7529, 0.7647, 0.7333, 0.6941, 0.7725, 0.7608, 0.7686, 0.7255, 0.9059, 0.2941, 0.0000, 0.0902, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.0196, 0.0980, 0.7216, 0.7843, 0.7608, 0.7529, 0.7569, 0.7451, 0.7098, 0.7765, 0.7608, 0.8000, 0.7216, 0.8784, 0.2196, 0.0039, 0.0784, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0824, 0.0235, 0.1373, 0.7255, 0.7608, 0.7608, 0.7529, 0.7569, 0.7451, 0.7137, 0.7765, 0.7647, 0.7725, 0.7294, 0.8431, 0.2824, 0.0000, 0.0863, 0.0353, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0627, 0.1725, 0.8431, 0.7294, 0.7608, 0.7529, 0.7529, 0.7451, 0.7176, 0.7765, 0.7765, 0.7725, 0.7294, 0.8196, 0.4078, 0.0000, 0.1059, 0.0549, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0941, 0.1686, 0.7529, 0.7725, 0.7569, 0.7569, 0.7569, 0.7490, 0.7176, 0.7843, 0.7725, 0.7843, 0.7216, 0.7843, 0.5176, 0.0000, 0.1373, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0980, 0.2902, 0.4706, 0.8078, 0.7529, 0.7569, 0.7569, 0.7529, 0.7216, 0.7804, 0.7569, 0.7961, 0.7294, 0.7529, 0.7059, 0.0000, 0.1961, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0784, 0.1608, 0.3216, 0.5176, 0.7804, 0.7647, 0.7529, 0.7608, 0.7647, 0.7294, 0.7725, 0.7647, 0.7725, 0.7686, 0.7882, 0.6745, 0.0235, 0.2588, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0118, 0.2431, 0.2588, 0.3255, 0.8196, 0.7569, 0.7529, 0.7608, 0.7725, 0.7255, 0.7647, 0.7725, 0.7569, 0.7961, 0.7451, 0.3373, 0.1961, 0.2039, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0157, 0.1608, 0.4235, 0.2078, 0.8745, 0.7412, 0.7529, 0.7608, 0.7686, 0.7059, 0.7686, 0.7608, 0.7647, 0.8039, 0.6980, 0.3725, 0.3490, 0.1216, 0.0235, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0275, 0.0588, 0.6078, 0.2902, 0.8549, 0.7412, 0.7569, 0.7608, 0.7647, 0.7098, 0.7804, 0.7529, 0.7686, 0.8157, 0.6784, 0.3725, 0.4353, 0.0706, 0.0196, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0353, 0.0745, 0.5333, 0.2902, 0.8588, 0.7451, 0.7529, 0.7608, 0.7608, 0.7098, 0.7843, 0.7569, 0.7490, 0.8196, 0.6863, 0.3059, 0.5137, 0.0235, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0275, 0.0157, 0.1490, 0.3922, 0.2824, 0.8784, 0.7373, 0.7569, 0.7569, 0.7608, 0.7020, 0.7804, 0.7569, 0.7490, 0.8078, 0.6941, 0.2667, 0.5725, 0.0000, 0.0235, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.0078, 0.2392, 0.2980, 0.3216, 0.8980, 0.7412, 0.7686, 0.7686, 0.7725, 0.7020, 0.7843, 0.7569, 0.7569, 0.7882, 0.7529, 0.2824, 0.5922, 0.0000, 0.0314, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0275, 0.0000, 0.3412, 0.2588, 0.3569, 0.8784, 0.7255, 0.7529, 0.7569, 0.7608, 0.7020, 0.7686, 0.7373, 0.7451, 0.7569, 0.7804, 0.2667, 0.5725, 0.0000, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.0000, 0.4000, 0.1922, 0.3843, 0.9294, 0.7647, 0.7804, 0.7686, 0.7608, 0.6902, 0.7765, 0.7451, 0.7804, 0.8118, 0.8745, 0.2157, 0.5647, 0.0039, 0.0431, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2588, 0.1059, 0.3294, 0.8235, 0.7451, 0.7608, 0.7569, 0.7882, 0.7529, 0.7725, 0.7608, 0.7529, 0.7451, 0.8549, 0.1373, 0.4824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.1333, 0.3255, 0.6196, 0.0392, 0.1647, 0.7647, 0.5608, 0.6784, 0.8706, 0.8549, 0.7412, 0.9294, 0.8902, 0.6235, 0.6078, 0.6235, 0.0039, 0.6078, 0.2196, 0.2235, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.3882, 0.7922, 0.8549, 0.1569, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000, 0.9765, 0.8706, 0.7765, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.4941, 0.5098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4235, 0.5569, 0.2627, 0.0000, 0.0000, 0.0000, 0.0000]]]), 4)
import os
import pandas as pd
from torchvision.io import read_image
class myDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
self.img_labels = pd.read_csv(annotations_file, names=["file_name", "label"])
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.img_labels)
def __getitem__(self, idx):
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_trasnform:
label = self.target_trasnform(label)
return image, label
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
일반적으로 샘플들을 “미니배치(minibatch)”로 전달하고, 매 에폭(epoch)마다 데이터를 다시 섞어서 과적합(overfit)을 막고, Python의 multiprocessing 을 사용하여 데이터 검색 속도를 높이려고 합니다.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
Feature batch shape: torch.Size([64, 1, 28, 28]) Labels batch shape: torch.Size([64])
Label: 9
댓글남기기