월드모델이란?
현재 상태와 행동이 주어졌을 때 다음 상태에 대한 확률분포를 만드는 생성모델로, 무작위로 이동하면서 그에 따른 환경적 변화를 학습하게 되면 모델이 새로운 작업에 대해 처음부터 스스로 훈련할 수 있게 되지 않을까?가 이 논문의 핵심입니다.
모델이 스스로 생성한 꿈속 세상(world)에서 수행한 실험을 통해 특정 작업을 처리하는 방법을 배우는 것을 보여주어, 생성모델링을 강화학습과 같은 다른 머신러닝 모델과 함께 적용했을 때 실용적인 문제를 해결하는 방법을 보여주는 훌륭한 사례입니다.
0. 사전 개념 : 강화 학습
월드 모델은 다양한 머신러닝 기법들이 섞여 있어서 하나하나씩 설명해보려고 합니다.
기본 개념
강화학습은 주어진 환경에서 에이전트가 특정 목적과 관련해서 최적의 성능을 발휘하는 것을 목표로 하는 머신러닝 기법입니다.
대부분의 머신러닝 모델은 손실함수를 최소화하는 것이 목적이지만 강화 학습은 주어진 환경에서 에이전트가 받을 장기간 보상을 최대화하는 것이 목적입니다.
이러한 강화학습의 핵심용어는 아래와 같습니다.
- 환경 : 에이전트가 작동하는 세상입니다. 주어진 행동이 다음 게임 상태에 미치는 영향을 결정하는 규칙으로 환경이 구성된다.
- 에이전트 : 환경에서 실제 행동을 수행하는 주체입니다.
- 상태 : 에이전트가 경험할 수 있는 특정 상황을 표현하는 데이터입니다.
- 행동 : 에이전트가 수행할 수 있는 움직임입니다.
- 보상 : 행동이 수행되고 나서 에이전트가 환경으로부터 받을 값입니다. (이 보상의 합을 최대화하는 것이 목표입니다.)
- 에피소드 : 에이전트가 1회 실행하는 것을 의미합니다.
- 타임스텝 : 이산적인 이벤트 환경을 위해 t로 표현합니다.
CarRacing 환경
강화학습은 일종의 시뮬레이션이 동작되어야하기 때문에 이런 환경을 제공하는 라이브러리가 있습니다. Gymnasium이라는 패키지에서 제공하는 CarRacing 환경을 책에서는 사용합니다.
해당 환경에서의 상태, 행동, 보상, 에피소드는 아래와 같이 정의됩니다.
- 상태 : 트랙과 자동차의 항공뷰를 나타낸 64X64픽셀 크기의 RGB 이미지입니다.
- 행동 : 방향(-1에서 1까지), 가속(0에서 1까지), 브레이크(0에서 1까지) 값의 집합입니다.
- 보상 : 타임스텝마다 음수 페널티 -0.1과 새로운 트랙 타일에 진입하면 1000/N의 양수 보상을 받습니다. N은 트랙을 구성하는 타일의 총 갯수입니다.
- 에피소드 : 자동차가 환경끝까지 주행하여 트랙을 완주하거나 3000 타임 스텝이 지나면 에피소드는 끝이 납니다.
1. 월드 모델 구조
전체적인 구조는 위의 그림과 같으며, 초록색 부분이 월드 모델이고 강화모델 관점에서 보면 행동을 수행하는 에이전트가 어떤 행동을 할지를 월드모델이 신경망을 통해 학습시킨다고 보면 됩니다.
A. VAE
우리는 운전 중에 벌어지는 상황을 결정할 때 보이는 모든 상황을 분석하지 않습니다.(기계 입장에서 보면 픽셀을 하나하나 다 보지 않습니다.) 대신 시각 정보 중에 운전에 도움이 되는 몇가지 요소를 추려서 그걸 기반으로 다음 행동을 선택합니다.
VAE 역시도 encoder 과정이 있기 때문에 유사한 작동을 하길 기다리며 월드 모델에는 해당 역할을 목적으로 VAE를 넣었습니다.
B. MDN - RNN
우리가 운전을 하는 상황을 생각해보면 현재의 도로 상황 정보로 10분 뒤의 도로 주행에 활용하지는 않습니다. 즉 직전 정보가 매우 중요하기에 이런 특성을 고려해 RNN을 활용합니다.
따라서 이 네트워크는 VAE가 전달해주는 핵심 운전정보를 기반으로 다음 상태를 예측하는 모델로 동작합니다.
C. 컨트롤러
컨트롤러는 Fully Connected Layer입니다. 앞의 VAE가 만들어낸 벡터와, MDN-RNN이 출력한 벡터를 두개 입력으로 받습니다. 그리고 회전, 가속, 브레이크라는 3개의 출력을 가집니다.
지도학습 데이터가 없기 때문에 강화학습을 사용하여 반복 실험을 통해 스스로 좋은 행동과 나쁜 행동을 찾습니다.
D. 훈련 과정
1. 랜덤한 롤아웃 데이터(시나리오)를 모읍니다. 처음 에이전트는 수행하는 작업에 상관없이 랜덤한 행동으로 환경을 탐험하고 이 때마다 상태, 행동, 보상이 저장됩니다.
2. 이를 활용해 VAE가 상태를 잠재 벡터로 표현하도록 훈련합니다. 실제 입력, 출력을 시켜보면 원본 이미지 중에 일부가 blur처리되는듯이 보여서 중요한 부분만 본다고 인지할 수 있습니다.
- VAE는 인코딩하는 대상이 정규분포에 사용할 평균과 분산입니다. 이렇게 학습시킨 평균과 분산을 그대로 디코더에 넣는게 아니라 이 두 모수를 활용해 정규분포를 만들고 여기서 샘플링을 한 값을 디코더에 넣습니다.
- 실제로 월드 모델에서 인코딩해서 나온 벡터 2개로 선형보간을 해보면 각 벡터가 어떻게 학습하도록 유도하는지를 살펴볼 수 있습니다.
3. MDN-RNN을 훈련시키기 위해 VAE를 훈련하고 수집한 샘플과, 행동과 보상을 통해 훈련합니다.
- VAE로부터 인코딩된 Z벡터(길이 32) + 현재 행동(길이 3) + 이전 보상(길이 1) 을 연결한 입력으로 학습합니다.
- 단순 RNN이 아니라 LSTM이고 이 모델의 아웃풋에 Fully Connected Layer를 붙이는 형태입니다.
- 이렇게 붙인 부분을 MDN (Mixture Density Network) 이라고 하는데 각 시퀀스 포인트에 대한 출력을 확률 분포의 혼합으로 표현합니다. 이는 단일 출력 값 대신 여러 가능한 결과를 표현할 수 있어, 불확실성이 높거나 다중 모드가 있는 데이터에 적합합니다.
- 이게 왜 들어있는지 이해하는데 좀 시간이 걸렸는데 결론적으로 RNN은 결정론적으로 특정 상태를 확신하는 반면에 시뮬레이션 상황에서는 어떻게 될지 알 수가 없다. 따라서 확률적으로 표현하기 위해서 이러한 구조를 추가했다고 이해하면 되고 밑에 나오는 보상을 계산할 때에도 다음 액션에 대해서 이 수치를 활용한다.
- 이렇게 붙인 부분을 MDN (Mixture Density Network) 이라고 하는데 각 시퀀스 포인트에 대한 출력을 확률 분포의 혼합으로 표현합니다. 이는 단일 출력 값 대신 여러 가능한 결과를 표현할 수 있어, 불확실성이 높거나 다중 모드가 있는 데이터에 적합합니다.
- Loss Function은 z벡터 재구성 손실과 보상 손실의 합으로 동작한다.
4. 컨트롤러는 2,3에서 주어진 인풋과 실제 롤아웃에서 나온 결과들을 종합해서 훈련시킵니다. 여기서 CMA-ES라는 진화 알고리즘을 사용하는데 이와 관련된 설명은 이 블로그 글 에서 보실 수 있습니다.
- 은닉층이 없는 완전 연결 신경망으로 z벡터(길이 32)와 현재 LSTM의 은닉상태(MDN의 샘플링 값이 아님)를 연결해 총 길이가 288인 벡터를 인풋으로 받습니다.
- 하지만 정확한 답(이렇게 행동하세요)이 없기 때문에 CMA-ES를 통해서 학습합니다.
- 에이전트의 개체군을 만들고 파라미터를 랜덤하게 초기화합니다.
- 다음 과정을 반복합니다.
- 환경 안에서 여러 에피소드를 수행하여 얻은 평균적인 보상으로 각 에이전트를 평가합니다.
- 가장 좋은 점수의 에이전트를 사용해 새롭게 개체를 만듭니다.
- 새로운 개체의 파라미터에 무작위성을 주입합니다.
- 새로 만든 에이전트를 추가하고 나쁜 성능의 에이전트를 제거하는 식으로 개체군을 업데이트합니다.
- 하지만 시뮬레이션이 독립적이기 때문에 병렬로 돌릴 수 있어서 속도를 빠르게 할 수 있다는 장점이 있습니다.
2. 꿈속에서 훈련하기
앞서서는 이미 주어져있는 데이터를 활용해서 훈련하는 것을 배웠다면 이제 현재 상태에서 다음 상태로 시뮬레이션을 이동하는 STEP 메소드가 필요하다.
현재 z벡터와 선택한 행동이 주어지면 MDN-RNN으로부터 샘플링하여 다음 z벡터와 보상에 대한 예측을 출력합니다. MDN-RNN은 랜덤하게 이동한 원본 데이터셋에서 실제 환경의 물리적 특성을 충분히 학습했기에 이를 가상 환경으로 만들어서 사용할 수도 잇습니다.
이 때 샘플링하는 z가 너무 같은 환경이 나오면 오버피팅이 발생할 수 있어서, temperature라는 변수를 추가해서, 분산을 크게 만들어 더 변덕스러운 상황을 자주 보여줘서 학습시켜서 이런 문제를 해결합니다.
실제 이 논문의 저자가 모델이 실제로 어떻게 돌아가는지 GIF 형태로 잘 정리해높은 사이트를 아래에서 볼 수 있다.
https://worldmodels.github.io/
'Machine Learning > 유튜브, 책, 아티클 정리' 카테고리의 다른 글
[book] 머신러닝 디자인 패턴 1. 데이터 표현 디자인 패턴 (0) | 2024.01.08 |
---|---|
만들면서 배우는 생성 AI 13장 - 멀티모달 모델 (1) | 2024.01.03 |
만들면서 배우는 생성 AI 11장 - 음악생성 (1) | 2023.12.15 |
만들면서 배우는 생성AI 10장 : 고급 GAN (0) | 2023.12.14 |
만들면서 배우는 생성AI 9장 : 트랜스포머 모델 (0) | 2023.12.12 |