PyTorch 나만의 데이터셋을 만들고, 이를 ImageFolder로 불러오기
최근에 데이터셋을 직접 구축하여, 내가 만든 데이터셋으로 학습(Training)을 해야 하는 일이 생겼다.
PyTorch에서는 ImageFolder라는 라이브러리를 제공한다. 이는 다음과 같은 계층적인 폴더 구조를 가지고 있는 데이터셋을 불러올 때 사용할 수 있다. 다시 말해 다음과 같이 각 이미지들이 자신의 레이블(Label) 이름으로 된 폴더 안에 들어가 있는 구조라면, ImageFolder 라이브러리를 이용하여 이를 바로 불러와 객체로 만들면 된다.
dataset/
0/
0.jpg
1.jpg
...
1/
0.jpg
1.jpg
...
...
9/
0.jpg
1.jpg
...
한 번 연습을 위해서 기존에 존재하는 CIFAR-10 데이터셋을 불러와서, 이를 계층적인 폴더 구조가 되도록 이미지를 저장하는 소스코드를 만들어 보자. 그 다음에 다시 ImageFolder 라이브러리로 동일한 CIFAR-10 데이터셋을 불러오면 성공이다.
먼저 다음과 같이 기본적으로 PyTorch에서 제공하고 있는 CIFAR-10 데이터셋을 불러와보자.
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.image as image
import numpy as np
transform_train = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
CIFAR-10의 경우 10개의 레이블로 구성된 데이터셋이므로, 각 레이블의 이미지가 몇 번 등장했는지를 기록해주는 변수를 선언하자.
import os
num_classes = 10
number_per_class = {}
for i in range(num_classes):
number_per_class[i] = 0
이후에 이미지 Torch 객체와 레이블 정수 값이 들어왔을 때, 이를 실제 폴더에 저장해주는 함수를 작성하자.
def custom_imsave(img, label):
path = 'dataset/' + str(label) + '/'
if not os.path.exists(path):
os.makedirs(path)
img = img.numpy()
img = np.transpose(img, (1, 2, 0))
image.imsave(path + str(number_per_class[label]) + '.jpg', img)
number_per_class[label] += 1
이제 만들어진 도구들을 이용하여 CIFAR-10 데이터셋에서 데이터를 배치 단위로 읽으며, 배치에 포함된 각 이미지를 하나씩 정확한 폴더에 저장될 수 있도록 하자.
def process():
for batch_idx, (inputs, targets) in enumerate(train_loader):
print("[ Current Batch Index: " + str(batch_idx) + " ]")
for i in range(inputs.size(0)):
custom_imsave(inputs[i], targets[i].item())
process()
이후에 한 번 0번 레이블(비행기)의 첫 번째 이미지를 출력하도록 해보자. 정상적으로 잘 출력된다.
from PIL import Image
from matplotlib.pyplot import imshow
img = Image.open('dataset/0/0.jpg')
imshow(np.asarray(img))
이제 ImageFolder 라이브러리를 이용해서, 우리가 저장한 이미지들을 이용해 다시 PyTorch 데이터셋 객체로 불러올 수 있는지 확인해보도록 하자.
from torchvision.datasets import ImageFolder
train_dataset = ImageFolder(root='./dataset', transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
이미지 출력용 함수는 다음과 같다. PyTorch의 경우 [Batch Size, Channel, Width, Height]의 구조를 가지고 있어서, 이를 matplotlib로 출력하기 위해서는 [Width, Height, Channel]의 순서로 변경해주어야 한다.
def custom_imshow(img):
img = img.numpy()
plt.imshow(np.transpose(img, (1, 2, 0)))
plt.show()
이제 이미지를 하나씩 출력하도록 해보자.
def process():
for batch_idx, (inputs, targets) in enumerate(train_loader):
custom_imshow(inputs[0])
process()
실행 결과, 다음과 같이 정상적으로 데이터셋이 구성되었다는 사실을 알 수 있다.
'인공지능' 카테고리의 다른 글
PyTorch의 전이 학습(Transfer Learning)에서 Freezing 여부에 따른 성능 차이 및 유의점 (0) | 2021.02.24 |
---|---|
CNN (Convolutional Neural Network) 요약 정리 (0) | 2020.10.20 |
PyTorch에서 특정 Dataset을 열어 이미지 출력하기 (3) | 2020.04.10 |
Google CoLab으로 머신러닝 공부 편하게 시작하기 (0) | 2019.06.05 |
파이썬(Python) Matplotlib 라이브러리 다루기 (0) | 2018.12.08 |