• 카테고리

    질문 & 답변
  • 세부 분야

    컴퓨터 비전

  • 해결 여부

    미해결

Faster RCNN 문의드립니다.

22.05.25 09:46 작성 조회수 123

0

안녕하세요. 수업을 너무 감사한 마음으로 잘 듣고 있습니다. 

아직 잘 이해되지 않아서 문의드립니다. 

1. RPN 입력 

 - VGG등 CNN을 통과한 feature map. 

2. RPN output

 - output 1: 3x3 conv을 통과한 feature 맵을 1x1x9(anchor 개수)로 통과해서  각 anchor마다 object인지 여부를 판별하는 확률

- output2 :  위의 3x3 conv을 통과한 feature맵을 1x1x4x9 (x,y, w, h x anchor 개수) 

3. training 

- output1과 target_class의 cross entropy

- output2와 target_class의 좌표간의 regression entropy

--------------------------------

이렇게 이해했는데요.

그럼 train할때 학습시키는 target에 바로 9개의 bbox size를 적용해서 image상에서 9ea bbox와 그 것과 실제 object 와의 IOU를 계산해서 positive / negative 인지 여부와 각 bbox의 좌표를 나타내는 건가요?

답변 1

답변을 작성해보세요.

1

안녕하십니까, 

수업 잘듣고 계셔서 저도 기분이 좋습니다. 

RPN 네트웍 학습과 메인 네트웍 학습을 번갈하서 수행하는 부분이 직관적으로 이해하기 어려운 부분일수 있습니다.

Faster RCNN은 RPN 네트웍 학습과 메인 네트웍 학습을 번갈하서 반복적으로 수행합니다. 

그러니까, 배치 사이즈만큼의 학습 데이터를 순차적으로 RPN학습-> Faster RCNN 메인 네트웍 학습 -> 다시 RPN학습-> Faster RCNN 메인 네트웍 학습으로 Iteration하면서 학습을 합니다. 

먼저 말씀하신 첫번째 배치 사이즈만큼의 학습 데이터를 기반으로 RPN을 첫번째 학습을 합니다. 그리고 이렇게 학습된 RPN기반으로 Object가 있을 만한 ROI영역을 예측하고  이 ROI 영역이 학습 데이터 Object 영역과 얼마나 다른지 RPN Loss를 계산합니다.

다음으로 이 ROI 영역을 기반으로 메인 네트웍에서 오브젝트 클래스와 위치를 학습합니다. 그리고 실제 학습데이터와 학습/예측한 오브젝트 클래스와 위치가 얼마나 다른지 메인 네트웍 Loss를 계산합니다. 

그리고 이 RPN 네트웍과 메인 네트웍 모두에 해당하는 Faster rcnn 전체 네트웍의 loss를 RPN Loss +  메인 네트웍 loss로 계산하면서 전체 네트웍의 loss를 감소 시키는 방향으로 학습을 진행합니다.   이런 방식으로 학습 데이터를 순차적으로 학습해 나가면서 RPN의 Loss와 전체 네트웍의 loss를 감소 시키면서 Faster RCNN 모델을 학습 시키게 됩니다. 

감사합니다.