[NIPS 2021] ABC: Auxiliary Balanced Classifier for Class-Imbalanced Semi-Supervised Learning
ABC: Auxiliary Balanced Classifier for Class-Imbalanced Semi-Supervised Learning
1. Problem Definition
Class imbalanced setting에서의 classification을 수행한다.
현존하는 Semi-supervised learning(SSL) 알고리즘은 class balanced dataset을 가정한다. 하지만, 현실에 존재하는 많은 dataset들은 class distribution이 imbalanced 되어있다. Imbalanced set에서 학습된 모델은 majority class에 편향된 모습을 보인다. SSL 알고리즘에서는 Unlabeled data에 대한 pseudo-label을 training에 사용하기 때문에 이 문제가 더욱 두드러진다. 따라서 본 논문에서는, Class imbalanced semi-supervised setting에서 이 biasedness를 완화해줄 새로운 방법을 제시한다. SSL 알고리즘 ReMixMatch가 majority class에 biased하게 학습함을 (a),(b)를 통해 확인 할 수 있다. 이에 반해 논문에서 제시하는 알고리즘을 사용하면 비교적 balanced한 prediction을 할 수 있다.
2. Motivation
Representation과 Classifier의 학습을 decoupling하여 balanced classifier를 얻을 수 있다.
DECOUPLING REPRESENTATION AND CLASSIFIER FOR LONG-TAILED RECOGNITION(ICLR2020) Link for Paper 논문에서는 representation과 classifier의 학습을 decoupling 했을 때, class imbalanced setting에서 성능이 향상된다고 주장한다. 이에 따라 본 논문은 SSL backbone의 representation layer 뒤에 auxiliary classifier를 덧붙여 representation과 auxiliary classifier의 학습을 decoupling 했다.
Training 후 test에서는 이 auxiliary classifier만 사용된다.
이 auxiliary classifier를 학습하는 과정에서 해당 sample의 학습 여부를 결정하는 mask를 사용한다.
덕분에 모든 sample을 backbone 학습에 사용할 수 있고, 기존의 Class re-balancing 기법 중 Oversampling이나 Undersampling의 단점인
overfitting, information loss를 극복하며 high-quality의 representation을 얻을 수 있다.
더불어 SSL 알고리즘은 decision boundary를 low-density region에 놓기 위해 unlabeled data를 활용하는데, 이 consistency regularization 과정에서 mask를 사용하기 때문에 더욱 balanced한 방향으로 consistency regularization을 할 수 있다.
3. Method
Problem setting
Labeled dataset : $\chi = \lbrace(x_ {n},y_ {n}) : n \in (1,..,N)\rbrace$ , $x_ {n} \in R^{d}$ is the $n$th labeled data point and $y_ {n} \in \lbrace1,…,L\rbrace$ is the corresponding label.
Unlabeled dataset : $\mu = \lbrace(u_ {m}) : m \in (1,..,M)\rbrace$, $u_ {m} \in R^{d}$ is the $m$th unlabeled data point.
$\beta = \frac{N}{M+N}$ : ratio of the amount of labeled data (generally $\beta<0.5$)
$N_l$ : number of labeled data of class $l$, So $\displaystyle\sum_ {l=1}^{L} {N_ l} = N$, assume $N_ {1} \ge N_ {2} \ge … \ge N_ {L}$
$\gamma = \frac{N_ 1}{N_ L}$ : ratio of class imbalance, $\gamma » 1$.
Assume that $\chi$ and $\mu$ share the same class distribution, i.e., both datasets are class-imbalanced to the same extent.
$M\beta_ {\chi} = \lbrace(x_ {b},y_ {b}) : b \in (1,..,B)\rbrace \subset \chi$ and $M\beta_ {\mu} = \lbrace(u_ {b}) : b \in (1,..,B)\rbrace \subset \mu$ are minibatches generated by $\chi$ and $\mu$, and $B$ is the minibatch size.
Using these minibatches for training, we aim to learn a model $f:R^{d} \to \lbrace1,…,L\rbrace$
Backbone SSL algorithm
ABC는 backbone의 representation layer에 붙어있으므로, backbone에 의해 학습된 high-quality의 representation을 사용할 수 있다. Fixmatch 혹은 Remixmatch를 backbone으로 사용하며, 이들은 당시 SSL 알고리즘 중에서 SOTA 모델이었다.
FixMatch는 weakly augmented labeled data point $\alpha(x_ b)$를 classification loss 계산에 사용한다. 그리고 consistency regularization loss는 weakly augmented unlabeled data $\alpha(u_b)$와 strongly augmented unlabeled data point $A(u_ b)$ 를 이용해 계산한다.
RemixMatch는 weakly augmented unlabeled data $\alpha(u_ b)$의 label을 distribution alignment와 sharpening으로 예측하고, strongly augmented unlabeled data $A(u_ b)$에 label을 부여한다. $A(u_ b)$와 $A(x_ b)$가 mixup regularization을 위해 활용된다. ReMixMatch는 FixMatch와 비슷한 방식으로 consistency regularization이 진행되는데, imgae의 rotation을 활용하여 self-supervised learning을 한다.
두 알고리즘은 SSL performance의 향상에 큰 기여를 하였지만, Imbalanced setting에서는 majority class에 편향되어있다.
ABC for class-imbalanced Semi-supervised learning
ABC를 balanced하게 학습하기 위해, 먼저 $M(x_ b)$ mask를 생성한다. Labeled data $x_ b$에 해당 class의 data 개수와 inversely proportional하게 parameter를 설정하고, Bernoulli distribution $\beta()$ 를 이용한다. 즉, data의 개수가 적은 minority class의 mask는 1이 될 확률이 높고, 반대로 data 개수가 많은 majority class의 mask는 0이 될 확률이 높다.
이 mask는 classification loss에 곱해지고, 이로 인해 ABC는 balanced classification loss로 학습된다. mask를 곱해주는 것은 minority class에 대해서는 oversampling, majority class에 대해서는 under sampling이라고 할 수 있다. 하지만 representation learning에서는 모든 sample이 사용되고 ABC를 학습하는 과정에서만 class re-balancing이 일어나기 때문에 앞서 설명한 re-balancing의 단점을 극복할 수 있다.
ABC의 classification loss는 0/1 mask $M()$을 이용하여 다음과 같이 나타낼 수 있다.
$L_{cls} = \frac{1}{B} \displaystyle\sum_ {b=1}^{B} M(x_ b)H(p_ {s}(y \vert \alpha(x_ b)),p_ {b})$
where $M(x_ b) = \beta(\frac{N_ L}{N_{y_ b}})$ , $H$ is the standard cross-entropy loss, $p_ s$ is the predicted class distribution using ABC for $\alpha(x_ b)$, and $p_ b$ : one-hot label for $x_ b$
위의 내용을 Figure2를 통해 직관적으로 이해할 수 있다.
Consistency Regularization for ABC
Unlabeled data를 이용하여 decision boundary의 margin을 증가시키기 위해 Consistency regularization을 한다. FixMatch와 비슷하게 먼저 predicted class distribution $p_ {s}(y \vert \alpha(u_ b))$ 를 먼저 구하고 이를 soft pseudo-label $q_ b$로 사용한다. 그 후 strongly augmented unlabeled data $A_ {1}(u_ b)$, $A_ {2}(u_ b)$에 대해 $p_ {s}(y \vert A_ {1}(u_ b))$ , $p_ {s}(y \vert A_ {2}(u_ b))$를 $q_ {b}$에 가깝도록 학습한다.
Class imbalanced setting에서는 많은 unlabeled data가 majority class로 예측될 수 있고, consistency regularization이 majority class 중심으로 일어날 수 있다. 이는 classifier의 bias를 야기한다. 이를 방지하기 위해, FixMatch와는 다른 방식으로 consistency regularization을 한다. Fixmatch에서는 entropy minimization을 위해 weakly augment point의 predicted class distribution을 one-hot pseudo-label로 변환하지만, 이는 bias를 가속화할 수 있기 때문에 본 논문에서는 soft-pseudo-label을 그대로 사용한다. 대신 0/1 mask $M()$를 활용하여 balanced consistency regularization을 하게 한다. Consistency regularization loss는 다음과 같다. $L_ {con} = \frac{1}{B}\displaystyle\sum_ {b=1}^{B}\displaystyle\sum_ {k=1}^{2}M(u_ b)I(max(q_ b)\ge\tau)H(p_ {s}(y \vert A_ {k}(u_ b)),q_ {b}),$ where $M(u_ b) = \beta(\frac{N_ L}{N_ {\hat{q_ {b}}}})$ and $I$ is the indicator function, max$(q_ b)$ is the highest predicted assignment probability for any class, $\tau$ is the threshold.
부정확한 soft pseudo-label이 consistency regularization에 영향을 주지 않게 하기 위해서 threshold를 부여하여 confident unlabeled sample만 선별하기 위해 threshold를 추가했다. $\hat{q_ b}$은 one hot vector of pseudo label from $q_ b$ 이며, 초반에는 threshold를 넘는 unlabeled sample이 거의 없기 때문에 Bernoulli distribution $\beta()$의 parameter를 1로 시작하여 점차 $\frac{N_ L}{N_ {\hat{q_ {b}}}}$로 줄여나간다.
End to End training
다른 Class imbalanced learning model은 representation learning이후에 classifier를 finetune하는데(즉, representation과 classifier의 학습을 decouple) 반해, 본 논문은 end-to-end training을 하면서도 balanced classifier를 학습할 수 있었다.
모델 전체의 Loss function은 $L_ {total} = L_ {cls} + L_ {con} + L_ {back}$ 이며, test sample을 test할 시에는 Auxiliary Balanced Classifier, ABC만 사용한다. $L_ {back}$은 backbone의 loss로 SSL algorithm의 loss를 그대로 사용한다.
4. Experiment
Experiment setup
-
Dataset CIFAR-10, CIFAR-100 and SVHN가 class imbalanced version으로 수정 후 사용되었다. 다양한 class imbalance ratio $\gamma$ 와 labeled data의 비율 $\beta$를 사용했고, $N_ {k} = N_ {1} * \gamma^{-\frac{k-1}{L-1}}$, where $\gamma = \frac{N_ 1}{N_ L}$ 이다. Main setting은 $\gamma=100, N_ {1}=1000, \beta=20%$ for CIFAR-10 & SVHN $\gamma = 20, N_ {1}=200, \beta=40%$ for CIFAR-100이다. CIFAR-100의 $\gamma$가 상대적으로 작은 이유는 CIFAR-100은 각 class마다 500개의 sample만 존재하기 때문이다. Large-scale dataset에서도 실험을 하기 위해 LSUN에서 7.5M개의 256*256image에 대해서도 실험하였다.
- baseline
- Deep CNN (vanilla algorithm): Cross entropy loss를 사용하여 labeled data로만 학습하였다. (WideResNet-28-2를 사용)
- BALMS (CIL algorithm): Unlabeled data를 사용하지 않는 class imbalanced learning 알고리즘 중 SOTA 모델이다.
- VAT, ReMixMatch, and FixMatch (SSL algorithms): 이들은 SOTA SSL 알고리즘이지만, clas imbalance를 고려하지 않는다.
- FixMatch+CReST+PDA and ReMixMatch+CReST+PDA (CISSL algorithms): CReST+PDA는 minority class로 판단된 unlabeled data를 majority class로 판단된 unlabeled data보다 더욱 많이 사용하여 class imbalance를 완화한다.
- ReMixMatch+DARP and FixMatch+DARP (CISSL algorithms): DARP를 이용하여 unlabeled data의 pseudo label의 분포를 refine한다.
- ReMixMatch+DARP+cRT and FixMatch+DARP+cRT (CISSL algorithms): 위의 알고리즘에 더하여, cRT를 이용해 classifier를 finetuing한다.
- Evaluation Metric
Result
각 setting과 dataset에 대한 실험 결과이다.
저자는 ABC가 backbone에 의해 학습된 high-quality의 representation을 이용할 수 있다고 주장한다. 다음의 t-SNE는 이를 뒷받침하고 있는데, (a)는 ABC가 SSL backbone없이 학습되었을 시의 representation이다. 0/1mask에 의해 충분한 data로 feature를 학습하지 못했기 때문에, (b)와 (c)에 비해서 discriminative하지 못한 모습을 보이고 있다.
다음은 CIFAR-10에서의 confusion matrix이다. ABC를 이용했을 때 minority class에서의 accuracy가 많이 향상된 모습을 볼 수 있다.
5. Conclusion
기존의 SSL 알고리즘에 보조적인 Classifier를 덧붙여 학습시켜 SSL 알고리즘이 주는 high-quality의 representation을 이용함과 동시에 balanced한 classifier를 얻을 수 있었다. 더불어 consistency regularization에서도 0/1mask를 이용해 balancing을 추구하였다. 덕분에 Class imbalanced Semi Supervised Learning에서 SOTA를 달성하였다.
다만 이 논문은 unlabeled data와 labeled data가 같은 정도로 편향되었다고 가정하기 때문에 두 data의 class imbalace ratio가 다를 경우에는 다른 알고리즘보다 낮은 accuracy를 보인다. 이는 본 논문에서도 지적하고 있는 내용이며, future work의 핵심 problem이 될 수 있다.
Author Information
-
Author : Park Tae Min
-
Affiliation : iStatLab in KAIST
-
Research Topic : Class imbalanced Semi supervised learning
6. Reference & Additional materials
Please write the reference. If paper provides the public code or other materials, refer them.