안경잡이개발자

728x90
반응형

  최근에 데이터셋을 직접 구축하여, 내가 만든 데이터셋으로 학습(Training)을 해야 하는 일이 생겼다.

 

  PyTorch에서는 ImageFolder라는 라이브러리를 제공한다. 이는 다음과 같은 계층적인 폴더 구조를 가지고 있는 데이터셋을 불러올 때 사용할 수 있다. 다시 말해 다음과 같이 각 이미지들이 자신의 레이블(Label) 이름으로 된 폴더 안에 들어가 있는 구조라면, ImageFolder 라이브러리를 이용하여 이를 바로 불러와 객체로 만들면 된다.

 

dataset/
	0/
		0.jpg
		1.jpg
        	...
	1/
		0.jpg
		1.jpg
		...
	...
	9/
		0.jpg
		1.jpg
		...

 

  한 번 연습을 위해서 기존에 존재하는 CIFAR-10 데이터셋을 불러와서, 이를 계층적인 폴더 구조가 되도록 이미지를 저장하는 소스코드를 만들어 보자. 그 다음에 다시 ImageFolder 라이브러리로 동일한 CIFAR-10 데이터셋을 불러오면 성공이다.

 

  먼저 다음과 같이 기본적으로 PyTorch에서 제공하고 있는 CIFAR-10 데이터셋을 불러와보자.

 

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.image as image
import numpy as np

transform_train = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

 

  CIFAR-10의 경우 10개의 레이블로 구성된 데이터셋이므로, 각 레이블의 이미지가 몇 번 등장했는지를 기록해주는 변수를 선언하자.

 

import os

num_classes = 10
number_per_class = {}

for i in range(num_classes):
    number_per_class[i] = 0

 

  이후에 이미지 Torch 객체레이블 정수 값이 들어왔을 때, 이를 실제 폴더에 저장해주는 함수를 작성하자.

 

def custom_imsave(img, label):
    path = 'dataset/' + str(label) + '/'
    if not os.path.exists(path):
        os.makedirs(path)
    
    img = img.numpy()
    img = np.transpose(img, (1, 2, 0))
    image.imsave(path + str(number_per_class[label]) + '.jpg', img)
    number_per_class[label] += 1

 

  이제 만들어진 도구들을 이용하여 CIFAR-10 데이터셋에서 데이터를 배치 단위로 읽으며, 배치에 포함된 각 이미지를 하나씩 정확한 폴더에 저장될 수 있도록 하자.

 

def process():
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        print("[ Current Batch Index: " + str(batch_idx) + " ]")
        for i in range(inputs.size(0)):
            custom_imsave(inputs[i], targets[i].item())

process()

 

  이후에 한 번 0번 레이블(비행기)의 첫 번째 이미지를 출력하도록 해보자. 정상적으로 잘 출력된다.

 

from PIL import Image
from matplotlib.pyplot import imshow

img = Image.open('dataset/0/0.jpg')
imshow(np.asarray(img))

 

 

  이제 ImageFolder 라이브러리를 이용해서, 우리가 저장한 이미지들을 이용해 다시 PyTorch 데이터셋 객체로 불러올 수 있는지 확인해보도록 하자.

 

from torchvision.datasets import ImageFolder

train_dataset = ImageFolder(root='./dataset', transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

 

  이미지 출력용 함수는 다음과 같다. PyTorch의 경우 [Batch Size, Channel, Width, Height]의 구조를 가지고 있어서, 이를 matplotlib로 출력하기 위해서는 [Width, Height, Channel]의 순서로 변경해주어야 한다.

 

def custom_imshow(img):
    img = img.numpy()
    plt.imshow(np.transpose(img, (1, 2, 0)))
    plt.show()

 

  이제 이미지를 하나씩 출력하도록 해보자.

 

def process():
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        custom_imshow(inputs[0])

process()

 

  실행 결과, 다음과 같이 정상적으로 데이터셋이 구성되었다는 사실을 알 수 있다.

 

728x90
반응형

Comment +1

  • aa 2021.05.17 19:35

    torch는 b, c, h, w 아닌가요?
    matplotlib도 ndarray 형태로 imshow하면 h, w, c 순서 아닌가요?