AIㆍML / 개발자

"가속기를 단 넘파이" 구글 JAX 시작하기

Martin Helle | InfoWorld 2022.07.20
인기 있는 오픈소스 텐서플로우 머신러닝 플랫폼을 이끄는 혁신 중에서 자동 미분(오토그라드)과 XLA(가속 선형 대수)가 있다. 딥러닝을 위한 컴파일러를 최적화하는 기술이다. 구글 JAX는 이 2가지를 결합하는 또 다른 프로젝트로, 속도와 성능 측면에서 상당한 이점이 있다. GPU 또는 TPU에서 실행하면 JAX는 넘파이를 호출하는 다른 프로그램을 대체할 수 있으며 실행 속도도 더 빠르다. 또한 신경망에 JAX를 사용하면 텐서플로우와 같은 큰 프레임워크를 확장하는 것보다 더 쉽게 새로운 기능을 추가할 수 있다.

여기서는 구글 JAX의 이점과 한계, 설치 방법을 포함한 개요를 소개하고 코랩(Colab)에서 구글 JAX를 간단하게 체험해 보자.
 

오토그라드의 정의

오토그라드(Autograd)는 라이언 애덤스의 하버드 지능형 확률 시스템 그룹(Harvard Intelligent Probabilistic Systems Group)의 연구 프로젝트로 시작된 자동 미분 엔진이다. 현재 이 엔진은 유지보수는 되고 있지만 활동적으로 개발은 이뤄지지 않고 있다. 대신 오토그라드 개발진은 XLA JIT 컴파일과 같은 추가 기능을 오토그라드와 결합하는 구글 JAX 개발에 참여하고 있다. 오토그라드 엔진은 네이티브 파이썬 및 넘파이 코드를 자동으로 미분한다. 주 응용 영역은 경사도 기반 최적화다.

텐서플로우의 tf.GradientTape?API는 오토그라드와 비슷한 개념을 기반으로 하지만 구현 방식은 다르다. 오토그라드는 완전히 파이썬으로 작성되며 함수에서 바로 경사도를 계산하는 반면 텐서플로우의 경사도 테이프 기능은 C++로 작성되며 파이썬 래퍼를 사용한다. 텐서플로우는 역전파를 사용해 손실 차이를 계산하고 손실의 경사도를 추정하고 최선의 다음 단계를 예측한다.
 

XLA의 정의

XLA는 텐서플로우가 개발한 선형 대수를 위한 도메인별 컴파일러다. 텐서플로우 문서에 따르면 XLA는 소스 코드 변경 없이 텐서플로우 모델을 가속화하고 속도와 메모리 사용을 개선할 수 있다. 한 예로, 2020년 구글 BERT MLPert 벤치마크에서 XLA를 사용한 8개의 볼타(Volta) V100 GPU가 성능 7배, 배치 크기는 5배 개선되는 결과를 냈다.

XLA는 텐서플로우 그래프를 주어진 모델에 맞추어 생성되는 계산 커널 시퀀스로 컴파일한다. 이러한 커널은 그 모델에 고유하므로 모델에 특화된 정보를 최적화에 활용할 수 있다. XLA는 텐서플로우 내에서 JIT(Just-In-Time) 컴파일러로도 불린다. 다음과 같이 @tf.function 파이썬 데코레이터에서 플래그를 사용해 활성화할 수 있다.
 

@tf.function(jit_compile=True)


또한 TF_XLA_FLAGS 환경 변수를 설정하거나 독립적인 tfcompile 툴을 실행하는 방법으로도 텐서플로우에서 XLA를 활성화할 수 있다. 텐서플로우 외에 XLA 프로그램을 생성할 수 있는 것은 구글 JAX, 줄리아(Julia), 파이토치, Nx 등이 있다.
 

구글 JAX 시작하기

필자가 살펴본 코랩의 JAX 퀵스타트는 기본적으로 GPU를 사용한다. TPU를 선호한다면 TPU를 사용하도록 선택할 수 있지만 TPU 월별 무료 사용량은 제한돼 있다. 또한 코랩 TPU를 구글 JAX에 사용하려면 별도의 초기화 작업을 해야 한다.

퀵스타트를 시작하려면 JAX의 병렬 평가 문서 페이지 상단에 있는 코랩에서 열기(Open in Colab)를 누른다. 그러면 라이브 노트북 환경으로 전환된다. 노트북의 연결(Connect) 버튼을 눌러 호스팅되는 런타임에 연결한다. GPU에서 퀵스타트를 실행한 결과 JAX가 행렬 대수와 선형 대수 연산을 얼마나 가속화할 수 있는지 명확하게 볼 수 있었다. 나중에 노트북에서 마이크로초 단위로 측정된 JIT 가속 시간을 확인했다. 코드를 보면 연상이 되겠지만 대부분이 딥러닝에서 사용되는 일반적인 함수를 표현한다.
 
구글 JAX 퀵스타터의 매트릭스 매스 예제
 

JAX 설치하기

JAX 설치는 운영체제, 선택한 CPU와 GPU, TPU 버전에 따라 다르다. CPU의 경우 간단하다. 예를 들어 노트북에서 JAX를 실행하려면 다음을 입력한다.
 

