model_compression 부스트캠프 AITech lightweighting KD Knowledge Distillation Logit-based KD

[모델 최적화 및 경량화 / Knowledge Distillation] Knowledge Distillation, Logit-based KD

Kwangjin Park

Dec 26, 2024 · 7 min read

Follow

Knowledge Distillation

개요

  • 고성능의 Teacher 모델로부터 지식을 전달받음 → 이를 기반으로 Student 모델을 학습시키는 기법
  • 성능 저하를 최소화하면서 모델을 압축하는 방법
  • 이를 통해, 작고 가벼운 학생 모델을 사용
  • Teacher vs Student
    • Teacher(선생 모델): 일반적인, 고성능의 Pre-trained 모델
      • 크고, 무겁고, 느리고, Pre-trained된 모델
    • Student(학생 모델): 선생 모델의 증류된 지식을 받아서 학습할 모델
      • 작고, 가볍고, 빠르고, 학습할 대상인 모델

메서드에 따른 KD 기법 분류

  1. Knowledge: 증류할 지식의 종류에 따른 분류
    • Response-based

      • Logit-based → Teacher 모델의 logit값을 사용
        • Logit: 신경망의 마지막 레이어(Sigmoid, Softmax 등)에 들어가기 직전의 raw 벡터
        • “확률로 바꾸기 전 점수들”을 그대로 가져와서 Teacher 점수 벡터를 따라 하도록 Student를 학습시킴
        • 그렇기 때문에, 클래스들 사이 미세한 점수 차이까지 전달하기 용이함
      • Output-based → Teacher 모델의 output값을 사용
        • Output: 신경망에서 마지막 레이어까지 모두 통과한 확률분포 값(각 클래스가 정답일 확률값)
        • “Softmax 적용 이후 확률들”을 가져와서 Teacher 확률 분포를 따라 하도록 Student를 학습시킴
        • 실제 예측에서 보는 값과 바로 연결되는 정보를 전달함
    • Feature-based

      • Teacher 모델의 중간 레이어의 feature/representation 활용
  2. Transparency: 모델 내부 구조 / 파라미터 열람 가능 여부에 따른 분류
    • White-Box

      • Teacher 모델의 내부 구조, 파라미터 등을 알 수 있는 경우
    • Gray-box

      • Teacher 모델의 output, 최종 logit값 등 제한된 정보만 알 수 있는 경우
    • Black-box

      • Teacher 모델의 output만 알 수 있는 경우

Logit-based KD

개요

  • Teacher 모델의 logit값을 지식으로 활용한 KD 기법
  • Logit

    • (넓은 의미) Unnormalized prediction value → “아직 정규화되지 않은 확률분포”
    • (좁은 의미) Softmax 함수를 씌워 클래스 확률값으로 만들 수 있는 값들

