기타

Instance Normalization를 NumPy 및 PyTorch로 구현하는 방법!

안경잡이개발자 2021. 8. 18. 23:34
728x90
반응형

  Instance NormalizationSingle Feature Map에 대하여 Normalization을 수행하는 기법이다. Instance Normalization은 Style Transfer나 StyleGAN과 같이 다양한 기술 및 아키텍처에서 활용되기 때문에, 알아 두면 상당히 좋다.

 

※ Batch Normalization ※

 

  먼저 Batch Normalization에 대해 알아보자. 매우 많은 딥러닝 네트워크에서 활용되고 있는 정규화 레이어이며, Batch Normalization은 채널별로 현재 배치에 포함된 모든 이미지에 대하여 정규화를 수행한다. 그래서 각 채널별로 mean과 variance를 계산하는 것을 확인할 수 있다.

 

 

※ Instance Normalization ※

 

  Instance Normalization은 개별적인 이미지에 대하여 채널별로 정규화를 수행한다는 점이 특징이다. Batch Normalization과 상당히 유사하지만, Instance Normalization은 개별적인 이미지 인스턴스(instance) 단위로 채널별로 정규화를 수행한다.

 

 

※ NumPy로 Instance Normalization 구현하기 ※

 

  Numpy로 Instance Normalization을 구현해보자. 아래 코드를 보면, 각 이미지에 대하여 채널별로 variance와 mean을 계산하는 것을 알 수 있다.

 

import numpy as np


def calc_mean_std(feat, eps=1e-5):
    n, c, h, w = feat.shape
    feat_var = np.var(feat.reshape(n, c, -1), axis=2) + eps
    feat_std = np.sqrt(feat_var).reshape(n, c, 1, 1)
    feat_mean = np.mean(feat.reshape(n, c, -1), axis=2).reshape(n, c, 1, 1)
    
    return feat_mean, feat_std


X = np.asarray([
# image 1
[
    [[1, 2, 9, 2, 7],
    [5, 0, 3, 1, 8],
    [4, 1, 3, 0, 6],
    [2, 5, 2, 9, 5],
    [6, 5, 1, 3, 2]],

    [[4, 5, 7, 0, 8],
    [5, 8, 5, 3, 5],
    [4, 2, 1, 6, 5],
    [7, 3, 2, 1, 0],
    [6, 1, 2, 2, 6]],

    [[3, 7, 4, 5, 0],
    [5, 4, 6, 8, 9],
    [6, 1, 9, 1, 6],
    [9, 3, 0, 2, 4],
    [1, 2, 5, 5, 2]]
],
# image 2
[
    [[7, 2, 1, 4, 2],
    [5, 4, 6, 5, 0],
    [1, 2, 4, 2, 8],
    [5, 9, 0, 5, 1],
    [7, 6, 2, 4, 6]],

    [[5, 4, 2, 5, 7],
    [6, 1, 4, 0, 5],
    [8, 9, 4, 7, 6],
    [4, 5, 5, 6, 7],
    [1, 2, 7, 4, 1]],

    [[7, 4, 8, 9, 7],
    [5, 5, 8, 1, 4],
    [3, 2, 2, 5, 2],
    [1, 0, 3, 7, 6],
    [4, 5, 4, 5, 5]]
]
])
print('Images:', X.shape)
# size = X.shape

feat_mean, feat_std = calc_mean_std(X)
print('Std:', feat_std.shape)
print('Mean:', feat_mean.shape)

out = (X - feat_mean) / feat_std
# We can get the same result by the below code.
# out = (X - np.broadcast_to(feat_mean, size)) / np.broadcast_to(feat_std, size)
print(out)

 

  실행 결과는 다음과 같다.

 

