forward에서 h_t
219
작성한 질문수 1
BasicRNN의 forward함수에서 self.rnn(x, h_0)로부터 나온 아웃풋인 x[:,-1,:]를 h_t로 설정하셨는데,
여기서 [:,-1,:]이 무슨 의미인지 모르겠습니다.
.size()함수를 이용하여 확인해봤더니 x가 [100,779,256] 이런식으로 나오고 h_t가 [100,256]으로 나와서 h_t=x.view(100,256) 이렇게 바꿔서 돌렸더니 "RuntimeError: shape '[100, 256]' is invalid for input of size 25395200"으로 뜹니다
질문1. [:,-1,:]이 .view(배치사이즈,히든사이즈)과 어떤 차이가 있는지 궁금합니다.
질문2. rnn의 아웃풋이 아닌 히든으로 logit을 구해도 무방한가요? ex) x, hidden = self.rnn(x, h_0)
h_t=hidden.view(100,256)
답변 1
1
안녕하세요.
Justin 입니다.
self.rnn = nn.RNN(embed_dim, self.hidden_dim, num_layers = self.n_layers, batch_first = True) 의 아웃풋은 (batch, seq, feature) 값의 텐서와, (num_layers * num_directions , batch, hidden_size) 값의 텐서 2개로 출력됩니다.
x[:, -1, :]은 self.rnn의 출력값 중 sequecnce의 가장 마지막 위치에서 출력되는 h_n값을 통해서 레이블값과 비교하기 위해서 설정하는 것 입니다. 시퀀스 내 가장 마지막에서 출력되는 값이 시퀀스 앞에 있는 정보를 반영하고 있기 때문입니다.
오류가 발생한 이유는, [100, 256]은 단순히 배치 사이즈와, RNN 셀에서 계산되어 출력되는 텐서값의 차원을 의미하는데, 이는 sequence 길이를 반영하고 있지 않기 때문에 입력값으로 이용되는 데이터와 연산이 불가능하여 발생되는 에러입니다.
정리하자면 다음과 같습니다.
1. [:, -1, :] 은 각 RNN 셀에서 계산된 값 중 가장 마지막 값에 접근하기 위함입니다.
2. rnn의 아웃풋이 아닌 hidden으로 logit을 구해도 상관은 없습니다만, 통상적으로 rnn의 출력값에는 hidden 값으로 계산된 결과값에 가중치값을 곱하여 비선형 함수를 거친 결과값을 활용하여 logit값을 계산합니다.
감사합니다.
소리가 겹쳐서 들려요
0
16
2
[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?
0
27
1
Continual Learning 과 Transfer Learning 의 차이점
0
1473
1
Deep Learning 정의에 나온 Graphical representation learning에 대해서
0
603
1
학습시간 줄이는 방법에 대하여 문의 드리겠습니다.
0
1195
3
cross_entropy
0
2401
1
Mnist 데이터 실습 관련 질문입니다!
0
301
1
CNN_MNIST 실습예제 질문입니다.
0
260
1
프리트레인 질문있습니다
0
277
1
ResNet 클래스의 _make_layer 메서드 부분의 설명이 이해가 되지 않습니다
0
761
2
CNN(강의자료 38 39 페이지 질문)
0
218
1
[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨 제거
1
242
1
[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨이 붙어있어요
0
410
0
[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨이 붙어있어요
0
215
1
[실습] MLP를 이용한 MNIST 숫자분류 - 형태가 달라요
1
476
3
코드 오류
0
256
1
RNN 실습_ cuda 관련 질문 드립니다!
0
303
1
torchtext
0
675
3
LSTM 원리가 궁금합니다.
0
343
2
Autoencoder 계산
0
205
1
Pytorch 실습 진입장벽
0
238
1
AutoEncoder 질문 드립니다.
0
341
1
"딥러닝"이라는 제목의 강의 슬라이드 15페이지
0
184
1
DataLoader에서 num_workers 개념 설명 다시 한 번 부탁드려요
0
4994
1





