고성능 딥러닝 프레임워크: JAX/Flax
JAX는 Google에서 만든 라이브러리입니다. JAX/Flax의 모두팝 발표 소개와 JAX에 대해 알아봅시다.
안녕하세요? 오늘은 딥러닝계의 뉴진스! 핫한 JAX를 소개해 드릴게요.
JAX(잭스)는 Google에서 만든 라이브러리입니다. 기존에 있던 NumPy 대신 활용할 수 있고, CPU, GPU, TPU에서 코드 변경 없이 사용 가능하며 빠르다는 장점으로 머신러닝에서 활용합니다. 최근 Google Research에서 최신 모델을 JAX로 구현하는 등 활용도가 높아지고 있습니다. 지난 2월 7일 <Google과 Huggingface에서 밀고 있는 고성능 딥러닝 프레임워크: JAX/Flax>란 제목으로 박정현님께서 세미나를 하였고, 해당 영상은 모두의연구소 유튜브에서 보실 수 있습니다.
JAX란?
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산 작업을 위해 만든 프레임워크입니다. 자동미분(Autograd)이 가능하고, JIT 컴파일 방식으로 속도 측면에서 강점을 보입니다. 기존 LLM같이 거대 모델의 경우, 학습 속도에 효율적으로 활용될 수 있습니다. 실제 SOTA 모델들 가운데 활용이 늘고 있는 추세입니다. 국내 자료가 많지 않기 때문에 진입장벽이 높은 편입니다.
⭐JAX의 장점과 단점
- 장점
- NumPy와 문법이 유사하여 활용이 쉽다.
- CPU, GPU, TPU의 코드가 동일하다
- 속도가 빠르다. JIT 컴파일 방식으로 더 빠른 속도로 실행한다.
- 단점
- JAX는 functional programming를 따르기 때문에 NumPy와 다르게 유의할 점이 많다.
- 최신 모델들에 활용은 많지만, 전통적인 모델들에 대한 코드가 부족하다.
JAX/Flax를 소개한 모두팝 소개
아까 진입장벽이 높다고 하였는데, 그 벽을 낮추기 위해서 JAX/Flax LAB에서 활동하시는 박정현님께서 모두팝에서 발표를 하셨습니다. JAX에 대해 자세히 알고 싶은 분들은 🔗모두팝 영상을 시청 바랍니다.
⭐ JAX 101
JAX를 입문하시는 분들을 대상으로 발표하였습니다. JAX를 입문하기 이전에는 기본적인 머신러닝 지식과 Python 활용 능력을 갖추셔야 합니다. 추가로 JAX를 활용한 프레임워크들에 대한 소개와 Flax 101까지 함께 설명하였습니다. 생태계를 이해하고, 좀 더 편리하게 사용할 수 있는 Flax를 알 수 있었습니다.
⭐JAX/Flax LAB 소개
JAX/Flax LAB은 기존에 딥러닝을 연구하는 분들과 JAX와 Flax라는 새로운 프레임워크를 연구합니다. 영어로 작성된 튜토리얼을 함께 번역하거나 TensorFlow나 PyTorch를 이용하고 있는 분들께서 활용할 수 있도록 JAX/Flax로 변환하는 프로젝트를 진행하고 있습니다. 국내 JAX생태계를 구축하여 많은 사람이 JAX를 활용할 수 있도록 하고자 합니다. 잭스 Document 한국어 번역에서는 공식 문서를 번역하고 있습니다. 번역과정에서 그라운드 룰을 정하거나 업무 분담을 하고 용어를 정리하는 등 여러 검수 절차를 갖춰서 진지하게 번역에 임하고 있습니다. JAX/Flax 코드 변환에서는 최신 모델에 대한 변환은 많지만, 기본적인 모델들에 대한 Reference를 추가할 예정입니다. 앞으로의 활동이 궁금하신 분들은 모두의연구소 커뮤니티나 JAX KR을 방문해주세요.
JAX를 공부하면 좋은 자료 & Reference
- 공식문서 : https://jax.readthedocs.io
- 깃허브 : https://github.com/google/jax
- 추천자료 : Get started with JAX
- 강력추천자료 : 고성능 딥러닝 프레임워크 JAX/Flax
📍모두의연구소 모두팝(MODUPOP)에서 양질의 세미나가 매주 진행됩니다. 재미난 주제에 관심 있으신 분은 festa를 확인해주세요.
3/7 [MODUPOP] StableDiffusion과 ChatGPT
신청링크 : https://festa.io/events/3163