학습 알고리즘

  • Soft Label vs Hard Label

    • Hard Label: 정답 클래스 (실제 정답, 즉 가장 높은 확률을 갖는 클래스 1개만 선택, Similarity값 X)
      • Teacher, Student 두 모델의 학습에 모두 활용
    • Soft Label: 각 클래스에 대한 확률 (모델이 예측하고 있는 확률값)
      • Student 모델 학습에 활용 → Soft Label은 곧 Teacher 모델의 logit
  • 학습 개요
    • 일반적인 분류기 학습
    • 여기에 추가로, Teacher의 지식을 학습하는 과정 수행

      • Teacher의 지식은 곧 출력값의 확률분포
        • Teacher가 사용하는 지식
          • 주어진 Logit으로 계산한 클래스의 확률값
          • 클래스 간 유사도 정보가 간접적으로는 담겨 있음
            • Teacher 모델은 데이터를 학습하면서 클래스 간 유사도를 자연스레 파악함
      • Student는 자신의 클래스 예측이 Teacher의 클래스 예측과 유사해지도록 추가 학습 진행 → “정답 라벨을 맞추는 loss + Teacher를 따라가는 distillation loss” 두 가지를 동시에 줄이면서 학습
  • 학습과정에서 Student 및 Teacher 모델의 역할
    • Teacher
      • 분류 문제를 Cross-Entropy Loss로 사전에 학습(Hard Label)
      • 이러한 학습의 결과로 나온 클래스 확률값(Softmax 결과값)을 지식으로 Student 모델에 전달
    • Student
      • 분류 문제를 푸는 Cross-Entropy Loss는 기본(Hard Label), 여기에 더해
      • KL-divergence loss 추가 → Student 확률값을 Teacher의 확률값으로 모사하는 Loss값(Soft Label)
      • 확률값(확률분포)이 유사한 정도를 측정 → 일반적으로 KL-Divergence 활용
  • Teacher 및 Student 모델의 학습 과정 상세
    • 개략적 학습 과정
      1. 입력 이미지 x를 Teacher와 Student 모델에 동시에 넣어줌
        • 각 모델에 넣어주면, Teacher, Student 모델 모두 클래스에 대한 logit → Softmax → 확률분포 Output 생성
      2. Teacher 확률분포($P_t$)를 soft label로 보고, Student의 확률 분포($P_s$)와의 차이를 KL-Divergence로 계산해 distillation loss를 만들어줌
      3. Distillation loss 생성과 동시에, Student의 예측 $P_s$와 실제 정답 라벨 $y$ 사이의 차이를 cross-entropy loss를 통해 계산
      4. 최종 Loss값: student loss, distillation loss 두 값의 가중합으로 둠 → 최종 loss값을 줄여나가는 방향으로 Student의 가중치를 역전파로 업데이트해줌
      5. 1-4과정을 반복하면, Student 모델은 정답을 맞춤과 동시에 Teacher가 주는 부드러운 확률 분포의 패턴까지 흉내내어 자신의 클래스 예측이 선생님의 클래스 예측과 유사해 지도록 학습
    • Teacher 모델의 학습 과정

      1. 데이터와 정답 준비
        • 고양이/강아지 같은 이미지와, 각각에 대한 정답 라벨(벡터)이 주어진 상태에서 시작
      2. 입력을 Teacher 모델에 통과
        • 고양이/강아지 같은 이미지를 Teacher 네트워크에 넣으면, 마지막 층에서 각 클래스에 대한 logit값이 나옴
      3. logits → probabilities(확률)로 변환
        • 이 logits에 softmax 함수를 적용해서, 각 클래스가 정답일 “확률 분포”로 바꿔줌
        • ex) logits [5.1, 2.2, 0.7] → softmax 통과 → [0.9, 0.08, 0.02]처럼 확률 벡터로 변환(합이 1)
      4. Cross‑Entropy loss 계산
        • “위에서 얻은 확률 분포 ↔ 앞서 주어진 정답 라벨” 비교 → cross‑entropy loss 계산
        • 정답 클래스 확률이 높을수록 loss는 낮아지고, 낮을수록 loss가 높아짐
      5. 역전파로 가중치 업데이트
        • 계산된 Cross‑Entropy loss를 기준으로 역전파 수행 → Teacher 네트워크의 모든 가중치를 조금씩 수정
        • 이 과정을 전체 학습 데이터에 대해 여러 epoch 반복, Teacher가 점점 더 정답 라벨에 맞는 확률 분포(=출력)를 내도록 성능 개선
    • Student 모델의 학습 과정

      1. 입력을 Teacher / Student 모델에 동시에 넣기
        • 고양이 이미지 한 장을 Teacher, Student 두 모델에 동시에 넣어줌
        • 두 모델 모두 마지막 층에서 각 클래스에 대한 logits 출력
      2. logits → 확률(softmax)로 바꾸기
        • Teacher logits에 softmax(온도 $\tau$ 적용 가능) 적용하여 Teacher 확률 분포 $p_{tea}$ 얻음
        • Student logits에도 같은 방식으로 softmax를 적용해 Student 확률 분포 $p_{stu}$ 얻음
      3. Distillation loss(KL Divergence) 계산
        • Teacher 확률 분포 $p_{tea}$를 “soft target”이라고 보고, Student 분포 $p_{stu}$와 $p_{tea}$ 간 KL Divergence 계산 → “Distillation loss 생성”
        • 두 분포가 비슷할수록 KL 값이 작아짐 / 다를수록 커짐 → loss값을 줄여 Student 예측이 Teacher 예측과 점점 가까워지도록 만들어줌
      4. Cross-Entropy loss 계산(정답 라벨 기반)
        • 이와 동시에 Student 확률 분포와 실제 정답 라벨 간 cross‑entropy loss도 계산
          • 일반적인 분류 학습과 동일하게, Student가 정답을 잘 맞추도록 만드는 역할
      5. 두 loss를 섞어서 역전파 수행
        • 최종 loss는 보통 두 loss의 가중합으로 둠
          • L = $\alpha \cdot$CE(Student, hard label) + $\beta \cdot$KL(Teacher, Student) 형태
        • 이 최종 loss를 기준으로 역전파 수행 → Student의 가중치 업데이트
        • 데이터 전체에 대해 여러 epoch 반복
  • Temperature 인자
    • 확률분포를 날카롭게/완만하게 조절해주는 인자

      • $T<1$: 날카롭게
      • $T>1$: 완만하게
    • 모든 클래스에 대해, logit 값을 동일한 temperature로 나눠 softmax 적용
    • 균등 분포로 만들어주면, 연관도 및 유사도 파악이 쉬움

  • Teacher 및 Student 모델의 추론 확률값 및 Loss 수식 정리
    • 추론 클래스 확률값 ($n$번째 아이템, $k$번째 클래스)
      • Teacher: $q(z_n)^{(k)} = \frac{\exp(z_n^{(k)}/\mathcal{T})}{\sum_{m=1}^K \exp(z_n^{(m)} / \mathcal{T})}$
      • Student: $q(v_n)^{(k)} = \frac{\exp(v_n^{(k)}/\mathcal{T})}{\sum_{m=1}^K \exp(v_n^{(m)} / \mathcal{T})}$
    • Loss값
      • KL Loss(Soft Label): $\mathcal{L}{KL}(q(v_n) \vert q(z_n) = \sum{k=1}^K q(v_n)^{(k)}\log(\frac{q(v_n)^{(k)}}{q(z_n)^{(k)}})$
      • CE Loss(Hard Label): $\mathcal{L}{CE}(q(v_n), q(z_n) = - \sum{k=1}^K q(v_n)^{(k)}\log{q(z_n)^{(k)}}$
    • Teacher, Student의 Loss
      • Teacher: $\mathcal{L}_{CE}$
      • Student: $\lambda_{CE}\mathcal{L}{CE} + \lambda{KL}\mathcal{L}_{KL}$

성능

  • 데이터셋
    • CIFAR100 → 이미지를 100개 클래스로 분류하는 task(train: 50k, test: 10k)
  • 모델
    • Teacher: ResNet-56(Layer 층수) → 파라미터 약 860k개
    • Student: ResNet-20(Layer 층수) → 파라미터 약 270k개
  • 결과

Logit-based KD의 효과

  1. Label Smoothing
    • 일반적인 정답 라벨(One-hot encoding)은 매우 뾰족한 분포
    • Logit-based KD의 “soft label” → 오답에도 약간의 점수를 주어 부드러운 분포를 만들어줌
      • ex) Dog는 오답이지만, 약간의 점수를 부여함

    • 이렇게 부드러운 분포를 사용하게 된다면,
      • 모델이 한 클래스에만 과한 확신을 갖지 않아 overfitting을 방지할 수 있고
      • 이에 따라, 다양한 상황에 더 잘 버티는 robustness 및 일반화 성능이 향상됨
  2. Continuous Distribution
    • Discrete vs Continuous

      • One-hot Label이 0 또는 1로만 알려주는 discrete distribution를 보여주지만
      • soft label은 여러 클래스에 연속적인 확률값을 주는 continuous distribution임
    • 이러한 Continuous Distribution는 엔트로피가 높아, Student 모델이 데이터 간 미묘한 차이점까지 학습

      • 마치 더 어려운 문제를 풀게 되는 것과 같고, 표현력이 향상되게 됨
  3. Intra-class variance
    • 클래스 내부의 다양성 파악
    • “고양이” 클래스 안에도 털 색, 자세, 배경이 모두 다른 여러 인스턴스가 존재 → Teacher의 Soft label은 서로 다른 고양이 이미지에 대해 다양한 패턴을 만들어줌

      • 이로 인해, Student는 단순히 고양이면 무조건 1이 아니라
      • 고양이 내에서 다양한 모습들을 구분하는 능력을 갖춤과 함께, 모두가 다 고양이라는 공통점을 유지하는 더욱 풍부한 내부 표현을 배우게 됨
  4. Inter-class variance
    • 클래스 간 관계 파악
    • 서로 비슷한 클래스들(동물 클래스: 고양이, 개, 소 등) → 서로의 확률 값이 완전히 0이 되지 않고, Teacher 모델의 soft label에서 어느 정도 유사한 패턴으로 나타남

      • Student는 이 분포를 따라가면서, 클래스 간 간 관계까지 함께 학습
      • ex) “고양이는 개보다는 더 비슷하고, 소와는 덜 비슷하다.” 등
chat_bubble 0

chat_bubble 댓글남기기

댓글남기기