In [1]:
import numpy as np
from collections import namedtuple
from IPython.core.debugger import Pdb

from GameEngine import multiplayer
from QTable import qtsnake
Point = namedtuple('Point', 'x, y')

pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# defines game window size and block size, in pixels
WINDOW_WIDTH = 640
WINDOW_HEIGHT = 480
GAME_UNITS = 80

In [3]:
game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,
                                    window_height=WINDOW_HEIGHT,
                                    units=GAME_UNITS,
                                    g_speed=35,
                                    s_size=1)

In [4]:
p1 = game_engine.add_player()
game_engine.start_game()
game_engine.toggle_draw()

Game starting with 1 players.
Draw is now False.


In [5]:
danger_states = 16
goal_relations = 8
actions = 4
q = np.zeros((danger_states,
              goal_relations,
              actions))

In [6]:
def sense_danger(head, tail, units=GAME_UNITS, window_width=WINDOW_WIDTH, window_height=WINDOW_HEIGHT):
    '''
    Given a player's head and tail,
    returns a list of actions that does
    not result in immediate death
    '''
    danger_array = np.array([
        head.y-units <  0             or Point(head.x, head.y-units) in tail[1:], # up
        head.x+units >= window_width  or Point(head.x+units, head.y) in tail[1:], # right
        head.y+units >= window_height or Point(head.x, head.y+units) in tail[1:], # down
        head.x-units < 0              or Point(head.x-units, head.y) in tail[1:], # left
    ])

    binary_string = "".join(map(str, map(int, danger_array)))

    return int(binary_string, base=2)

In [7]:
x_b = 5; y_b = 5;
head = Point(3,3); tail = [head]
assert sense_danger(head,tail,1,x_b,y_b) == 0

head = Point(0,0); tail = [head, Point(0,1)]
assert sense_danger(head,tail,1,x_b,y_b) == 11

head = Point(3,3); tail = [head, Point(3,4)]
assert sense_danger(head,tail,1,x_b,y_b) == 2

head = Point(3,3); tail = [head, Point(2,3)]
assert sense_danger(head,tail,1,x_b,y_b) == 1

In [8]:
def index_actions(q, id):
    '''
    given q, player_id, an array of heads,
    and the goal position,
    indexes into the corresponding expected
    reward of each action
    '''
    heads, tails, goal = game_engine.get_heads_tails_and_goal()
    state = np.array([sense_danger(heads[id], tails), qtsnake.sense_goal(heads[id], goal)])
    return state, q[state[0], state[1], :]

In [9]:
def pick_greedy_action(q, id, epsilon):
    state, rewards = index_actions(q, id)

    if np.random.uniform() < epsilon:
        return np.append(state, [np.random.choice([0, 1, 2, 3])])
    action = np.argmin(rewards)
    return np.append(state, [action])

In [10]:
def update_q(q, old_state_action, new_state_action, outcome, lr=0.05):
    if outcome == multiplayer.CollisionType.GOAL:
        q[new_state_action[0], new_state_action[1], new_state_action[2]] = 0
    elif outcome == multiplayer.CollisionType.DEATH:
        q[new_state_action[0], new_state_action[1], new_state_action[2]] = 500
    else:
        td_error = 1 + q[new_state_action[0], new_state_action[1], new_state_action[2]] - q[old_state_action[0], old_state_action[1], old_state_action[2]]
        q[old_state_action[0], old_state_action[1], old_state_action[2]] += lr * td_error

In [11]:
n_steps = 52500
epsilon = 1
final_epsilon = 0.03
epsilon_decay =  np.exp(np.log(final_epsilon) / (n_steps))
lr = 0.15

In [12]:
p1_old_s_a = pick_greedy_action(q, p1, epsilon) # state, action
game_engine.player_advance([p1_old_s_a[-1]])

for step in range(n_steps):
    p1_new_s_a = pick_greedy_action(q, p1, epsilon) # state, action
    outcome = game_engine.player_advance([p1_new_s_a[-1]], 0.05)

    update_q(q, p1_old_s_a, p1_new_s_a, outcome[p1], lr)

    epsilon *= epsilon_decay
    p1_old_s_a = p1_new_s_a

In [13]:
game_engine.toggle_draw()

Draw is now True.


In [14]:
epsilon = 0
for step in range(n_steps):
    s_a = pick_greedy_action(q, p1, epsilon)
    game_engine.player_advance([s_a[-1]])

KeyboardInterrupt: 