논문리뷰 – Neural Architecture Search with Reinforcement Learning
Neural Architecture Search(NAS)란?
Neural Architecture Search with Reinforcement Learning이라는 논문은 Google Brain(구글 브레인)에서 2016년에 발표한 논문으로 기본적인 아이디어는 적합한 신경망 구조를 예측하는 신경망을 자동적으로 만들어서, 기존에는 인간의 지식에 의해 설계되던 신경망 구조를 자동화된 방법으로 찾아내는 것입니다.
Auto ML을 제시
이는 AutoML이라는 새로운 연구방향을 제시한 선구적인 논문으로, 본 논문에는 좋은 아키텍쳐들을 자동적으로 찾기 위해 gradient-based 방법들을 Neural Architecture Search를 제안합니다. NAS의 기본 구조는 다음 그림 1과 같습니다.
Controller라는 RNN 네트워크가 다양한 하이퍼파라미터 값을 예측하고, 이를 기반으로 한 구조를 가진 Child Networks를 생성하여 특정 문제를 해결하도록 학습시킵니다. Controller Networks는 Child Networks가 문제 영역에서 보이는 성능에 따라 정확도 (Reward R)을 받습니다. 이 정확도 R을 리워드 신호로서 활용할 수 있고, 이 리워드를 토대로 강화학습(Policy Gradient 기법)을 통해 Controller RNN을 학습시킵니다.
수식과 함께 좀 더 자세히 설명하자면, 최적의 아키텍쳐를 찾기 위해서는, Controller가 예측된 리워드 R을 최대화하도록 할 수 있습니다.
리워드 신호 R은 미분할 수 없기 때문에, 우리는 theta를 반복적으로 갱신하기 위해 정책 경사법(policy gradient method)을 사용할 필요가 있스비다. 이 연구에서는 Williams (1992)의 REINFORCE 방법을 사용합니다. 다음은 REINFORCEMENT에 대한 설명을 한다기 보다는 Controller 역할을 RNN이 어떻게 예측을 이어나가는지를 확인해보겠습니다.
RNN을 이용한 CNN 모델 구조 생성
먼저 RNN Controller를 이용해 간단한 CNN 구조를 결정하는 하이퍼파라미터를 예측하는 경우를 생각해보겠습니다. 그림 2와 같은 RNN Controller를 통해 구조를 결정하는 하이퍼파라미터를 예측합니다. RNN의 각 time-step마다 예측하는 하이퍼파라미터는 다음과 같습니다:
- 필터의 높이 (Height)
- 필터의 너비 (Width)
- 스트라이드의 높이 (Height)
- 스트라이드의 너비 (Width)
- 필터의 개수 (Number of filters)
위 과정을 CNN 레이어 수만큼 반복합니다.
각각의 time-step에서 Softmax Classifier를 통한 예측이 수행되며, 이는 다음 time-step의 input으로 사용됩니다. (간단히 말해서, Char-RNN 구조와 유사합니다.)
특정 레이어 수 이상이 예측되면 예측을 중지하고, 예측된 값에 기반한 구조로 CNN을 만들고 학습시킵니다. 학습된 CNN의 검증 데이터에 대한 정확도를 기록하고, RNN은 다시 검증 정확도를 높이는 방향으로 Policy Gradient 기법을 이용해 학습됩니다.
Parallelism(병렬화)과 Asynchronous Update (비동기적 업데이트)를 통한 학습 가속화
RNN이 생성한 하나의 Child Network를 학습시키는 데도 몇 시간이 걸릴 수 있습니다. 따라서 빠른 학습과 탐색을 위해 병렬화와 비동기적 업데이트를 구현했습니다. 구체적으로는 파라미터 서버 구조를 사용했습니다. 파라미터 서버의 하이퍼파라미터는 다음과 같습니다:
- S: 파라미터 서버 shard의 개수
- K: 컨트롤러 복제본의 개수
- m: Child Network의 개수
아래 그림 3은 parameter-server구조를 나타낸 것입니다.
각각의 Child Network는 병렬로 학습되고, m개의 미니배치를 통해 계산된 각각의 컨트롤러의 그래디언트는 파라미터 서버로 모아집니다. 파라미터 서버는 이들을 종합하여 다시 각각의 컨트롤러 복제본의 파라미터를 갱신합니다.
Skip Connection 추가를 통한 모델 복잡도 향상
현대적인 CNN 구조는 ResNet에서 제안된 Skip Connection을 활용하여 더 복잡한 구조를 사용합니다. 따라서 RNN Controller가 예측하는 구조를 결정하는 하이퍼파라미터에 Skip Connection을 추가하여 더 복잡한 Child Network를 생성할 수 있습니다.
Neural Architecture Search를 이용한 RNN 구조 생성
NAS는 CNN뿐만 아니라 RNN을 생성하는 RNN Controller를 만들 수도 있습니다. RNN의 각 요소를 나눠서 생각하면 RNN은 트리 형태로 표현할 수 있습니다. 따라서 트리 구조를 기반으로 RNN 구조를 생성하는 RNN Controller는 다음과 같습니다. 마지막 두 개의 레이어는 LSTM에서 사용하는 cell state 와 을 트리의 어느 부분과 연결할지를 결정합니다.
Training & Experiments
CIFAR-10 데이터셋에 대해 Neural Architecture Search, NAS로 찾아낸 CNN 구조의 성능을 평가했습니다. 구체적으로 아래와 같이 구성하여 적합한 CNN 구조를 예측했습니다:
- ReLU 활성화 함수 사용
- 배치 정규화 사용
- Skip Connection 사용
- 필터 높이는 [1, 3, 5, 7] 중에서 예측
- 필터 너비는 [1, 3, 5, 7] 중에서 예측
- 필터 개수는 [24, 36, 48, 64] 중에서 예측
- 스트라이드 높이와 너비는 [1, 2, 3] 중에서 예측
학습을 위한 RNN 구조는 각 레이어마다 35개의 노드를 가진 2-layer LSTM으로 구성했습니다.
파라미터 서버를 위한 하이퍼파라미터는 다음과 같이 설정했습니다:
- S: 파라미터 서버 shard의 개수 – 20
- K: 컨트롤러 복제본의 개수 – 100
- m: Child Network의 개수 – 8
즉, 한 번에 800개의 구조를 800개의 GPU를 이용해 학습했습니다. 또한 Child Networks 하나 당 50번의 epoch을 학습시켜 CIFAR-10 분류에 적합한 값으로 파라미터를 갱신했습니다.
최종적으로 12,800개의 Child Networks를 탐색하고, 그중에서 가장 좋은 성능을 보인 구조들을 기존의 최첨단(SOTA) CNN 모델들과 비교한 결과는 다음과 같습니다.
표에서 볼 수 있듯이, 아무런 사전 지식 없이 기존의 SOTA 모델에 버금가는 성능을 가진 CNN 구조를 자동으로 찾아냈습니다.
또 다른 예시로 Neural Architecture Search를 이용해서 찾아낸 CNN중 하나를 그려보면 다음과 같습니다.
References
[1] https://arxiv.org/pdf/1611.01578
[2] http://solarisailab.com/archives/2691
[3] https://ysbsb.github.io/nas/2022/07/27/NAS.html