UNet 실습

UNet

import os
import sys
print(os.getcwd())
!ls
/content
sample_data
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
  project_path = '/content/drive/MyDrive/PyTorch_YearDream/2022-01-18'
sys.path.insert(0, project_path)
# # For download pascal voc 2007 dataset
# from torchvision.datasets import VOCSegmentation
# VOCSegmentation(root=os.path.join(project_path, 'data'), year='2007', download=False)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import os
import albumentations as A
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision.transforms.functional import to_tensor
from torchvision.utils import save_image

class DatasetFromFolderVOC(Dataset):
    def __init__(self, root_dir=os.path.join(project_path,'data/VOCdevkit/VOC2007'), txt_file='train.txt', transforms=None):
        super(DatasetFromFolderVOC, self).__init__()
        with open(os.path.join(root_dir, 'ImageSets/Segmentation', txt_file), 'r') as f:
            self.filenames = f.readlines()
        self.filenames = [file.strip() for file in self.filenames]
        self.img_path = os.path.join(root_dir, "JPEGImages")
        self.gt_path = os.path.join(root_dir, "SegmentationClass")
        self.transforms = transforms
        self.pallete = Image.open(os.path.join(self.gt_path, f'{self.filenames[0]}.png')).getpalette()

    # getitem 함수는 index에 맞는 data를 반환하는 역할을 합니다.
    def __getitem__(self, index):
        img = Image.open(os.path.join(self.img_path, f'{self.filenames[index]}.jpg')).convert('RGB')
        gt = Image.open(os.path.join(self.gt_path, f'{self.filenames[index]}.png'))

        aug = self.transforms(image=np.array(img), mask=np.array(gt))
        img = to_tensor(aug['image'])
        gt = aug['mask']
        gt[gt > 20] = 0
        gt = torch.tensor(gt)

        return img, gt

    # len 함수는 Dataset 전체 개수를 반환합니다.
    def __len__(self):
        return len(self.filenames)
transform = A.Compose([
    A.Resize(512, 512),
    #A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

dataset = DatasetFromFolderVOC(transforms=transform)
img, gt = dataset[0]
pallete = dataset.pallete
tf = transforms.ToPILImage()
tf(img)

gt_img = tf(gt)
gt_img.putpalette(dataset.pallete)
gt_img

from torch.utils.data import DataLoader

transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])

train_set = DatasetFromFolderVOC(txt_file='train.txt', transforms=transform)
test_set = DatasetFromFolderVOC(txt_file='val.txt', transforms=transform)

# Training Set과 Test Set을 각각 DataLoader에 넣습니다.
trainDataLoader = DataLoader(dataset=train_set, num_workers=2, batch_size=8, shuffle=True)
testDataLoader = DataLoader(dataset=test_set, num_workers=1, batch_size=1, shuffle=False)

##직접 레이어를 구현해봅시다!

UNet

그림을 참고하여 [Conv, BatchNorm, ReLU] 2개로 구성된 있는 레이어를 구현해 봅시다

class ConvBlock(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    # Class initialization 인자를 확인해 보세요
    def __init__(self, in_channels, out_channels, mid_channels= -11):
        super(ConvBlock, self).__init__()
        if mid_channels == -11:
          mid_channels = out_channels
        self.conv1 =  nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
        self.batchnor1 = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
        self.batchnor2 = nn.BatchNorm2d(out_channels)
        

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.batchnor1(out))
        
        out = self.conv2(out)
        out = F.relu(self.batchnor2(out))
        
        return out


# input_tensor = torch.LongTensor([1,2,3], [64])
# debug = ConvBlock(input_tensor)
# print("Input shape:", input_tensor.shape)
# print("Output shape:", debug.shape)
# print("Example features:\n", embed_vectors[:,:,:2])
# debug = ConvBlock()
# torch.Tensor(2,3,64,64)

그림을 참고하여 [Downsampling, ConvBlock] 으로 구성된 레이어를 구현해 봅시다

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            ConvBlock(in_channels, out_channels)
        )


    def forward(self, x):
        x = self.maxpool_conv(x)
        return x

그림을 참고하여 [Upsampling, ConvBlock] 으로 구성된 레이어를 구현해 봅시다

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels):
      super(Up, self).__init__()
      self.up = nn.Upsample(scale_factor=2, mode='bilinear')
      self.conv = ConvBlock(in_channels, out_channels, in_channels // 2)
        


    def forward(self, x1, x2):
      x1 = self.up(x1)
      x = torch.cat([x2, x1], dim=1)
      x = self.conv(x)
      return x

그림을 참고하여 최종 레이어를 구현해 봅시다

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)



    def forward(self, x):
      x = self.conv(x)
      return x

UNet

앞서 구현한 레이어들을 이용하여 모델을 구현해봅시다

UNet Class 정의

여기서 이상한게 1024가 되어야 한다는건데

512로 주면 ConvBlock으로 잘 넘어가는데, 왜 안되는건가.

