import torch
import numpy as np
data = [[1,2], [3,4]]
t_data = torch.tensor(data)
t_data
tensor([[1, 2],
        [3, 4]])
t_data.dtype
torch.int64
t_data.size()
torch.Size([2, 2])

np_array = np.array(data)
t_np = torch.from_numpy(np_array)
t_np
tensor([[1, 2],
        [3, 4]])
t_np.size()
torch.Size([2, 2])
t_np.shape
torch.Size([2, 2])
x_ones = torch.ones_like(t_data)
x_ones
tensor([[1, 1],
        [1, 1]])
x_rand = torch.rand_like(t_data, dtype=torch.float)
x_rand
tensor([[0.4474, 0.7949],
        [0.4205, 0.8668]])
shape = (2,3,)
rand_tensor = torch.rand(shape)
ones_tensor = torch.ones(shape)
zeros_tensor = torch.zeros(shape)
print(rand_tensor)
print(ones_tensor)
print(zeros_tensor)
tensor([[0.8367, 0.4293, 0.0447],
        [0.2943, 0.1218, 0.8405]])
tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])
tensor = torch.rand(3,4)
print(tensor.shape)
print(tensor.dtype)
print(tensor.device)
torch.Size([3, 4])
torch.float32
cpu
device = "cuda" if torch.cuda.is_available() else "cpu"
device
'cpu'
tensor = torch.ones(4,4)
print(tensor)
print("First row", tensor[0])
print("First column", tensor[:,0])
print("Last Column", tensor[..., -1])
tensor[:,1] = 0
print(tensor)
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
First row tensor([1., 1., 1., 1.])
First column tensor([1., 1., 1., 1.])
Last Column tensor([1., 1., 1., 1.])
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])
t1 = torch.cat([tensor, tensor, tensor], dim=1)
print(t1)
tensor([[1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.],
        [1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1.]])
t2 = torch.cat([tensor,tensor,tensor], dim=0)
t2
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])
t2.shape
torch.Size([12, 4])
y1 = tensor @ tensor.T
y2 = tensor.matmul(tensor.T)
y3 = torch.rand_like(tensor)
print(y3)
torch.matmul(tensor, tensor.T, out=y3)

print(y1)
print(y2)
print(y3)
tensor([[0.3516, 0.7082, 0.2357, 0.0189],
        [0.5164, 0.8764, 0.1845, 0.9110],
        [0.4364, 0.2171, 0.0746, 0.9275],
        [0.3004, 0.5628, 0.6385, 0.0230]])
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 3., 3., 3.]])
z1 = tensor * tensor
z2 = tensor.mul(tensor)
z3 = torch.rand_like(tensor)
torch.mul(tensor, tensor, out=z3)
print(z1)
print(z2)
print(z3)
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])
tensor([[1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.],
        [1., 0., 1., 1.]])
agg = tensor.sum()

agg = agg.type(torch.int32)
agg.item()
12
tensor = tensor.sub(1)
tensor
tensor([[ 0., -1.,  0.,  0.],
        [ 0., -1.,  0.,  0.],
        [ 0., -1.,  0.,  0.],
        [ 0., -1.,  0.,  0.]])
tensor = tensor.add(2)
tensor
tensor([[2., 1., 2., 2.],
        [2., 1., 2., 2.],
        [2., 1., 2., 2.],
        [2., 1., 2., 2.]])
t = torch.ones(5)
print(t)
print(type(t))
n = t.numpy()
print(n)
print(type(n))
at = torch.from_numpy(n)
print(at)
print(type(at))
tensor([1., 1., 1., 1., 1.])
<class 'torch.Tensor'>
[1. 1. 1. 1. 1.]
<class 'numpy.ndarray'>
tensor([1., 1., 1., 1., 1.])
<class 'torch.Tensor'>
t.add_(1)
print(t)
print(n)
print(at)
tensor([2., 2., 2., 2., 2.])
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.])
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
train_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
  0%|          | 0/5148 [00:00<?, ?it/s]
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(9,9))
cols, rows = 4,4
for i in range(1, cols*rows + 1):
  sample_idx = torch.randint(len(train_data), size=(1,)).item()
  img, label = train_data[sample_idx]
  figure.add_subplot(rows,cols, i)
  plt.title(labels_map[label])
  print(img.shape, labels_map[label])
  plt.axis("off")
  plt.imshow(img.squeeze(), cmap="gray")
