-
pytorch-4 stride conv, dilated conv, 가중치 행렬 시각화딥러닝/pytorch 2023. 7. 9. 15:18
1. stride conv
import torch import os from torch import nn import matplotlib.pyplot as plt os.environ["KMP_DUPLICATE_LIB_OK"] = "True" input_data = torch.randn(1, 1, 28, 28) conv = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=2, padding=1) output = conv(input_data) print(output.shape) plt.subplot(1, 2, 1) plt.imshow(input_data.squeeze(), cmap="gray") plt.title("input") plt.subplot(1, 2, 2) plt.imshow(output.squeeze().detach().numpy()[0], cmap="gray") plt.title("output") plt.tight_layout() plt.show()
1 채널의 28*28 이미지를 랜덤 하게 생성하고 nn.conv2d를 사용하여 스트라이드를 2칸으로 적용한다. conv에 데이터를 넣어주고 stride를 2칸씩 적용하기 이전과 이후의 이미지를 출력하여 비교하면 더욱 데이터가 옅어진 것을 확인할 수 있다.
2. dilated conv
from torch import nn import matplotlib.pyplot as plt from PIL import Image from torchvision import transforms import os os.environ["KMP_DUPLICATE_LIB_OK"] = "True" image_path = "./data/surprised_cat.jpg" image = Image.open(image_path).convert("L") input_data = transforms.ToTensor()(image).unsqueeze(0) conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, dilation=2) output = conv(input_data) plt.subplot(1, 2, 1) plt.imshow(image, cmap="gray") plt.title("input image") plt.subplot(1, 2, 2) plt.imshow(output.squeeze().detach().numpy(), cmap="gray") plt.title("output image") plt.tight_layout() plt.show()
dilattion 같을 조절 하여 이전 이미지와 비교해서 출력하면 마찬가지로 더 옅어진 데이터를 표현한다.
3. 가중치 행렬 시각화
import matplotlib.pyplot as plt import torch.nn as nn import os os.environ["KMP_DUPLICATE_LIB_OK"] = "True" input_size = 4 output_size = 2 dense_layer = nn.Linear(input_size, output_size) weights = dense_layer.weight.detach().numpy() plt.figure(figsize=(10, 6)) plt.imshow(weights, cmap="coolwarm", aspect="auto") plt.colorbar() plt.show()
nn.linear로 입력값 4, 출력값 2의 선형변환하는 연산을 수행한다. detach를 사용하여 가중치를 분리하고 numpy로 변환하여 그림으로 표시한다.
x축은 입력값이고 y축은 출력값이다. 입력값과 출력값의 가중치 행렬의 크기를 그래프로 확인할 수 있으며 빨간색으로 갈수록 높은 상관 값을 가지고 있다.
'딥러닝 > pytorch' 카테고리의 다른 글
pytorch-6 과적합 방지 방법 (0) 2023.07.10 pytorch-5 CNN (0) 2023.07.09 pytorch-3 ANN, RBM (0) 2023.07.09 pytorch-2 데이터로더 (0) 2023.07.03