이게 1024로 받을거면 permute 시켜서, 자리를 바꿔줘야 한다.

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
      super(UNet, self).__init__()
      # self.n_channels = n_channels
      # self.n_classes = n_classes

      self.one = ConvBlock(n_channels, 64)
      self.two = Down(64, 128)
      self.three = Down(128, 256)
      self.four = Down(256, 512)
      self.five = Down(512, 512)
      self.six = Up(1024, 256)
      self.seven = Up(512, 128)
      self.eight = Up(256, 64)
      self.nine = Up(128, 64)
      self.out = OutConv(64, n_classes)


    def forward(self, x):
      # x = x.permute(0,3,1,2)              # 아 여기서 permute시켜서 꼬였었네
      x1 = self.one(x)
      x2 = self.two(x1)
      x3 = self.three(x2)
      x4 = self.four(x3)
      x5 = self.five(x4)
      x = self.six(x5, x4)
      x = self.seven(x, x3)
      x = self.eight(x, x2)
      x = self.nine(x, x1)
      logits = self.out(x)
      return logits

num_classes = 21  # 20 + background
model = UNet(3, num_classes).cuda()
model = model.train()
criterion = nn.CrossEntropyLoss()
# VOC_CLASSES = [
#     "background",
#     "aeroplane",
#     "bicycle",
#     "bird",
#     "boat",
#     "bottle",
#     "bus",
#     "car",
#     "cat",
#     "chair",
#     "cow",
#     "diningtable",
#     "dog",
#     "horse",
#     "motorbike",
#     "person",
#     "potted plant",
#     "sheep",
#     "sofa",
#     "train",
#     "tv/monitor",
# ]

학습 코드를 완성시켜 봅시다


import torch.optim as optim
from tqdm import tqdm


# setup optimizer
# optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=1e-5, betas=(0.9, 0.999))

total_epoch = 3

# Training
for epoch in range(total_epoch):
    for i, (img, gt) in enumerate(trainDataLoader):
        img = img.cuda()
        gt = gt.long().cuda()

        # forward and calculate the loss
        pred = model(img)
        loss = criterion(pred, gt)




        # Backprop + Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()



        if i % 10 == 0:        
            print(f'[{epoch+1}][{i+1}/{len(trainDataLoader)}] Loss: {loss.item():.3f}')
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3635: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode)
[1][1/27] Loss: 2.987
[1][11/27] Loss: 2.884
[1][21/27] Loss: 2.907
[2][1/27] Loss: 2.838
[2][11/27] Loss: 2.737
[2][21/27] Loss: 2.706
[3][1/27] Loss: 2.751
[3][11/27] Loss: 2.736
[3][21/27] Loss: 2.560
# Visualization
model.eval()
num_imgs = 4
preds_train, gts_train = None, None
preds_val, gts_val = None, None
for i, (img, gt) in enumerate(trainDataLoader):
    if i * trainDataLoader.batch_size >= num_imgs:
        break
    with torch.no_grad():
        # forward
        img = img.cuda()
        gt = gt.long().cuda()
        pred = model(img)
        output_pred = pred.argmax(1)
        for j in range(min(trainDataLoader.batch_size, 4)):
            out_img = Image.fromarray(np.array(output_pred[j].cpu()).astype('uint8'))
            out_img.putpalette(train_set.pallete)
            out_img = to_tensor(out_img.convert('RGB')).unsqueeze(0)
            preds_train = out_img if i == 0 and j == 0 else torch.cat([preds_train, out_img])
            out_img = Image.fromarray(np.array(gt[j].cpu()).astype('uint8'))
            out_img.putpalette(train_set.pallete)
            out_img = to_tensor(out_img.convert('RGB')).unsqueeze(0)
            gts_train = out_img if i == 0 and j == 0 else torch.cat([gts_train, out_img])
            
for i, (img, gt) in enumerate(testDataLoader):
    if i * testDataLoader.batch_size >= num_imgs:
        break
    with torch.no_grad():
        # forward
        img = img.cuda()
        gt = gt.long().cuda()
        pred = model(img)
        output_pred = pred.argmax(1)
        for j in range(min(testDataLoader.batch_size, num_imgs)):
            out_img = Image.fromarray(np.array(output_pred[j].cpu()).astype('uint8'))
            out_img.putpalette(test_set.pallete)
            out_img = to_tensor(out_img.convert('RGB')).unsqueeze(0)
            preds_test = out_img if i == 0 and j == 0 else torch.cat([preds_test, out_img])
            out_img = Image.fromarray(np.array(gt[j].cpu()).astype('uint8'))
            out_img.putpalette(test_set.pallete)
            out_img = to_tensor(out_img.convert('RGB')).unsqueeze(0)
            gts_test = out_img if i == 0 and j == 0 else torch.cat([gts_test, out_img])

train_samples = torch.cat([preds_train.cpu(), gts_train])
train_samples = torchvision.utils.make_grid(train_samples, nrow=num_imgs)
train_samples = train_samples.permute(1, 2, 0)

test_samples = torch.cat([preds_test.cpu(), gts_test])
test_samples = torchvision.utils.make_grid(test_samples, nrow=num_imgs)
test_samples = test_samples.permute(1, 2, 0)

plt.rcParams["figure.figsize"] = (30, 15)
plt.subplot(2, 1, 1)
plt.imshow(train_samples)
plt.subplot(2, 1, 2)
plt.imshow(test_samples)
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3635: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
  "See the documentation of nn.Upsample for details.".format(mode)
<matplotlib.image.AxesImage at 0x7fb4a524bf50>


댓글남기기