plt.show()
torch.Size([1, 28, 28]) Sneaker
torch.Size([1, 28, 28]) T-Shirt
torch.Size([1, 28, 28]) Shirt
torch.Size([1, 28, 28]) Bag
torch.Size([1, 28, 28]) Pullover
torch.Size([1, 28, 28]) Bag
torch.Size([1, 28, 28]) Trouser
torch.Size([1, 28, 28]) Trouser
torch.Size([1, 28, 28]) T-Shirt
torch.Size([1, 28, 28]) Shirt
torch.Size([1, 28, 28]) Ankle Boot
torch.Size([1, 28, 28]) Coat
torch.Size([1, 28, 28]) Pullover
torch.Size([1, 28, 28]) Sandal
torch.Size([1, 28, 28]) Shirt
torch.Size([1, 28, 28]) Shirt

sample_idx = torch.randint(len(train_data), size=(1,)).item()
sample_idx
10819
len(train_data)
60000

train_data[sample_idx]
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039,
           0.0000, 0.0000, 0.0000, 0.0000, 0.8078, 0.8314, 0.8118, 1.0000,
           0.1529, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0392, 0.0000, 0.2588, 0.7255, 0.9804, 1.0000, 0.8549,
           0.4824, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0627,
           0.1059, 0.0000, 0.0000, 0.6000, 0.5608, 0.8196, 0.9333, 0.6667,
           0.7333, 0.3922, 0.0000, 0.0588, 0.0431, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0980,
           0.0000, 0.0353, 0.8667, 0.9098, 0.8196, 0.7098, 0.6000, 0.8039,
           0.7961, 0.9647, 0.5647, 0.0000, 0.0863, 0.0745, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0784, 0.0000,
           0.0000, 0.8000, 0.8667, 0.7059, 0.7569, 0.7373, 0.6667, 0.7765,
           0.7216, 0.7294, 0.9686, 0.3608, 0.0000, 0.1059, 0.0314, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.0863, 0.0000,
           0.3686, 0.9333, 0.7059, 0.7490, 0.7490, 0.7294, 0.6627, 0.7804,
           0.7725, 0.7490, 0.8196, 0.8706, 0.0000, 0.0314, 0.0863, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510, 0.0824, 0.0000,
           0.6824, 0.7686, 0.7216, 0.7412, 0.7529, 0.7216, 0.6784, 0.7686,
           0.5882, 0.5608, 0.5098, 0.8824, 0.2000, 0.0000, 0.1176, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0549, 0.0510, 0.0196,
           0.8431, 0.7686, 0.7294, 0.7412, 0.7490, 0.7176, 0.6902, 0.7725,
           0.7255, 0.7804, 0.7216, 0.8824, 0.3490, 0.0000, 0.1059, 0.0039,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0275, 0.1451,
           0.8235, 0.7804, 0.7569, 0.7529, 0.7647, 0.7333, 0.6941, 0.7725,
           0.7608, 0.7686, 0.7255, 0.9059, 0.2941, 0.0000, 0.0902, 0.0078,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0902, 0.0196, 0.0980,
           0.7216, 0.7843, 0.7608, 0.7529, 0.7569, 0.7451, 0.7098, 0.7765,
           0.7608, 0.8000, 0.7216, 0.8784, 0.2196, 0.0039, 0.0784, 0.0275,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0824, 0.0235, 0.1373,
           0.7255, 0.7608, 0.7608, 0.7529, 0.7569, 0.7451, 0.7137, 0.7765,
           0.7647, 0.7725, 0.7294, 0.8431, 0.2824, 0.0000, 0.0863, 0.0353,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0706, 0.0627, 0.1725,
           0.8431, 0.7294, 0.7608, 0.7529, 0.7529, 0.7451, 0.7176, 0.7765,
           0.7765, 0.7725, 0.7294, 0.8196, 0.4078, 0.0000, 0.1059, 0.0549,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0941, 0.1686,
           0.7529, 0.7725, 0.7569, 0.7569, 0.7569, 0.7490, 0.7176, 0.7843,
           0.7725, 0.7843, 0.7216, 0.7843, 0.5176, 0.0000, 0.1373, 0.0667,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.0980, 0.2902,
           0.4706, 0.8078, 0.7529, 0.7569, 0.7569, 0.7529, 0.7216, 0.7804,
           0.7569, 0.7961, 0.7294, 0.7529, 0.7059, 0.0000, 0.1961, 0.0431,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.0784, 0.1608, 0.3216,
           0.5176, 0.7804, 0.7647, 0.7529, 0.7608, 0.7647, 0.7294, 0.7725,
           0.7647, 0.7725, 0.7686, 0.7882, 0.6745, 0.0235, 0.2588, 0.0118,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.0118, 0.2431, 0.2588,
           0.3255, 0.8196, 0.7569, 0.7529, 0.7608, 0.7725, 0.7255, 0.7647,
           0.7725, 0.7569, 0.7961, 0.7451, 0.3373, 0.1961, 0.2039, 0.0078,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0157, 0.1608, 0.4235,
           0.2078, 0.8745, 0.7412, 0.7529, 0.7608, 0.7686, 0.7059, 0.7686,
           0.7608, 0.7647, 0.8039, 0.6980, 0.3725, 0.3490, 0.1216, 0.0235,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0275, 0.0588, 0.6078,
           0.2902, 0.8549, 0.7412, 0.7569, 0.7608, 0.7647, 0.7098, 0.7804,
           0.7529, 0.7686, 0.8157, 0.6784, 0.3725, 0.4353, 0.0706, 0.0196,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.0353, 0.0745, 0.5333,
           0.2902, 0.8588, 0.7451, 0.7529, 0.7608, 0.7608, 0.7098, 0.7843,
           0.7569, 0.7490, 0.8196, 0.6863, 0.3059, 0.5137, 0.0235, 0.0275,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0275, 0.0157, 0.1490, 0.3922,
           0.2824, 0.8784, 0.7373, 0.7569, 0.7569, 0.7608, 0.7020, 0.7804,
           0.7569, 0.7490, 0.8078, 0.6941, 0.2667, 0.5725, 0.0000, 0.0235,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.0078, 0.2392, 0.2980,
           0.3216, 0.8980, 0.7412, 0.7686, 0.7686, 0.7725, 0.7020, 0.7843,
           0.7569, 0.7569, 0.7882, 0.7529, 0.2824, 0.5922, 0.0000, 0.0314,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0275, 0.0000, 0.3412, 0.2588,
           0.3569, 0.8784, 0.7255, 0.7529, 0.7569, 0.7608, 0.7020, 0.7686,
           0.7373, 0.7451, 0.7569, 0.7804, 0.2667, 0.5725, 0.0000, 0.0431,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0471, 0.0000, 0.4000, 0.1922,
           0.3843, 0.9294, 0.7647, 0.7804, 0.7686, 0.7608, 0.6902, 0.7765,
           0.7451, 0.7804, 0.8118, 0.8745, 0.2157, 0.5647, 0.0039, 0.0431,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2588, 0.1059,
           0.3294, 0.8235, 0.7451, 0.7608, 0.7569, 0.7882, 0.7529, 0.7725,
           0.7608, 0.7529, 0.7451, 0.8549, 0.1373, 0.4824, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.1333, 0.3255, 0.6196, 0.0392,
           0.1647, 0.7647, 0.5608, 0.6784, 0.8706, 0.8549, 0.7412, 0.9294,
           0.8902, 0.6235, 0.6078, 0.6235, 0.0039, 0.6078, 0.2196, 0.2235,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.3882, 0.7922, 0.8549, 0.1569,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0078, 0.0000, 0.0000, 0.9765, 0.8706, 0.7765,
           0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.1098, 0.4941, 0.5098, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
           0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4235, 0.5569, 0.2627,
           0.0000, 0.0000, 0.0000, 0.0000]]]), 4)
import os
import pandas as pd
from torchvision.io import read_image

class myDataset(Dataset):
  def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
    self.img_labels = pd.read_csv(annotations_file, names=["file_name", "label"])
    self.img_dir = img_dir
    self.transform = transform
    self.target_transform = target_transform
  
  def __len__(self):
    return len(self.img_labels)

  def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
    image = read_image(img_path)
    label = self.img_labels.iloc[idx, 1]
    if self.transform:
      image = self.transform(image)
    if self.target_trasnform:
      label = self.target_trasnform(label)
    return image, label
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

일반적으로 샘플들을 “미니배치(minibatch)”로 전달하고, 매 에폭(epoch)마다 데이터를 다시 섞어서 과적합(overfit)을 막고, Python의 multiprocessing 을 사용하여 데이터 검색 속도를 높이려고 합니다.

train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])

Label: 9

댓글남기기