Instance Normalization를 NumPy 및 PyTorch로 구현하는 방법!
Instance Normalization은 Single Feature Map에 대하여 Normalization을 수행하는 기법이다. Instance Normalization은 Style Transfer나 StyleGAN과 같이 다양한 기술 및 아키텍처에서 활용되기 때문에, 알아 두면 상당히 좋다.
※ Batch Normalization ※
먼저 Batch Normalization에 대해 알아보자. 매우 많은 딥러닝 네트워크에서 활용되고 있는 정규화 레이어이며, Batch Normalization은 채널별로 현재 배치에 포함된 모든 이미지에 대하여 정규화를 수행한다. 그래서 각 채널별로 mean과 variance를 계산하는 것을 확인할 수 있다.
※ Instance Normalization ※
Instance Normalization은 개별적인 이미지에 대하여 채널별로 정규화를 수행한다는 점이 특징이다. Batch Normalization과 상당히 유사하지만, Instance Normalization은 개별적인 이미지 인스턴스(instance) 단위로 채널별로 정규화를 수행한다.
※ NumPy로 Instance Normalization 구현하기 ※
Numpy로 Instance Normalization을 구현해보자. 아래 코드를 보면, 각 이미지에 대하여 채널별로 variance와 mean을 계산하는 것을 알 수 있다.
import numpy as np
def calc_mean_std(feat, eps=1e-5):
n, c, h, w = feat.shape
feat_var = np.var(feat.reshape(n, c, -1), axis=2) + eps
feat_std = np.sqrt(feat_var).reshape(n, c, 1, 1)
feat_mean = np.mean(feat.reshape(n, c, -1), axis=2).reshape(n, c, 1, 1)
return feat_mean, feat_std
X = np.asarray([
# image 1
[
[[1, 2, 9, 2, 7],
[5, 0, 3, 1, 8],
[4, 1, 3, 0, 6],
[2, 5, 2, 9, 5],
[6, 5, 1, 3, 2]],
[[4, 5, 7, 0, 8],
[5, 8, 5, 3, 5],
[4, 2, 1, 6, 5],
[7, 3, 2, 1, 0],
[6, 1, 2, 2, 6]],
[[3, 7, 4, 5, 0],
[5, 4, 6, 8, 9],
[6, 1, 9, 1, 6],
[9, 3, 0, 2, 4],
[1, 2, 5, 5, 2]]
],
# image 2
[
[[7, 2, 1, 4, 2],
[5, 4, 6, 5, 0],
[1, 2, 4, 2, 8],
[5, 9, 0, 5, 1],
[7, 6, 2, 4, 6]],
[[5, 4, 2, 5, 7],
[6, 1, 4, 0, 5],
[8, 9, 4, 7, 6],
[4, 5, 5, 6, 7],
[1, 2, 7, 4, 1]],
[[7, 4, 8, 9, 7],
[5, 5, 8, 1, 4],
[3, 2, 2, 5, 2],
[1, 0, 3, 7, 6],
[4, 5, 4, 5, 5]]
]
])
print('Images:', X.shape)
# size = X.shape
feat_mean, feat_std = calc_mean_std(X)
print('Std:', feat_std.shape)
print('Mean:', feat_mean.shape)
out = (X - feat_mean) / feat_std
# We can get the same result by the below code.
# out = (X - np.broadcast_to(feat_mean, size)) / np.broadcast_to(feat_std, size)
print(out)
실행 결과는 다음과 같다.
Images: (2, 3, 5, 5)
Std: (2, 3, 1, 1)
Mean: (2, 3, 1, 1)
[[[[-1.01167305 -0.6341831 2.0082465 -0.6341831 1.25326661]
[ 0.49828672 -1.38916299 -0.25669316 -1.01167305 1.63075655]
[ 0.12079678 -1.01167305 -0.25669316 -1.38916299 0.87577667]
[-0.6341831 0.49828672 -0.6341831 2.0082465 0.49828672]
[ 0.87577667 0.49828672 -1.01167305 -0.25669316 -0.6341831 ]]
[[ 0.03335184 0.45024982 1.28404578 -1.63424008 1.70094375]
[ 0.45024982 1.70094375 0.45024982 -0.38354614 0.45024982]
[ 0.03335184 -0.80044412 -1.2173421 0.8671478 0.45024982]
[ 1.28404578 -0.38354614 -0.80044412 -1.2173421 -1.63424008]
[ 0.8671478 -1.2173421 -0.80044412 -0.80044412 0.8671478 ]]
[[-0.46796399 0.99442348 -0.10236712 0.26322975 -1.5647546 ]
[ 0.26322975 -0.10236712 0.62882661 1.36002035 1.72561722]
[ 0.62882661 -1.19915773 1.72561722 -1.19915773 0.62882661]
[ 1.72561722 -0.46796399 -1.5647546 -0.83356086 -0.10236712]
[-1.19915773 -0.83356086 0.26322975 0.26322975 -0.83356086]]]
[[[ 1.24161152 -0.77399159 -1.17711222 0.03224965 -0.77399159]
[ 0.43537027 0.03224965 0.83849089 0.43537027 -1.58023284]
[-1.17711222 -0.77399159 0.03224965 -0.77399159 1.64473214]
[ 0.43537027 2.04785276 -1.58023284 0.43537027 -1.17711222]
[ 1.24161152 0.83849089 -0.77399159 0.03224965 0.83849089]]
[[ 0.17149843 -0.25724764 -1.11473978 0.17149843 1.02899057]
[ 0.6002445 -1.54348585 -0.25724764 -1.97223192 0.17149843]
[ 1.45773663 1.8864827 -0.25724764 1.02899057 0.6002445 ]
[-0.25724764 0.17149843 0.17149843 0.6002445 1.02899057]
[-1.54348585 -1.11473978 1.02899057 -0.25724764 -1.54348585]]
[[ 1.07948803 -0.20561677 1.50785629 1.93622455 1.07948803]
[ 0.2227515 0.2227515 1.50785629 -1.49072156 -0.20561677]
[-0.63398503 -1.06235329 -1.06235329 0.2227515 -1.06235329]
[-1.49072156 -1.91908982 -0.63398503 1.07948803 0.65111976]
[-0.20561677 0.2227515 -0.20561677 0.2227515 0.2227515 ]]]]
※ PyTorch로 Instance Normalization 구현하기 ※
이어서 PyTorch로 Instance Normalization을 구현하는 방법은 다음과 같다. 위 코드와 같은 로직이다.
import numpy as np
import torch
def calc_mean_std(feat, eps=1e-5):
n, c, h, w = feat.shape
feat_var = feat.view(n, c, -1).var(dim=2) + eps
feat_std = feat_var.sqrt().view(n, c, 1, 1)
feat_mean = feat.view(n, c, -1).mean(dim=2).view(n, c, 1, 1)
return feat_mean, feat_std
X = torch.from_numpy(np.asarray([
# image 1
[
[[1, 2, 9, 2, 7],
[5, 0, 3, 1, 8],
[4, 1, 3, 0, 6],
[2, 5, 2, 9, 5],
[6, 5, 1, 3, 2]],
[[4, 5, 7, 0, 8],
[5, 8, 5, 3, 5],
[4, 2, 1, 6, 5],
[7, 3, 2, 1, 0],
[6, 1, 2, 2, 6]],
[[3, 7, 4, 5, 0],
[5, 4, 6, 8, 9],
[6, 1, 9, 1, 6],
[9, 3, 0, 2, 4],
[1, 2, 5, 5, 2]]
],
# image 2
[
[[7, 2, 1, 4, 2],
[5, 4, 6, 5, 0],
[1, 2, 4, 2, 8],
[5, 9, 0, 5, 1],
[7, 6, 2, 4, 6]],
[[5, 4, 2, 5, 7],
[6, 1, 4, 0, 5],
[8, 9, 4, 7, 6],
[4, 5, 5, 6, 7],
[1, 2, 7, 4, 1]],
[[7, 4, 8, 9, 7],
[5, 5, 8, 1, 4],
[3, 2, 2, 5, 2],
[1, 0, 3, 7, 6],
[4, 5, 4, 5, 5]]
]
], dtype=np.float64))
print('Images:', X.shape)
# size = X.shape
feat_mean, feat_std = calc_mean_std(X)
print('Mean:', feat_mean.shape)
print('Std:', feat_std.shape)
out = (X - feat_mean) / feat_std
# out = (X - feat_mean.expand(size)) / feat_std.expand(size)
print(out)
실행 결과는 다음과 같다.
Images: torch.Size([2, 3, 5, 5])
Mean: torch.Size([2, 3, 1, 1])
Std: torch.Size([2, 3, 1, 1])
tensor([[[[-0.9912, -0.6214, 1.9677, -0.6214, 1.2279],
[ 0.4882, -1.3611, -0.2515, -0.9912, 1.5978],
[ 0.1184, -0.9912, -0.2515, -1.3611, 0.8581],
[-0.6214, 0.4882, -0.6214, 1.9677, 0.4882],
[ 0.8581, 0.4882, -0.9912, -0.2515, -0.6214]],
[[ 0.0327, 0.4412, 1.2581, -1.6012, 1.6666],
[ 0.4412, 1.6666, 0.4412, -0.3758, 0.4412],
[ 0.0327, -0.7843, -1.1927, 0.8496, 0.4412],
[ 1.2581, -0.3758, -0.7843, -1.1927, -1.6012],
[ 0.8496, -1.1927, -0.7843, -0.7843, 0.8496]],
[[-0.4585, 0.9743, -0.1003, 0.2579, -1.5331],
[ 0.2579, -0.1003, 0.6161, 1.3325, 1.6908],
[ 0.6161, -1.1749, 1.6908, -1.1749, 0.6161],
[ 1.6908, -0.4585, -1.5331, -0.8167, -0.1003],
[-1.1749, -0.8167, 0.2579, 0.2579, -0.8167]]],
[[[ 1.2165, -0.7584, -1.1533, 0.0316, -0.7584],
[ 0.4266, 0.0316, 0.8215, 0.4266, -1.5483],
[-1.1533, -0.7584, 0.0316, -0.7584, 1.6115],
[ 0.4266, 2.0065, -1.5483, 0.4266, -1.1533],
[ 1.2165, 0.8215, -0.7584, 0.0316, 0.8215]],
[[ 0.1680, -0.2521, -1.0922, 0.1680, 1.0082],
[ 0.5881, -1.5123, -0.2521, -1.9324, 0.1680],
[ 1.4283, 1.8484, -0.2521, 1.0082, 0.5881],
[-0.2521, 0.1680, 0.1680, 0.5881, 1.0082],
[-1.5123, -1.0922, 1.0082, -0.2521, -1.5123]],
[[ 1.0577, -0.2015, 1.4774, 1.8971, 1.0577],
[ 0.2183, 0.2183, 1.4774, -1.4606, -0.2015],
[-0.6212, -1.0409, -1.0409, 0.2183, -1.0409],
[-1.4606, -1.8803, -0.6212, 1.0577, 0.6380],
[-0.2015, 0.2183, -0.2015, 0.2183, 0.2183]]]],
dtype=torch.float64)