안경잡이개발자

728x90
반응형

  최근에 GAN을 다시 공부하면서, GAN을 학습할 때 유의할 점에 대해서 다시 한번 정리하고 있다. 지금까지 GAN 네트워크를 수백 번 이상 학습을 해보았지만, 역시 학습 난이도가 높다. 그래서 경험적으로 GAN을 학습할 때, 어떤 테크닉을 사용하는 것이 유리한지 정리해 둘 필요가 있다.

 

※ Batch Size ※

 

  GAN에서는 배치 사이즈(batch size)에 의하여 결과가 많이 바뀌는 경향이 있다. 일반적으로 생성 모델(generative model)이 아닌 분류 모델에서는 배치 사이즈를 크게 설정하면, 그만큼 학습 속도가 빨라지게 되는 장점이 있다. 예를 들어 batch size가 64일 때보다 multi GPU를 활용하여 batch size 256으로 키워서 사용할 수 있으며, 그만큼 num_workers와 learning rate도 증가시킬 수 있다. 그러면 전체 학습 속도가 2~3배 빨라지는 것을 경험할 수 있다. 물론 분류 모델에서도 batch size가 1,024 이상으로 과도하게 커지는 경우 성능이 하락하는 현상을 볼 수도 있는데, GAN에서는 이러한 문제가 더 크게 발생하는 경향이 있다.

 

  실제로 필자가 GAN을 학습할 때는 상당수의 데이터셋에서 batch size를 64 이하로 설정하는 것이 결과적으로 성능이 좋았다. 필자는 예전에 분류 모델을 학습할 때 batch size를 증가시키고, 그만큼 learning rate을 크게 설정했던 것을 기억하여 GAN에서도 동일하게 적용해 보았으나, GAN에서는 batch size를 키웠을 때 성능 하락이 발생하는 문제를 자주 경험했다. 당연히 데이터셋마다 다르지만, 많은 경우에서 batch size가 64 이하일 때 합리적인 성능이 나왔다. 그렇기 때문에 사실상 multi GPU를 활용하기 어려운 경우가 많았다.

 

  또한 가끔은 잘 동작하는 코드에서 오직 multi GPU 설정만 적용했을 때에도 성능이 떨어지는 현상을 경험한 적이 있다. 다시 말해 batch size나 learning rate 등을 그대로 쓰는 상황에서 단지 multi GPU만 적용했을 뿐인데도 문제가 발생했다. 정확한 이유는 잘 모르겠지만, DCGAN처럼 convolution layer를 활용하는 경우에는 문제가 덜했다. 사실 batch size가 작을 때는 multi GPU를 사용할 이유가 없기 때문에, 단순히 1개의 GPU에서 작은 batch size로 학습을 진행하는 게 학습상의 효율성이 컸던 경우가 많다.

 

※ Conditional GAN ※

 

  학습이 잘 안 되거나 mode collapse의 문제가 발생하는 경우 conditional GAN을 활용하는 것이 GAN의 성능을 높이는 데에 도움을 주기도 한다. 기본적으로 GAN은 학습 난이도가 높기 때문에, 추가적인 guide가 존재하는 경우 성능이 향상될 수 있다. 따라서 단순히 데이터만 있는 것이 아니라, 레이블 정보가 같이 있다면 레이블 정보도 활용해 보는 것이 좋다. 필자의 경우 실제로 MNIST를 학습할 때, 단순히 GAN을 쓰는 것보다 conditional GAN을 쓸 때 mode collapse의 문제가 덜하고, 생성하고자 하는 이미지를 쉽게 컨트롤할 수 있었던 기억이 있다.

 

출처: https://arxiv.org/abs/1411.1784

 

  실제로 conditional GAN을 이용해 MNIST를 학습하면 다음과 같이 각 레이블(label)마다 이미지를 생성할 수 있게 된 것을 확인할 수 있다. 모드(mode)를 컨트롤할 수 있도록 해줌으로써, 상대적으로 수월하게 학습할 수 있을 뿐더러 학습 이후에 원하는 레이블의 숫자 이미지를 샘플링할 수 있다.

 

 