Images: (2, 3, 5, 5)
Std: (2, 3, 1, 1)
Mean: (2, 3, 1, 1)
[[[[-1.01167305 -0.6341831   2.0082465  -0.6341831   1.25326661]
   [ 0.49828672 -1.38916299 -0.25669316 -1.01167305  1.63075655]
   [ 0.12079678 -1.01167305 -0.25669316 -1.38916299  0.87577667]
   [-0.6341831   0.49828672 -0.6341831   2.0082465   0.49828672]
   [ 0.87577667  0.49828672 -1.01167305 -0.25669316 -0.6341831 ]]

  [[ 0.03335184  0.45024982  1.28404578 -1.63424008  1.70094375]
   [ 0.45024982  1.70094375  0.45024982 -0.38354614  0.45024982]
   [ 0.03335184 -0.80044412 -1.2173421   0.8671478   0.45024982]
   [ 1.28404578 -0.38354614 -0.80044412 -1.2173421  -1.63424008]
   [ 0.8671478  -1.2173421  -0.80044412 -0.80044412  0.8671478 ]]

  [[-0.46796399  0.99442348 -0.10236712  0.26322975 -1.5647546 ]
   [ 0.26322975 -0.10236712  0.62882661  1.36002035  1.72561722]
   [ 0.62882661 -1.19915773  1.72561722 -1.19915773  0.62882661]
   [ 1.72561722 -0.46796399 -1.5647546  -0.83356086 -0.10236712]
   [-1.19915773 -0.83356086  0.26322975  0.26322975 -0.83356086]]]


 [[[ 1.24161152 -0.77399159 -1.17711222  0.03224965 -0.77399159]
   [ 0.43537027  0.03224965  0.83849089  0.43537027 -1.58023284]
   [-1.17711222 -0.77399159  0.03224965 -0.77399159  1.64473214]
   [ 0.43537027  2.04785276 -1.58023284  0.43537027 -1.17711222]
   [ 1.24161152  0.83849089 -0.77399159  0.03224965  0.83849089]]

  [[ 0.17149843 -0.25724764 -1.11473978  0.17149843  1.02899057]
   [ 0.6002445  -1.54348585 -0.25724764 -1.97223192  0.17149843]
   [ 1.45773663  1.8864827  -0.25724764  1.02899057  0.6002445 ]
   [-0.25724764  0.17149843  0.17149843  0.6002445   1.02899057]
   [-1.54348585 -1.11473978  1.02899057 -0.25724764 -1.54348585]]

  [[ 1.07948803 -0.20561677  1.50785629  1.93622455  1.07948803]
   [ 0.2227515   0.2227515   1.50785629 -1.49072156 -0.20561677]
   [-0.63398503 -1.06235329 -1.06235329  0.2227515  -1.06235329]
   [-1.49072156 -1.91908982 -0.63398503  1.07948803  0.65111976]
   [-0.20561677  0.2227515  -0.20561677  0.2227515   0.2227515 ]]]]

 

※ PyTorch로 Instance Normalization 구현하기 ※

 

  이어서 PyTorch로 Instance Normalization을 구현하는 방법은 다음과 같다. 위 코드와 같은 로직이다.

 

import numpy as np
import torch


def calc_mean_std(feat, eps=1e-5):
    n, c, h, w = feat.shape
    feat_var = feat.view(n, c, -1).var(dim=2) + eps
    feat_std = feat_var.sqrt().view(n, c, 1, 1)
    feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)

    return feat_mean, feat_std


