20210731 TIL
2021.07.31
matrix_multiplication
x = torch.FloatTensor([[1, 2],
[3, 4],
[5, 6]])
y = torch.FloatTensor([[1, 2],
[1, 2]])
xx = torch.FloatTensor([1,2])
x * xx
x * xx를 계산하면 broadcasting이 되어 계산된다. 수학적인 행렬의 곱을 하려면 단순하게 곱하기가 아니라, torch.mm(x, y) 를 해야한다. 정말 수학에서 처럼 행렬의 사이즈가 맞아야만 하는데 이를 편하게 내부적으로 곱셉을 하게 만들어 주는 함수가 matmul() 이다. 다만, 행렬의 사이즈가 곱셈하기에 맞는 것이 아니더라도 알아서 잘 바꾸어서 계산해주기 때문에, 내가 어떤 형태의 행렬 또는 텐서를 넣어서 계산하는지 파악이 잘 안될 수 있다는 단점이 있다.
z = torch.matmul(x, y)
xxx = torch.FloatTensor([[1],[2]])
print(x.size())
print(xx.size())
print(xxx.size())
print(torch.matmul(x, xx))
print(torch.matmul(x, xxx))
torch.Size([3, 2]) torch.Size([2]) torch.Size([2, 1]) tensor([ 5., 11., 17.]) tensor([[ 5.], [11.], [17.]])
bmm
x = torch.FloatTensor([[[1, 2],
[3, 4],
[5, 6]],
[[7, 8],
[9, 10],
[11, 12]],
[[13, 14],
[15, 16],
[17, 18]]])
y = torch.FloatTensor([[[1, 2, 2],
[1, 2, 2]],
[[1, 3, 3],
[1, 3, 3]],
[[1, 4, 4],
[1, 4, 4]]])
(3, 3, 2) * (3, 2, 3) 을 계산하는 예제인데, (3, 2) 행렬, (2, 3) 행렬를 각 3개씩 batch 로 계산한다는 의미 이다. mm 처럼 수학에서의 행렬 곱셈 처럼 짝이 딱 맞아야만 실행된다.
(batch, p, n) * (batch, n, m) 형태여야 계산 가능. 결과는 (batch, p, m) 형태
위의 상황과 마찬가지로 matmul로 계산하느냐 mm 혹은 bmm으로 계산하느냐의 차이가 있으므로 거기서 기인하는 문제는 인식하면서 사용을 해야된다.
z = torch.bmm(x, y)
(3, 3, 2)와 (3, 2, 3)의 결과이므로 (3, 3, 3)의 형태를 얻게 된다.
댓글을 작성해보세요.