비즈니스 문제를 해결하고 예측하는 데이터 사이언티스트가 되고 싶다면?
#인공지능 

고성능 딥러닝 프레임워크: JAX/Flax

JAX는 Google에서 만든 라이브러리입니다. JAX/Flax의 모두팝 발표 소개와 JAX에 대해 알아봅시다.

2023-02-22 | 유현아

안녕하세요? 오늘은 딥러닝계의 뉴진스! 핫한 JAX를 소개해 드릴게요.
JAX(잭스)는 Google에서 만든 라이브러리입니다. 기존에 있던 NumPy 대신 활용할 수 있고, CPU, GPU, TPU에서 코드 변경 없이 사용 가능하며 빠르다는 장점으로 머신러닝에서 활용합니다. 최근 Google Research에서 최신 모델을 JAX로 구현하는 등 활용도가 높아지고 있습니다. 지난 2월 7일 <Google과 Huggingface에서 밀고 있는 고성능 딥러닝 프레임워크: JAX/Flax>란 제목으로 박정현님께서 세미나를 하였고, 해당 영상은 모두의연구소 유튜브에서 보실 수 있습니다.

JAX란?

JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산 작업을 위해 만든 프레임워크입니다. 자동미분(Autograd)이 가능하고, JIT 컴파일 방식으로 속도 측면에서 강점을 보입니다. 기존  LLM같이 거대 모델의 경우, 학습 속도에 효율적으로 활용될 수 있습니다. 실제 SOTA 모델들 가운데 활용이 늘고 있는 추세입니다. 국내 자료가 많지 않기 때문에 진입장벽이 높은 편입니다.

JAX의 장점과 단점

  • 장점
    • NumPy와 문법이 유사하여 활용이 쉽다.
    • CPU, GPU, TPU의 코드가 동일하다
    • 속도가 빠르다. JIT 컴파일 방식으로  더 빠른 속도로 실행한다.
  • 단점
    • JAX는 functional programming를 따르기 때문에 NumPy와 다르게  유의할 점이 많다.
    • 최신 모델들에 활용은 많지만, 전통적인 모델들에 대한 코드가 부족하다.

 

JAX/Flax를 소개한 모두팝 소개

아까 진입장벽이 높다고 하였는데, 그 벽을 낮추기 위해서 JAX/Flax LAB에서 활동하시는 박정현님께서 모두팝에서 발표를 하셨습니다. JAX에 대해 자세히 알고 싶은 분들은 🔗모두팝 영상을 시청 바랍니다.

JAX 101

JAX를 발표하시는 연사님

JAX를 입문하시는 분들을 대상으로 발표하였습니다. JAX를 입문하기 이전에는 기본적인 머신러닝 지식과 Python 활용 능력을 갖추셔야 합니다. 추가로 JAX를 활용한 프레임워크들에 대한 소개와 Flax 101까지 함께 설명하였습니다. 생태계를 이해하고, 좀 더 편리하게 사용할 수 있는 Flax를 알 수 있었습니다.

JAX/Flax LAB 소개

JAX/플랙스 랩

JAX/Flax LAB은 기존에 딥러닝을 연구하는 분들과 JAX와 Flax라는 새로운 프레임워크를 연구합니다. 영어로 작성된 튜토리얼을 함께 번역하거나 TensorFlow나 PyTorch를 이용하고 있는 분들께서 활용할 수 있도록 JAX/Flax로 변환하는 프로젝트를 진행하고 있습니다. 국내 JAX생태계를 구축하여 많은 사람이 JAX를 활용할 수 있도록 하고자 합니다. 잭스 Document 한국어 번역에서는 공식 문서를 번역하고 있습니다. 번역과정에서 그라운드 룰을 정하거나 업무 분담을 하고 용어를 정리하는 등 여러 검수 절차를 갖춰서 진지하게 번역에 임하고 있습니다. JAX/Flax 코드 변환에서는 최신 모델에 대한 변환은 많지만, 기본적인 모델들에 대한 Reference를 추가할 예정입니다. 앞으로의 활동이 궁금하신 분들은 모두의연구소 커뮤니티JAX KR을 방문해주세요.

JAX를 공부하면 좋은 자료 & Reference

 

📍모두의연구소 모두팝(MODUPOP)에서 양질의 세미나가 매주 진행됩니다. 재미난 주제에 관심 있으신 분은 festa를 확인해주세요.
     3/7 [MODUPOP] StableDiffusion과 ChatGPT
     신청링크 : https://festa.io/events/3163