X = torch.from_numpy(np.asarray([
# image 1
[
    [[1, 2, 9, 2, 7],
    [5, 0, 3, 1, 8],
    [4, 1, 3, 0, 6],
    [2, 5, 2, 9, 5],
    [6, 5, 1, 3, 2]],

    [[4, 5, 7, 0, 8],
    [5, 8, 5, 3, 5],
    [4, 2, 1, 6, 5],
    [7, 3, 2, 1, 0],
    [6, 1, 2, 2, 6]],

    [[3, 7, 4, 5, 0],
    [5, 4, 6, 8, 9],
    [6, 1, 9, 1, 6],
    [9, 3, 0, 2, 4],
    [1, 2, 5, 5, 2]]
],
# image 2
[
    [[7, 2, 1, 4, 2],
    [5, 4, 6, 5, 0],
    [1, 2, 4, 2, 8],
    [5, 9, 0, 5, 1],
    [7, 6, 2, 4, 6]],

    [[5, 4, 2, 5, 7],
    [6, 1, 4, 0, 5],
    [8, 9, 4, 7, 6],
    [4, 5, 5, 6, 7],
    [1, 2, 7, 4, 1]],

    [[7, 4, 8, 9, 7],
    [5, 5, 8, 1, 4],
    [3, 2, 2, 5, 2],
    [1, 0, 3, 7, 6],
    [4, 5, 4, 5, 5]]
]
], dtype=np.float64))
print('Images:', X.shape)
# size = X.shape

feat_mean, feat_std = calc_mean_std(X)
print('Mean:', feat_mean.shape)
print('Std:', feat_std.shape)

out = (X - feat_mean) / feat_std
# out = (X - feat_mean.expand(size)) / feat_std.expand(size)
print(out)

 

  실행 결과는 다음과 같다.

 

Images: torch.Size([2, 3, 5, 5])
Mean: torch.Size([2, 3, 1, 1])
Std: torch.Size([2, 3, 1, 1])
tensor([[[[-0.9912, -0.6214,  1.9677, -0.6214,  1.2279],
          [ 0.4882, -1.3611, -0.2515, -0.9912,  1.5978],
          [ 0.1184, -0.9912, -0.2515, -1.3611,  0.8581],
          [-0.6214,  0.4882, -0.6214,  1.9677,  0.4882],
          [ 0.8581,  0.4882, -0.9912, -0.2515, -0.6214]],

         [[ 0.0327,  0.4412,  1.2581, -1.6012,  1.6666],
          [ 0.4412,  1.6666,  0.4412, -0.3758,  0.4412],
          [ 0.0327, -0.7843, -1.1927,  0.8496,  0.4412],
          [ 1.2581, -0.3758, -0.7843, -1.1927, -1.6012],
          [ 0.8496, -1.1927, -0.7843, -0.7843,  0.8496]],

         [[-0.4585,  0.9743, -0.1003,  0.2579, -1.5331],
          [ 0.2579, -0.1003,  0.6161,  1.3325,  1.6908],
          [ 0.6161, -1.1749,  1.6908, -1.1749,  0.6161],
          [ 1.6908, -0.4585, -1.5331, -0.8167, -0.1003],
          [-1.1749, -0.8167,  0.2579,  0.2579, -0.8167]]],


        [[[ 1.2165, -0.7584, -1.1533,  0.0316, -0.7584],
          [ 0.4266,  0.0316,  0.8215,  0.4266, -1.5483],
          [-1.1533, -0.7584,  0.0316, -0.7584,  1.6115],
          [ 0.4266,  2.0065, -1.5483,  0.4266, -1.1533],
          [ 1.2165,  0.8215, -0.7584,  0.0316,  0.8215]],

         [[ 0.1680, -0.2521, -1.0922,  0.1680,  1.0082],
          [ 0.5881, -1.5123, -0.2521, -1.9324,  0.1680],
          [ 1.4283,  1.8484, -0.2521,  1.0082,  0.5881],
          [-0.2521,  0.1680,  0.1680,  0.5881,  1.0082],
          [-1.5123, -1.0922,  1.0082, -0.2521, -1.5123]],

         [[ 1.0577, -0.2015,  1.4774,  1.8971,  1.0577],
          [ 0.2183,  0.2183,  1.4774, -1.4606, -0.2015],
          [-0.6212, -1.0409, -1.0409,  0.2183, -1.0409],
          [-1.4606, -1.8803, -0.6212,  1.0577,  0.6380],
          [-0.2015,  0.2183, -0.2015,  0.2183,  0.2183]]]],
       dtype=torch.float64)
728x90
반응형