inflearn logo
강의

강의

N
챌린지

챌린지

멘토링

멘토링

N
클립

클립

로드맵

로드맵

지식공유

[PyTorch] 쉽고 빠르게 배우는 딥러닝

[실습] RNN을 이용한 영화 리뷰 예측 모델 만들기

forward에서 h_t

219

최성빈

작성한 질문수 1

0

BasicRNN의 forward함수에서 self.rnn(x, h_0)로부터 나온 아웃풋인 x[:,-1,:]를 h_t로 설정하셨는데,

여기서 [:,-1,:]이 무슨 의미인지 모르겠습니다.

.size()함수를 이용하여 확인해봤더니 x가 [100,779,256] 이런식으로 나오고 h_t가 [100,256]으로 나와서 h_t=x.view(100,256) 이렇게 바꿔서 돌렸더니 "RuntimeError: shape '[100, 256]' is invalid for input of size 25395200"으로 뜹니다

질문1. [:,-1,:]이  .view(배치사이즈,히든사이즈)과 어떤 차이가 있는지 궁금합니다.

질문2. rnn의 아웃풋이 아닌 히든으로 logit을 구해도 무방한가요? ex) x, hidden = self.rnn(x, h_0)

h_t=hidden.view(100,256)

pytorch 딥러닝 인공신경망

답변 1

1

Justin

안녕하세요.

Justin 입니다.

self.rnn = nn.RNN(embed_dim, self.hidden_dim, num_layers = self.n_layers, batch_first = True) 의 아웃풋은 (batch, seq, feature) 값의 텐서와, (num_layers * num_directions , batch, hidden_size) 값의 텐서 2개로 출력됩니다.

x[:, -1, :]은 self.rnn의 출력값 중 sequecnce의 가장 마지막 위치에서 출력되는 h_n값을 통해서 레이블값과 비교하기 위해서 설정하는 것 입니다. 시퀀스 내 가장 마지막에서 출력되는 값이 시퀀스 앞에 있는 정보를 반영하고 있기 때문입니다.

오류가 발생한 이유는, [100, 256]은 단순히 배치 사이즈와, RNN 셀에서 계산되어 출력되는 텐서값의 차원을 의미하는데, 이는 sequence 길이를 반영하고 있지 않기 때문에 입력값으로 이용되는 데이터와 연산이 불가능하여 발생되는 에러입니다.

정리하자면 다음과 같습니다.

1. [:, -1, :] 은 각 RNN 셀에서 계산된 값 중 가장 마지막 값에 접근하기 위함입니다.

2. rnn의 아웃풋이 아닌 hidden으로 logit을 구해도 상관은 없습니다만, 통상적으로 rnn의 출력값에는 hidden 값으로 계산된 결과값에 가중치값을 곱하여 비선형 함수를 거친 결과값을 활용하여 logit값을 계산합니다.

감사합니다.

[개정판] 딥러닝 컴퓨터 비전 완벽 가이드 먼저? 구현하며 배우는 Transformer 먼저?

0

25

1

전 강의와 전혀 이어지지가 않음

0

26

1

Continual Learning 과 Transfer Learning 의 차이점

0

1473

1

Deep Learning 정의에 나온 Graphical representation learning에 대해서

0

602

1

학습시간 줄이는 방법에 대하여 문의 드리겠습니다.

0

1195

3

cross_entropy

0

2401

1

Mnist 데이터 실습 관련 질문입니다!

0

301

1

CNN_MNIST 실습예제 질문입니다.

0

260

1

프리트레인 질문있습니다

0

277

1

ResNet 클래스의 _make_layer 메서드 부분의 설명이 이해가 되지 않습니다

0

761

2

CNN(강의자료 38 39 페이지 질문)

0

218

1

[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨 제거

1

242

1

[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨이 붙어있어요

0

410

0

[실습] MLP를 이용한 MNIST 숫자분류 - 테스트 데이터 셋에 라벨이 붙어있어요

0

215

1

[실습] MLP를 이용한 MNIST 숫자분류 - 형태가 달라요

1

476

3

코드 오류

0

256

1

RNN 실습_ cuda 관련 질문 드립니다!

0

303

1

torchtext

0

675

3

LSTM 원리가 궁금합니다.

0

342

2

Autoencoder 계산

0

205

1

Pytorch 실습 진입장벽

0

238

1

AutoEncoder 질문 드립니다.

0

341

1

"딥러닝"이라는 제목의 강의 슬라이드 15페이지

0

184

1

DataLoader에서 num_workers 개념 설명 다시 한 번 부탁드려요

0

4994

1