[ICLR 2022] Online Coreset Selection for Rehearsal-based Continual Learning
Title
Online Coreset Selection for Rehearsal-based Continual Learning
1. Problem Definition
Static한 setting에 맞춰져 있는 현재의 Learning process는 현실의 상황과는 거리가 멀다.
- Sequence of tasks에 continuously 적용될 수 있는 model 고안하는 것이 본 논문의 주된 목적이다.
- Sequence of tasks에 continuously 적용될 수 있는 model을 고안하는 것이 본 논문의 주된 목적이다.
- Continual Learning에서 발생하는 주된 문제인 catastrophic forgetting 문제도 보완한다.
- 더욱 specific한 setting으로 imbalanced / noisy한 data 상황에서도 높은 accuracy와 적은 catastrophic forgetting을 목표한다.
2. Motivation
2.1 Continual Learning과 Catastrophic Forgetting
Continual Learning은 많은 관심을 받고 있는 연구 분야이며, 눈에 띌만한 성장세를 보이고 있다. 현재까지의 Learning Scenario는 static한 setting에 초점이 맞춰져 개발되었다. 하지만 현실에서의 setting은 dataset이 고정되어 있지 않고, 새로운 data / class 등이 끊임없이 추가된다. 이러한 상황에서 model은 정확성을 지속적으로 유지할 수 있어야 한다. 그렇다면 이러한 setting에서 새로운 task까지 잘 해내는 모델을 학습해야 한다면 어떻게 해야할까? 당연히 모델을 retraining 시켜야한다. 모델을 retraining 시키기 위해 아래 두 가지 방법을 쉽게 떠올려 볼 수 있다. 첫째, 기존 데이터에 새로운 데이터까지 추가해서 모델을 처음부터 다시 학습하는 방법이다. 이 방법이 직관적일 수 있지만, 새로운 데이터가 수집될 때마다 전체 데이터셋에 대하여 모델의 모든 가중치값들을 학습하는 것은 시간과 computational cost 측면에서 큰 손실이다. 그렇다면, 모델을 새로운 데이터로만 retraining 시키면 어떻게 될까? 이전에 학습했던 데이터와 유사한 데이터셋을 학습하더라도 아래의 그림처럼 이전의 데이터셋에 대한 정보를 잊어버리게 될 것이다. 이 문제를 일컬어 Catastrophic Forgetting 이라고 부른다.
Catastrophic Forgetting : Single task에 대해서 뛰어난 성능을 보인 모델을 활용하여 다른 task를 위해 학습했을 때 이전에 학습했던 task에 대한 성능이 현저하게 떨어지는 현상
Catastrophic forgetting은 neural network의 더욱 general한 problem인 “stability-plasticity” dilema의 결과이다. 이 때, stability는 previously acquired knowledge의 보존을 의미하고, plasticity는 new knowledge를 integrate하는 능력을 의미한다.
2.2 Limitation
- Training instances are not equally useful!
- Current task를 학습하는데 더 representative / informative한 data들이 있다.
- 그렇지 않은 data를 쓸 경우 오히려 성능을 떨어뜨릴 가능성이 있다.
- Imbalanced / Noisy instances
- Real-world에서는 data가 imbalanced / noisy한 경우가 있다.
2.3 Purpose
- Real-world에서는 data가 imbalanced / noisy한 경우가 있다.
- 새로운 task를 학습할 때 이전 task에 대한 catastrophic forgetting 방지.
- 새로운 task 학습을 용이하게 하기 위해 이전 task의 knowledge를 사용.
- Online Corset Selection (OCS) 방법론을 고안하여 representative하고 diverse한 subset을 선정하여 buffer에 저장하고 새로운 task 학습에 함께 사용.
- Current task에 대해서도 모든 data를 사용하는 것이 아닌 이전 task의 buffer들과 high affinity를 갖는 data를 선정하여 함께 training 시킴.
- Online Corset Selection (OCS) 방법론을 고안하여 representative하고 diverse한 subset을 선정하여 buffer에 저장하고 새로운 task 학습에 함께 사용.
- Current task에 대해서도 모든 data를 사용하는 것이 아닌 이전 task의 buffer들과 high affinity를 갖는 data를 선정하여 함께 training 시킴.
- buffer : 이전 task의 dataset 중 현재의 task dataset과 함께 training 시키기 위해 저장시키는 data의 subset
- affinity : 현재의 task dataset이 이전 task의 buffer들과 갖는 친밀도(유사도)를 의미한다.
2.4 Contributions
- Class-imbalanced / noisy setting이 존재하는 continual scenario를 다루었다.
- OCS (Online Coreset Selection)이라는 simple하지만 effective한 방법론을 개발하였고, 이는 similarity (대표성) & diversity (overfitting 방지)를 함께 고려하여 replay시킬 data point를 select한다.
- 다양한 세팅으로 이루어진 여러 실험을 통해 각 component와 제안한 model의 효과를 입증하였다.
3. Method
3.1 Online Coreset Selection
이 부분에서는 주어진 task에서 어떠한 기준으로 replay 시킬 data를 선정하는지에 대해 설명합니다.
크게 두가지의 기준을 적용하는데, “similarity”와 “diversity”입니다.
3.1.1 Minibatch similarity
$b_{t,n} = {x_{t,n}, y_{t,n}} \in B_t$는 data point의 n-th pair를 의미하고, 분모의 좌측에 있는 식은 해당 datapoint의 gradient를 의미한다. 또한, 분모의 우측에 있는 식은 집합 $B_t$내에 있는 data들의 gradient의 평균을 의미한다.
즉, 이 식은 특정 data point의 gradient와 집합 $B_t$내의 data들의 gradient의 평균 간의 similarity를 나타낸 식이다.
3.1.2 Sample diversity
본 식에서는 특별히 새롭게 설명할 notation은 없을 것이다. 본 식은 특정 data point $b_{t,n}$과 subset 내의 다른 datapoint $b_{t,p}$ 간의 dissimilarity의 평균이다. 따라서 값이 클수록 subset 내의 다른 data와 다른, 즉 다양성을 갖는 data point라는 것이다.
3.2 Online Coreset Selection for Current Task Adaptation
이제 위의 section 3.1에서 다룬 두 가지 기준 “similarity”와 “diversity”를 고려하여 replay 시킬 data를 뽑아야 할 것이다.
Similarity와 diversity 값을 더하여 그 값이 가장 큰 top k개를 선정한 $u^{*}$집합을 선정한다.
그 이후 아래와 같이 replay할 data를 갖고 loss가 최소가 되도록 model을 training 시키는 간단한 방법론을 제시하였다.
3.3 Online Coreset Selection for Continual Learning
지금부터는 저자가 제시한 OCS (Online Coreset Selection) 방법론에 대해 구체적으로 다룰 것이다.
OCS 방법론의 목적은 previous task의 지식을 앞서 다룬 similarity와 diversity의 관점에서 고려하여 현재 task에서 활용도가 높은 coreset을 찾는 것이다.
더 직관적으로 설명하자면, 현재 task에 대해서는 모든 dataset을 사용할 수 있는 것 아닌가라는 의문이 들 수 있다. 하지만 늘 그렇듯 real-world dataset에는 noise가 있기도 하고, 틀리지 않은 data 이지만 이전 task가 지향하는 방향과는 방향성이 다를 수 있다. 이에, 저자는 현재 task 이더라도, continual한 세팅에서 sequential한 학습에 도움이 되는 data subset을 선정하여 그 data들에 대해서만 training을 진행한다.
3.3.1 Coreset Affinity
위의 similarity 수식과 굉장히 유사하다. 분모의 우측에 있는 식이 의미하는 것은 coreset C로부터 randomly sampled 된 subset $B_c$에 대한 gradient의 평균이다. 따라서 이는 현재 task의 data distribution만 고려하는 것이 아니라 이전 task의 coreset과의 similarity도 고려한다는 의미이다.
그렇다면 새로운 data selection equation은 아래와 같이 구성된다.
그리고, 마찬가지로 아래와 같은 수식을 통해 current task의 coreset과 이전 task들에서 replay된 data들의 loss를 최소화하는 parameter를 찾는 방향으로 model이 training된다.
3.4 Algorithm
위의 방법론을 하나의 algorithm으로 정리하면 아래와 같다.
4. Experiment
4.1 Experiment setup
4.1.1 Dataset
- Domain Incremental
- Task간에 겹치는 class가 있다. 즉, 모든 task에 class가 섞여서 존재한다.
- Rotated MNIST
- Task Incremental
- Task간에 겹치는 class가 없는 setting이다. Test 상황에서 data point가 속한 task의 class에 대해 test한다. (task 정보 있음)
- Split CIFAR-100 / Multiple Datasets (a sequence of five datasets)
- Class Incremental
- Task간에 겹치는 class가 없는 setting이다. Test 상황에서 전체 class에 대해 test한다. (task 정보 없음)
- Balanced and ?Imbalanced Split CIFAR-100
4.1.2 baseline
OCS과의 비교를 위해 continual setting에서 아래의 모델들과 비교하였다.
~~~
- EWC
- Stable SGD
- A-GEM
- ER-Reservior
- Uniform Sampling & k-means features
- k-means Embeddings
- iCaRL
- Grad Matching
- GSS
- ER-MIR
- Bilevel Optim
~~~
4.1.3 Evaluation Metric
본 논문의 주된 목적은 continual learning에서 고질적으로 발생하는 문제인 catastrophic forgetting을 줄이기 위함이므로 이에 알맞은 evaluation metric을 저자는 제안한다.
- Average Accuracy : 일반적인 accuracy value이다.
- Average Forgetting : 이후 task를 학습하고 난 뒤, task의 accuracy가 떨어지는 정도를 측정한 값이다.
4.2 Result
4.2.1 Quantitative Analysis for Continual Learning
- Baseline model 모두 일정 수준의 catastrophic forgetting은 발생하는 것을 관찰할 수 있다.
- Balanced continual learning setting에서 random replay based methods (A-GEM & ER-Reservoir)과 비교하면 OCS는 average accuracy 관점에서 약 19%의 gain이 있다.
- 마찬가지로, balanced continual learning setting에서 forgetting average도 다른 baseline보다 현저히 낮은 수치가 관찰된다.
- OCS는 imbalance setting에서 balanced setting에서보다 더욱 큰 강점을 보였다.
- Accuracy와 forgetting 측면에서 baseline model들보다 훨씬 좋은 성능을 보였고, 이는 baseline model에서는 imbalance 상황에서 current task에 대해 coreset을 select하는 과정이 없으므로 biased estimation이 진행되어 performance degenerate이 일어났다고 볼 수 있다.
4.2.2 Noisy Continual Learning
- Gaussian noise를 적용하여 Rotated MNIST dataset을 noise하게 setting하였다.
- 위의 table을 보면, noise는 모든 baseline의 성능을 상당히 저하시키는 것을 관찰할 수 있다.
- 하지만 저자가 제안한 OCS의 경우, noise rate이 증가함에 따라 accuracy와 forgetting이 심각하게 저하되지는 않는 것으로 보여진다. 이는 task 내에서 similarity와 diversity를 고려하여 coreset을 선정하는 과정이 noise data를 상당부분 제외시키는 것으로 해석 가능하다.
- 하지만 저자가 제안한 OCS의 경우, noise rate이 증가함에 따라 accuracy와 forgetting이 심각하게 저하되지는 않는 것으로 보여진다.
- 이는 과거 task에 대해 similarity와 diversity를 고려하여 coreset을 선정하고, 현재 task에 대해서도 affinity를 고려하여 coreset을 선정하기 때문에 noise data를 상당부분 제외시킨 채로 학습을 진행하기 때문인 것으로 해석이 가능하다.
4.2.3 Ablation Studies
- 본 실험은 gradient 활용의 효과를 검증한 실험이다.
- Gradient를 활용하여 coreset selection을 한 경우와 raw input (Input-OCS), feature-representations (Feat-OCS)를 활용하여 coreset selection을 한 경우를 비교하였는데, balanced / imbalanced CL setting에서 모두 gradient가 다른 두 방법에 비해 좋은 성능을 보였다.
- Coreset에 들어갈 top k개의 data를 고르는 과정에서 “Minibatch similarity”, “Sample diversity”, “Coreset affinity”라는 세가지 식이 적용된다.
- 본 실험은 이러한 세가지 component의 효과를 관찰하기 위해 component를 제외해보며 abalation study를 진행하였다.
- Similarity와 diversity를 혼자만 사용하는 것은 성능 저하가 상당했다. Similarity만 사용할 경우, 중복되는 data를 선정할 가능성이 있고, diversity만 고려할 경우 representative한 data point를 선정하는데에 한계가 있을 것이다.
- 따라서 similarity와 diversity의 고려 비율을 적절히 interpolate해야 좋은 성능이 도출될 것이고, 아래의 그림은 “noisy rotated MNIST”, “multiple dataset”에서 각각의 interpolate ratio에서의 average test accuracy를 나타낸다.
- Affinity는 이전 task의 coreset과 유사한 gradient direction을 가진 candidates를 선정하기 때문에 forgetting을 방지한다.
- 하지만, affinity는 앞선 task의 replay buffer의 quality에 dependent하기 때문에 홀로 고려될 경우 performance에 큰 도움을 주지 못한다.
- 따라서 위의 table을 보면, multiple dataset에서 affinity가 홀로 고려될 경우에는 performance가 낮지만, task마다 동일한 class set이 혼재되어있는 domain-incremental setting (Noisy Rot-MNIST)에서는 꽤나 좋은 성능이 관찰된다.
5. Conclusion
5.1 Summary
- Vision 분야의 continual learning problem에서 catastrophic forgetting을 방지할 수 있는 framework를 제안함.
- Continual learning의 큰 줄기 중에서 replay 방식을 채택하였고, online coreset selection을 접목시켜 similarity (대표성), diversity (overfitting 방지)를 동시에 고려하여 가장 영향력이 높은 node를 buffer에 저장하도록 함.
- 다양한 실험을 통해 targeting problem의 예시를 실제로 보임.
5.2 Discussion
- 본 논문의 가장 주요한 novelty는 replay를 할 때, similarity (대표성), diversity (overfitting 방지)를 함께 접목시킨 것이다.
- 또한, 저자는 current task 내에서도 이전 task의 replay buffer와의 관계를 고려하여 data를 select하기 때문에 sequential task의 순서가 바뀌더라도 robust하게 대응할 수 있는 order-robustness의 특성을 가진다고 주장한다.
- Method에 담은 algorithm에서의 task의 개수를 미리 알고 있다는 가정은 real-world setting에 맞지 않은 strong assumption일 수 있겠다.
-
이렇듯 replay 방식에서는 buffer에 넣을 node를 선택하는 방법론이 매우 주요할 것이고, experience selection strategy에 대하여 연구해보는 것도 좋은 future research topic이 될 것 같다.
Author Information
- Seungyoon Choi
- Affiliation : DSAIL@KAIST
- Research Topic : GNN, Continual Learning, Active Learning
6. Reference & Additional materials
6.1 Github Implementation
- https://github.com/jaehong31/OCS
6.2 Reference
- Yoon, Madaan, Yang, Hwang. “Online Coreset Selection for Rehearsal-based Continual Learning.”
- Yoon, Yang, Lee, Hwang. “Lifelong Learning with Dynamically Expandable Networks.”