지난 시간에는 보이는 게임화면을 디지털화(?)하는 데 성공했다.
이제 우리는 이 단순화된 화면을 가지고 다양한 방법의 알고리즘을 만들고 돌리려고 한다.
이전까지 하던 것은 약간 복잡한? 전처리였다.
이제는 진짜 머리를 써야하는 부분을 들어가야 할 것이다.
첫번째로는 강화학습(Reinforcement Learning)을 돌려보려고 한다.
강화학습을 위해서는 다양한 라이브러리가 사용될 수 있다.
본인은 pytorch에 매우 익숙해져있기 때문에
pytorch를 사용하여 reinforcement learning을 구축한 라이브러리인 stable baselines 3를 사용하려고 한다.
무엇이든지, 처음 보는 라이브러리를 잘 사용하기 위해서는 공식 홈페이지를 가야한다.
https://stable-baselines3.readthedocs.io/en/master/
Stable-Baselines3 Docs - Reliable Reinforcement Learning Implementations — Stable Baselines3 1.2.0a0 documentation
© Copyright 2020, Stable Baselines3 Revision 2fa06ae8.
stable-baselines3.readthedocs.io
공식 홈페이지의 문서가 영어고 너무 많아서 보기 힘들다면,
최소한 Getting Started와 Examples만 보고 넘어가도록 하자.
나머지는 쓰다가 모르는 것이 나올 때 보면 될 일이다.
강화학습이란 강화학습 모델(agent)이 환경(environment)과 소통하면서 reward가 높아지는 방향으로 학습하는 방법이다.
위 그림처럼 agent는 특정 행동(action)을 선택하고, Environment는 그 행동에 따라서 변화된 상태(state)와 보상(reward)
를 agent에게 알려준다.
Bellman Equation은 모든 강화학습이 기초로 하고 있는 공식이다.
s는 상태이고 s'은 action 에 의해 변화하는 다음 상태이다.
R은 현재 상태 s 에서 a라는 행동을 취했을 때 받을 수 있는 보상이다.
max_a는 괄호 안의 값들을 가능한 모든 action에 대해서 계산하였을 때 가장 큰 값을 의미한다.
V는 특정 상태 s의 가치(?)라고 할 수 있다.
즉, 저 공식은 특정 상태의 가치를 정의하기 위한 식이다.
모든 상태에 대한 가치와 특정 상태에서 어떤 action을 취했을 때 어떤 상태로 가는 지(state transition)에 대해서 완벽히 알 수만 있다면,
우리는 그 환경에 대해서 보상을 최대로 할 수 있는 행동을 취할 수 있는 것이다.
보통은 모든 s에 대한 V 함수값을 완벽히 구하는 것이 불가능하기에, Q라는 함수를 이용하여 간접적으로 구한다.
Q 라는 함수는 가치를 추정하는 함수로서, s,a를 인자로 받는다. 특정상황에서 특정 행동을 취했을 때 받을 수 있을 것으로 예상되는 가치라고 생각하면 된다.
Q 함수를 찾아가는 강화학습 방법을 Q-Learning이라고 한다.
위 공식에 따라 Q 함수 값들을 조금씩 업데이트 해가면 함수값들은 점점 수렴하게 되고,
결국 특정 상황에서 취해야되는 최적의 행동을 알 수 있게 되는 것이다.
강화학습 기본은 이쯤으로 마치는 게 좋겠다. 이 이상 설명하기에는 너무 길어진다.
Value Iteration, Policy Iteration, model-free, policy gradient 등에 대한 공부는 추후에 하도록 하자..
강화학습에는 위처럼 본디 Environment라는 것이 필수적이다!
따라서 본인은 MarioEnv라는 Environment class를 만들어볼 요량이다.
3개의 메소드를 필수적으로 구현해야한다.
init에서는 Observation과 action의 타입과 크기를 정의해야하고,
step에서는 특정 action이 들어왔을 때 보상, 변화된 상태(observation), 게임이 끝났는지(done) 등을 리턴해야한다.
reset에서는 게임이 끝났을 떄 Environment를 재시작하기 위해 필요한 과정을 구현해야 한다.
render 메소드는 Environment의 현재 상황을 시각화 해주는 거라고 보면 된다. 그냥 마음대로 구현하자.
현재까지 진행된 MarioEnv 클래스는 아래와 같다.
import gym
from gym import spaces
import numpy as np
from MarioMemory import MarioMemory
import cv2
class MarioEnv(gym.Env):
"""Custom Environment that follows gym interface"""
metadata = {'render.modes': ['human']}
def __init__(self):
super(MarioEnv, self).__init__()
self.mm = MarioMemory()
self.action_space = spaces.MultiDiscrete([5,2,2])
self.observation_space = spaces.Box(low=0, high=255,
shape=(3, 13*16, 256), dtype=np.uint8)
self.prev_mario_x = 0
def step(self, action):
self.mm.control(action)
self.mm.advance_frame()
observation = self.mm.get_scene()
observation = observation.transpose(2,0,1)
reward = 0
xs = self.mm.get_xs()
mario_x = xs[0]
d = mario_x - self.prev_mario_x
reward += d
self.prev_mario_x = mario_x
t = -1
reward += t
death = self.mm.get_death()
if death:
reward -= 15
done = death
info = {}
self.action = action
self.observation = observation
self.reward = reward
self.render()
return observation, reward, done, info
def reset(self):
self.mm.reset()
observation = self.mm.get_scene()
observation = observation.transpose(2,0,1)
return observation # reward, done, info can't be included
def render(self, mode='human'):
print(self.action)
scene = self.mm.get_scene()
cv2.imshow('', scene)
cv2.waitKey(1)
action space는 MultiDiscrete로 화살표, A키, B키의 입력을 모두 받을 수 있도록 하였다.
step에서는 한 프레임 이동 후 observation과 reward를 계산한다.
observation은 위에서 디지털화 했던 화면을 리턴하고
reward는 총 3개의 값의 합으로 계산하였다.
1. 죽음 페널티
2. 마리오가 오른쪽으로 이동한 정도
3. 시간이 감소한 정도.
시간은 항상 1프레임당 1씩 감소하므로 -1,
죽음 페널티는 -15
마리오의 이동정도는 현재 x 값 - 이전 x 값으로 계산하였다.
위 reward는 gym-super-mario-bros 를 참고하였다.
사실상 강화학습에 필요한건 Env 가 전부 이므로 구현은 끝났다고 보면 된다.
이제 실행해보자.
action과 observation그리고 게임 화면을 순서대로 놓아보았다.
마리오가 앞으로 전진하는 것처럼 보이지만 강화학습 초기의 random action에 불과하다.
이론상 그냥 이대로 계속 두면 언젠가 학습은 되긴 할 것이다...
근데 인간적으로 너무 느리다.
약 4 FPS밖에 안되는 것으로 나오는 데, 이는 1프레임 앞으로 전진시키는 커맨드가 VirtuaNES에 입력되는 시간이 너무 길었기 때문에 발생한 일이다.
VirtuaNES에서는 스페이스바를 누르면 1프레임 앞으로 이동한다.
하지만 스페이스바를 1초에 10번씩 누른다고해서 10프레임이 이동하지 않는다.
10번 누른 것들 중 4개만 인식된다고 보면 된다.
약 두시간에 걸친 10000 step의 학습 완료 후
오른쪽 버튼과 A, B 버튼을 누르는 상태가 되어버렸다. 어느정도 앞으로 가는데 성공했지만 첫번째 파이프에서 막혀버린다.
5부에서 계속...
'기타 잡 코딩' 카테고리의 다른 글
iptables 명령어란? (0) | 2023.03.19 |
---|---|
슈퍼마리오 1을 플레이하는 AI를 만드는 방법(3) (0) | 2021.07.10 |
슈퍼마리오 1을 플레이하는 AI를 만드는 방법(2) (0) | 2021.07.09 |
슈퍼마리오 1을 플레이하는 AI를 만드는 방법 (0) | 2021.07.07 |
폴란드 테트리스를 플레이하는 AI를 만드는 방법(2) (0) | 2021.07.06 |