JAX로 나만의 AI 프로젝트를 만들고 싶다면?
#인공지능 

Hugging Face 그리고 JAX의 만남

전세계에서 가장 핫한 AI 커뮤니티인 Hugging Face는 구글 딥마인드에서 관리하고 있는 라이브러리인 JAX를 기반으로 하는 다양한 활동을 하고 있습니다. 이번 블로그에서는 whisper-jax와 커뮤니티 스프린트에 대해 알아보겠습니다.

2023-05-19 | 이영빈

전세계에서 가장 핫한 AI 커뮤니티를 하나 뽑는다면 저는 Hugging Face라고 생각합니다. Hugging Face의 transformers는 자연어처리에서 가장 많이 사용하고 있는 프레임워크로 자리잡았으며 Diffusers로 Diffusion 모델도 쉽게 구동할 수 있어 많은 사람들에게 사랑받고 있습니다. Hugging Face는 오픈소스를 기반으로 하기 때문에 직원뿐만 아니라 다양한 오픈소스 컨트리뷰터들이 생태계에 기여하고 있고 그 생태계를 활성화하고자 Hugging Face는 자체적으로 혹은 커뮤니티적으로 행사를 진행합니다.

그중에서 Hugging Face가 많이 신경 쓰고 있는 분야는 다름아닌 JAX입니다. JAX는 이전에 블로그()에서 소개해드린것처럼 구글 딥마인드에서 많이 사용하고 있는 고성능 딥러닝 프레임워크입니다. JAX의 경우 다른 프레임워크와 달리 XLA를 기반으로 JIT 컴파일을 지원하기 때문에 학습 속도 및 추론 속도가 다른 프레임워크와 비교할 수 없을 정도로 빠릅니다. 거기다가 JAX는 구글 전용 칩인 TPU에 최적화되어 있어 사용하기 용이합니다. Hugging Face도 JAX의 장점을 알고 있기 때문에 사내 프로젝트 혹은 커뮤니티 프로젝트로 진행하고 있습니다. 이번에 소개할 프로젝트는 Hugging Face 직원인 Sanchit Gandhi가 만든 Whisper-JAX와 Hugging Face가 주최한 커뮤니티 프로젝트인 JAX/DIffusers community sprint입니다.

Whisper-JAX

Whisper JAX는 OpenAI가 만든 음성인식 모델인 Whisper를 JAX로 최적화한 모델입니다. 이 모델은 Sanchit Gandhi가 메인으로 만들었습니다. Gandhi는 Google Cloud TPU를 지원하는 TPU Resarch Cloud(TRC) 프로그램을 이용해서 TPU를 1달 지원받아서 만들었다고 합니다.

whisper-jax

Gandhi가 Whisper를 JAX로 바꾼 이유는 서론에서 설명했다시피 “속도”에 있습니다. Gandhi는 똑같은 A100 40GB 1대로 PyTorch 모델, Hugging Face Transformers, Whisper JAX를 비교했으며 Whisper JAX를 TPU로 추론한 결과를 10번 진행하고 평균을 내서 비교했습니다. OpenAI가 만든 모델은 1분을 텍스트로 변환하는데 평균13.8초, 10분은 평균 108초, 1시간은 평균 1001초가 걸렸습니다. 반면 Whisper JAX는 1분을 텍스트로 변환하는데 평균 1.72초, 10분은 9.38초, 1시간은 평균 75.3초가 걸렸습니다. 만일 TPU를 쓰게 된다면 OpenAI 모델이 1분짜리 음성을 생성할 때 1시간짜리 음성을 텍스트로 변환합니다.

현재 Whisper JAX는 transformers 라이브러리에서 작동하며 HuggingFace hub에서 직접 실행가능하며 만일 TPU를 사용하고 싶다면 Kaggle notebook에서 사용할 수 있습니다.  하단에 링크를 누르면 실제 실습할 수 있습니다.

JAX/Diffusers 커뮤니티 스프린트

Hugging Face가 주최하고 커뮤니티 행사로 진행한 것은 JAX/Diffusers 커뮤니티 스프린트입니다. 커뮤니티 스프린트는 특정 딥러닝 라이브러리를 사용해 공통의 목표를 달성하는 기간 한정 프로젝트입니다. Hugging Face가 이제까지 주최했던 스프린트는 scikit-learn, Keras 등을 진행했고 JAX의 경우 Transformers 라이브러리, NeRF 등 많은 프로젝트를 진행했습니다. 이번에 진행한 커뮤니티 스프린트는 JAX를 Diffusers에 적용해 다양한 모델을 만드는 주제로 진행했습니다.

해당 스프린트는 최신 TPU인 TPU v4-8을 4/14일부터 5/1일까지 사용가능하며 프로젝트를 시작하기 위해서 ControlNet을 사용하는 방법을 알려주고 Hugging Face Space를 데모페이지를 만들고 좋아요를 많이 받은 순서대로 점수를 매겨 시상합니다.

jax-diffusers-first-prize

이번 스프린트에서 1등한 Space는 ControlNet-Interior-Design입니다. 해당 모델은 Segmentation map과 프롬프트를 모델의 정보로 사용해 고품질의 인테리어 디자인 이미지를 생성하도록 만든 것입니다. Segmentation map이 있기에 유저는 어떤 오브젝트를 배치할지 세밀하게 제어가능하게 됩니다. 실제 가구 위치나 인테리어를 직접 꾸밀 수 있어 해당 모델은 사업적인 효과도 꽤 클거라는 생각이 듭니다.

jax-second-prize

2등한 Space는 흑백 이미지에 컬러를 넣는 ControlNet on Brightness이며 3등은 손의 포즈를 예측하고 그걸 기반으로 새로운 이미지를 만드는 Stable Diffusion with Hand Control 입니다. 2개 모두 훌륭한 성능을 보여주고 있고 재미있는 프로젝트였다고 생각합니다.

앞으로의 JAX와 HuggingFace 그리고 모두연은?

JAX/Flax LAB

Hugging Face는 지속적으로 성능이 확실하기 때문에 JAX를 밀고 있습니다. 물론 JAX의 경우 PyTorch에 비해서는 커뮤니티적인 성장이 필수적인 것은 사실입니다. 그러나 전세계에서 가장 핫한 AI 커뮤니티가 Sprint형식으로 밀어줬을 때 파급효과는 지속적으로 있을거라고 생각합니다.

모두의연구소에서는 JAX와 Flax를 활성화하기 위한 LAB이 있습니다. JAX/Flax LAB은 현재 JAX-KR 프로젝트를 진행하고 있으며 그 이외에도 모두팝 세미나 등 많은 곳에서 JAX를 알리기 위한 노력을 하고 있습니다. 모두의연구소 JAX/Flax LAB에 많은 관심 부탁드립니다.