일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
- 딥러닝
- 큐
- numpy
- 알고리즘
- flame
- cv2
- 논문 구현
- Knowledge Distillation
- 파이썬
- Deeplearning
- re-identification
- attention
- 프로그래머스
- transformer
- point cloud
- Threshold
- reconstruction
- 3D
- 자료구조
- NLP
- Python
- Computer Vision
- center loss
- 스택
- Object Tracking
- 임계처리
- level2
- OpenCV
- Object Detection
- deep learning
- Today
- Total
공돌이 공룡의 서재
[논문 리뷰] Learning Efficient Object Detection Models with Knowledge Distillation 본문
[논문 리뷰] Learning Efficient Object Detection Models with Knowledge Distillation
구름위의공룡 2022. 3. 29. 21:24NIPS 2017
1. Introduction
Knowledge distillation 에 대한 연구들이 많이 진행되어 왔지만, 대부분이 Image Classification에 대해 적용되어 왔다. 그렇다면 다른 Task를 수행하는 모델들에 적용할 수 있을까? 논문 저자들은 Multi-Class Object Detection task에 대해 거의 처음으로 성공했다고 한다.
왜 Object detection 에 적용하기가 좀 더 힘들까? 내용을 정리하면 다음과 같다.
- Detectition은 bounding box 좌표도 구해야 하고, box 내 물체가 어떤 물체인지 classification도 해야한다. 즉, 더 복잡한 task라 볼 수 있고 Image classification 보다 많은 연산량을 필요로 하는데 compression 시 성능 저하가 더 눈에 띄게 드러나게 된다.
- Detection의 경우 Background가 다른 class들보다 더 많이 고려되는 class 간 불균형 문제도 있다.
이 논문의 Contribution은 다음과 같이 정리할 수 있다.
- Object detection task 에 대해 Knowledge Distillation을 처음으로 성공적으로 접목시켰다.
- class imbalance (위에서 언급한 2번 문제점)을 해결할 새로운 loss를 제안했다.
- Teacher bounded loss 를 제시해서 bounding regression 문제를 해결했다.
- hint learning을 사용했다.
- 다양한 benchmark를 사용해서 많은 evaluation을 했고 저자들의 방법을 증명했다.
3. Method
Object Detection 모델로는 Faster-RCNN을 사용했다. (뒤에 나오겠지만 Loss function이 2 stage 기반으로 설명이 꼭 2 stage model을 쓰라고 나와있진 않으니 Detection model로는 뭘 선택해도 상관 없을 듯 하다.)
3.1 Overall Structure
위 Figure에 대한 설명에 해당하는 내용들이다.
- Student model이 Teacher model의 feature representation을 학습할 수 있도록 Intermediate layer에 대한 Hint learning을 도입했다.
- RPN (Region proposal network) 과 RCN (classification 과 box regression을 하는 head network 에 해당한다.) 각각에 대해 knowledge distillation에 관해 loss term을 도입했다.
- class imbalance를 해결하기 위해 weighted cross entropy loss를 사용했다.
- Regression의 경우 Teacher regression output이 upper bound가 되도록 했다.
전체적인 loss는 다음과 같다.
cls는 classification loss로 Ground Truth와 Soft label을 사용하는 loss tem이다.
reg는 regression loss로 Ground Truth에 대해서는 smoothed L1 loss, Teacher model 의 box regression 결과에 대해서는 L2 regression loss를 사용한 loss term이다.
아래 섹션들에서는 loss term들에 대한 구체적인 설명들이 이뤄진다.
3.2 Knowledge Distillation for Classification with Imbalanced Classes
기존 연구들에선 Image Classification할 때는 Teacher model 의 prediction 에 softmax를 취하고, Student model의 prediction에도 softmax를 취해서 둘을 비교했다. 즉, cross entropy loss 였다.
이 접근을 그대로 적용하면 앞서 말한 background class 가 가장 dominant 한 class 라는 점 때문에 적용이 잘 안된다. 따라서 저자들은 아래와 같이 class 별로 weight term 을 추가해서 distillation loss를 만들었다.
Ps는 student model prediction, Pt는 Teacher model prediction 이다. Wc는 class weight로 background 일 때는 1.5로 두고 나머지 class 일 때는 1로 두었다. (PASCAL dataset 기준)
3.3 Knowledge Distillation for Regression with Teacher Bounds
Teacher model의 regression 결과는 Student model 에게 틀린 guide를 줄 수 있다. (regression output이 unbounded 하기 때문이다.) 따라서 이 논문에서는 Teacher model의 box 결과를 그대로 사용하기 보단, student model이 달성해야 할 upper bound로 이용한다.
이를 나타낸 것이 $L_b$이다. 수식을 풀어보자. Student model 의 box 와 GT box 차이 값이 Teacher model 의 box와 GT box 차이 값을 비교했을 때, 둘의 차이가 특정 값 (margin) 보다 작아야지만 loss term이 작아진다. 전자가 후자에서 margin 을 뺀 값보다 작아야 (=Teacher model 보다도 어느정도 더 좋은 box 결과가 나와야) 0 이 되는 식이라서 upper bound라는 표현을 쓴 것 같다. GT와는 smoothed L1 loss 로 loss term을 구한다. ( $L_{sL1}$)
3.4 Hint learning with Feature Adaptation
Teacher model 의 중간 hidden layer output을 Student model 의 중간 hidden layer가 따라갈 수 있도록 학습하는 것이다. Hint learning을 처음 제시한 FitNet에서는 두 output 결과의 차이를 L2 term으로 loss를 정의했으나 이 논문에서는 아래와 같이 L1 term으로 정의하고 있다.
(왜 L1 loss 를 사용했는지는 잘 모르겠다.)
4. Experiment
여러 데이터셋들에 대해 테스트해본 결과를 정리한 표다. MS COCO, ImageNet(ILSVRC), PASCAL, 등 Object Detection에서 주로 사용하는 데이터셋들은 모두 사용했다.
결과를 보면, Teacher model이 좀 더 크고 성능이 좋은 모델일수록 Knowledge Distillation 의 효과가 좀 더 크게 나타났음을 알 수 있다. Object detection task에서도 지식을 전이하는 것이 유효함을 보였다.
정리하며
- Object detection 에 Knowledge distillation을 적용하고자 할 때 box regression을 어떻게 할 지가 궁금했었다. Teacher model의 output을 upper bound로 사용한 것이 흥미로웠다.
- Weighted Cross entropy, Smoothed L1 loss, 등 잘 안 보이는 loss들인데 Task에 맞게 잘 사용한 것 같다.
- 실험을 충분히 하고 실험에 대한 설명도 충분해서 방법의 유효함을 잘 보였다 생각한다.
'딥러닝 > Model' 카테고리의 다른 글
[논문 리뷰] Distilling the Knowledge in a Neural Network (0) | 2022.02.13 |
---|---|
[논문 리뷰] CBAM : Convolutional Block Attention Module (0) | 2021.08.25 |