인프런 커뮤니티 질문&답변

vpdtlrdl님의 프로필 이미지
vpdtlrdl

작성한 질문수

처음하는 딥러닝과 파이토치(Pytorch) 부트캠프 (쉽게! 기본부터 챗GPT 핵심 트랜스포머까지) [데이터분석/과학 Part4]

RNN 과 LSTM 구현해보기2 (MNIST 데이터셋)

RNN과 LSTM 구현해보기2(MNIST 데이터셋) 강의에서 질문입니다

해결된 질문

작성

·

373

0

RNN과 LSTM 구현해보기2(MNIST 데이터셋) 강의의 15:04 부분에서 질문입니다.

 

강의에서는 다음과 같이 학습 과정에서 반복문을 작성했습니다.

        # |x_minibatch| = (128, 1, 28, 28)
        # |y_minibatch| = (128)
        for x_minibatch, y_minibatch in train_batches:
            
            x_minibatch = x_minibatch.reshape(-1, sequence_length, feature_size)
            y_minibatch_pred = model(x_minibatch)
            
            loss = loss_func(y_minibatch_pred, y_minibatch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()            
            train_losses.append(loss.item())

 

이때, 아래와 같이 loss_func를 적용하는 부분에서 궁금한 점이 있는데요,

loss = loss_func(y_minibatch_pred, y_minibatch)

y_minibatch_pred 는 model에 x_minibatch 를 넣어서 값을 예측한 것으로, 그 shape이 (128, 10) 과 같이 2차원으로 나온다고 이해하였습니다.

반면, y_miinibatch 는 (128) 과 같이 1차원으로 나오는 것을 확인했습니다.

 

이렇게 loss_func 안에 넣는 두 텐서의 다른 것으로 보이는데, y_minibatch의 shape을 변형해 줘야 하는 것은 아닌지 여쭙고 싶습니다..!

답변 2

1

안녕하세요. 답변 도우미입니다. 우선 답변이 조금 늦어져서 죄송합니다. 출장을 다녀와서 조금 늦었습니다.

혹시나 질문을 정확하게 이해한 것인지 조금 걱정이 되는데요. 그래도 최대한 답변을 드리면요.

PyTorch의 loss 함수는 이러한 차이점을 알아서 처리해줍니다. 예를 들어, nn.CrossEntropyLoss() 같은 함수는 target(정답 레이블)이 1차원 텐서고, 예측값이 2차원 텐서일 경우에도 잘 동작합니다.

  • y_minibatch_pred는 각 클래스에 대한 확률(또는 점수)을 담고 있는 2차원 텐서입니다. 그래서 그 shape은 [batch_size, num_classes]입니다.

  • y_minibatch는 정답 레이블을 담고 있는 1차원 텐서입니다. 여기서 각 원소는 해당 샘플의 클래스 인덱스를 나타냅니다. 그래서 그 shape은 [batch_size]입니다.

이 두 텐서를 nn.CrossEntropyLoss() 같은 loss 함수에 넣으면, 내부적으로 필요한 연산을 수행하여 loss를 계산해줍니다. 따라서 특별히 y_minibatch의 shape을 변형해줄 필요는 없습니다.

감사합니다.

0

안녕하세요, 인프런 AI 인턴이에요.

y_minibatch의 shape을 변형해 줘야 하는지 여쭤보셨는데, LSTM 모델에서는 상태가 없는 텐서를 다루는 경우가 많습니다. 그래서 y_minibatch의 shape을 변형할 필요가 없습니다.

LSTM 모델은 각 시점 별로 예측 값을 내놓기 때문에, 실제 값과 예측 값을 비교하기 위해서는 y_minibatch_pred와 y_minibatch의 shape이 같아야 합니다.

그래서 여기서는 y_minibatch_pred와 y_minibatch의 shape을 비교하여 손실 함수를 계산하고 있습니다.

따라서, y_minibatch의 shape을 변형할 필요 없이 그대로 사용하시면 됩니다.

추가적인 궁금한 사항이 있다면 언제든지 물어보세요. 감사합니다.

vpdtlrdl님의 프로필 이미지
vpdtlrdl
질문자

y_minibatch_pred와 y_minibatch의 shape이 동일하지 않다고 질문에 서술하였습니다.

vpdtlrdl님의 프로필 이미지
vpdtlrdl

작성한 질문수

질문하기