pip install --upgrade pip
pip install --upgrade "jax[cpu]"


GPU의 경우 CUDACuDNN이 설치돼 있어야 하고 호환되는 엔비디아 드라이버도 필요하다. 두 가지 모두 어느 정도 최신 버전이 필요하다. 최신 버전의 CUDA와 CuDNN이 포함된 리눅스에서는 사전 구축된 CUDA 호환 휠을 설치할 수 있다. 그 외의 경우 소스에서 빌드해야 한다. JAX는 구글 클라우드 TPU를 위해 사전 구축된 휠도 제공한다. 클라우드 TPU는 코랩 TPU보다 더 최신이며 하위 호환되지 않지만 코랩 환경에는 이미 JAX와 적절한 TPU 지원이 포함돼 있다.
 

JAX API

JAX API에는 3가지 계층이 있다. 최상위 계층의 JAX는 넘파이 API인 jax.numpy의 미러를 구현한다. numpy로 할 수 있는 거의 모든 일을 jax.numpy로 할 수 있다. jax.numpy의 제약은 넘파이 배열과 달리 JAX 배열은 불변성이란 점이다. 즉, 일단 생성되면 그 내용을 변경할 수 없다.

중간 계층의 JAX API는 넘파이 계층에 비해 더 엄격하고 많은 경우 더 강력한 jax.lax다. jax.numpy의 모든 연산은 최종적으로 jax.lax에 정의된 함수 측면에서 표현된다. jax.numpy는 혼합된 데이터 유형 간의 연산을 허용하도록 암시적으로 인수를 프로모션하지만 jax.lax는 그렇지 않고 명시적인 프로모션 함수를 제공한다. API의 최하위 계층은 XLA다. 모든 jax.lax 연산은 XLA에서의 연산을 위한 파이썬 래퍼다. 모든 JAX 연산은 최종적으로 이러한 근본적인 XLA 연산의 측면에서 표현되고, 이로써 JIT 컴파일이 실현된다.
 

JAX의 한계

JAX 변환과 컴파일은 함수적으로 순수한 파이썬 함수에서만 작동한다. 함수에 부수적 효과가 있는 경우 print()문과 같이 단순한 것이라 해도 코드를 여러 번 실행하면 다른 부수적 효과가 나타난다. 이후 실행에서 print()는 다른 것을 출력하거나 아예 아무것도 출력하지 않을 수 있다. 그 외의 JAX의 제약은 제자리(in-place) 변이가 허용되지 않는다는 것이다(배열이 불변성이므로). 이 제약은 외부(out-of-place) 배열 업데이트를 허용하는 방법으로 우회할 수 있다.
 

updated_array = jax_array.at[1, :].set(1.0)


또한 넘파이는 기본적으로 배정밀도(float64)로 설정되지만 JAX는 기본 단정밀도 숫자(float32)로 설정된다. 배정밀도가 꼭 필요한 경우 JAX를 jax_enable_x64로 설정할 수 있다.
 

가속 신경망에 JAX 사용하기

지금까지 살펴본 내용을 종합하면 확실히 JAX로 가속 신경망을 구현할 수 있다. 하지만 반대로 생각해보면 굳이 새로운 방법을 사용할 필요가 있을까? 구글 리서치 그룹과 딥마인드는 JAX 기반의 여러 신경망 라이브러리를 오픈소스화했다. 플랙스(Flax)는 신경망 학습을 위한 온전한 라이브러리이며 예제와 사용법 가이드를 함께 제공한다. 하이쿠(Haiku)는 신경망 모듈에 사용되며 옵택스(Optax)는 경사도 처리와 최적화, RLax는 RL(강화 학습) 알고리즘, 첵스(chex)는 안정적인 코드 및 테스트에 사용된다.
 

JAX에 대해 더 알아보기

JAX에는 JAX 퀵스타트 외에 코랩에서 실행할 수 있는 다양한 자습서가 있다. 첫 번째 자습서는 jax.numpy 함수, grad 및 value_and_grad 함수와 @jit 데코레이터를 사용하는 방법을 보여준다. 그다음 자습서는 JIT 컴파일에 대해 더 심층적으로 다루고, 마지막 자습서에서는 단일 및 다중 호스트 환경에서 함수를 컴파일하고 자동으로 분할하는 방법을 배울 수 있다.

또한 JAX 참조 문서도 읽을 수 있고, 읽어야 하며(FAQ부터 시작) 코랩에서 고급 자습서(오토디프 쿡북(Autodiff Cookbook)부터 시작)를 실행할 수 있다. 마지막으로, 주 JAX 패키지를 시작으로 API 문서도 읽어볼 만하다.
editor@itworld.co.kr

회사명 : 한국IDG | 제호: ITWorld | 주소 : 서울시 중구 세종대로 23, 4층 우)04512
| 등록번호 : 서울 아00743 등록발행일자 : 2009년 01월 19일

발행인 : 박형미 | 편집인 : 박재곤 | 청소년보호책임자 : 한정규
| 사업자 등록번호 : 214-87-22467 Tel : 02-558-6950

Copyright © 2024 International Data Group. All rights reserved.