※ GAN의 Loss ※

 

  GAN에는 정말 다양한 종류의 손실(loss) 함수가 있다. 많은 논문에서 일반적으로 사용하는 loss로는 WGAN-GP가 있다. 기본적인 original GAN loss와 비교하여 WGAN-GP는 더 좋은 결과를 낼 때가 많다. 하지만 데이터셋에 따라서 더 좋은 loss 세팅이 존재하는 경우가 있다. 실제로 어떤 GAN 아키텍처를 사용하느냐에 따라서 WGAN-GP보다 오히려 LSGAN이나 일반적인 GAN loss가 더 좋은 FID를 보이기도 한다. 따라서 데이터셋마다 많은 시간을 투자해 하이퍼 파라미터(learning rate, optimizer 등)를 튜닝해보는 것이 좋다.

 

  필자의 경우 흑백 사진 데이터셋에 대하여 MLP 아키텍처를 사용할 때, 기본적인 original GAN loss보다 WGAN-GP 손실을 사용할 때 훨씬 성능이 개선되었다. 반면에 DCGAN의 경우에는 WGAN-GP를 사용할 때보다 LSGAN을 사용할 때가 오히려 성능이 더 좋았던 경험이 있다. 그래서 WGAN-GP를 무작정 사용하기보다는, 잘 동작한다고 알려진 네트워크를 사용할 필요가 있다. 예를 들어 ResNet 기반의 네트워크가 WGAN-GP와 함께 쓰이는 것을 자주 확인할 수 있다. 또한 CelebA나 FFHQ와 같은 데이터셋은 WGAN-GP loss가 효과적이라고 알려져 있다.

 

  또한 WGAP-GP의 경우 판별자(discriminator)를 여러 번 업데이트하고, 생성자(generator)를 한 번 업데이트하는 방식을 자주 사용한다. 그래서 실제로 학습이 되는 것을 보면서 d를 업데이트하는 횟수를 설정할 필요가 있다. 만약에 D를 너무 많이 업데이트하고 (n_critic이 너무 높은 경우), G를 한 번 업데이트하는 경우 D가 상대적으로 너무 강해져서, 진짜/가짜 이미지를 너무 쉽게 판별하게 될 수 있다. 이 경우 G또한 더 이상 정상적으로 업데이트되지 못한다. 배치 사이즈가 클 때도 D가 너무 강해지는 경향이 있으며, G와 D의 네트워크 capacity 밸런스 또한 맞추어 줄 필요가 있다.

 

출처: https://arxiv.org/abs/1704.00028

 

  그리고 GAN loss마다 사용하면 안 되는 레이어 유형이 존재하기도 한다. 예를 들어 WGAN-GP의 경우 discriminator에서 batch normalization 대신에 instance normalization을 사용하는 편이 더 유리하다. 실제로 batch normalization이 포함된다는 점에서 수렴이 잘 안 되어 노이즈와 같은 결과 이미지만 내보내던 GAN 네트워크에 대하여 batch normalization을 제거하고, 다른 normalization 레이어를 추가했을 때 비로소 수렴하여 좋은 이미지를 만드는 것을 확인했던 경험이 있다.

 

출처: https://towardsdatascience.com/gan-objective-functions-gans-and-their-variations-ad77340bce3c

 

  또한 GAN을 활용한 image-to-image translation 기법에서는 ground-truth 이미지와 유사해질 수 있도록 하기 위해서 L1을 활용한 loss가 자주 사용된다. 일반적으로 L1이나 L2 모두 blurry한 결과를 만들 수 있다는 단점이 있지만, 그나마 L1을 사용하면서 GAN과 함께 적용했을 때 결과가 조금 더 선명하게 나온다는 점을 언급한 논문들(Pix2Pix 등)이 있다.


※ 최적화 ※

  학습이 잘 되지 않거나, mode collapse가 발생하는 경우 학습률(learning rate)을 조절해 볼 필요가 있다. 또한 일반적으로 GAN에서는 Adam optimizer가 많이 사용된다. 하지만 이 또한 어떠한 loss를 사용하는지에 따라서 조금씩 차이가 있다. 그리고 충분히 수렴한 뒤에는 epoch을 증가시켜도 생성되는 이미지의 quality가 더 이상 개선되지 않는다. 따라서 epoch이 증가함에 따라서 FID를 계산해보는 방식으로 학습의 진행 과정을 확인해 볼 수도 있다.

 

※ Feature Matching ※

 

  feature matching을 사용해 GAN 모델의 불안정성을 완화하고 좋은 결과를 낼 수 있다.

728x90
반응형

Comment +0