{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "85da5df2-c926-417c-bd7b-d214ad31ebe1", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n", "Hello from the pygame community. https://www.pygame.org/contribute.html\n" ] } ], "source": [ "import numpy as np\n", "from collections import namedtuple\n", "from IPython.core.debugger import Pdb\n", "\n", "from GameEngine import multiplayer\n", "from QTable import qtsnake\n", "Point = namedtuple('Point', 'x, y')" ] }, { "cell_type": "code", "execution_count": 2, "id": "9eec33d8-9a65-426a-8ad5-8ffb2cbe2541", "metadata": {}, "outputs": [], "source": [ "# defines game window size and block size, in pixels\n", "WINDOW_WIDTH = 640\n", "WINDOW_HEIGHT = 480\n", "GAME_UNITS = 80" ] }, { "cell_type": "code", "execution_count": 3, "id": "77b2d951-72be-4bee-aa73-3b87ea0288b6", "metadata": {}, "outputs": [], "source": [ "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", " window_height=WINDOW_HEIGHT,\n", " units=GAME_UNITS,\n", " g_speed=35,\n", " s_size=1)" ] }, { "cell_type": "code", "execution_count": 4, "id": "93ce92f3-b336-4aca-a8da-ebd913e9d16b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Game starting with 1 players.\n", "Draw is now False.\n" ] } ], "source": [ "p1 = game_engine.add_player()\n", "game_engine.start_game()\n", "game_engine.toggle_draw()" ] }, { "cell_type": "code", "execution_count": 5, "id": "78ad0c30-63e7-4689-9332-6e9540a67105", "metadata": {}, "outputs": [], "source": [ "danger_states = 16\n", "goal_relations = 8\n", "actions = 4\n", "q = np.zeros((danger_states,\n", " goal_relations,\n", " actions))" ] }, { "cell_type": "code", "execution_count": 6, "id": "09804da3-ef32-4265-a63d-7d4ea7e59031", "metadata": {}, "outputs": [], "source": [ "def sense_danger(head, tail, units=GAME_UNITS, window_width=WINDOW_WIDTH, window_height=WINDOW_HEIGHT):\n", " '''\n", " Given a player's head and tail,\n", " returns a list of actions that does\n", " not result in immediate death\n", " '''\n", " danger_array = np.array([\n", " head.y-units < 0 or Point(head.x, head.y-units) in tail[1:], # up\n", " head.x+units >= window_width or Point(head.x+units, head.y) in tail[1:], # right\n", " head.y+units >= window_height or Point(head.x, head.y+units) in tail[1:], # down\n", " head.x-units < 0 or Point(head.x-units, head.y) in tail[1:], # left\n", " ])\n", "\n", " binary_string = \"\".join(map(str, map(int, danger_array)))\n", "\n", " return int(binary_string, base=2)" ] }, { "cell_type": "code", "execution_count": 7, "id": "bfca56cf-0cef-4523-8e4a-f3254180e397", "metadata": {}, "outputs": [], "source": [ "x_b = 5; y_b = 5;\n", "head = Point(3,3); tail = [head]\n", "assert sense_danger(head,tail,1,x_b,y_b) == 0\n", "\n", "head = Point(0,0); tail = [head, Point(0,1)]\n", "assert sense_danger(head,tail,1,x_b,y_b) == 11\n", "\n", "head = Point(3,3); tail = [head, Point(3,4)]\n", "assert sense_danger(head,tail,1,x_b,y_b) == 2\n", "\n", "head = Point(3,3); tail = [head, Point(2,3)]\n", "assert sense_danger(head,tail,1,x_b,y_b) == 1" ] }, { "cell_type": "code", "execution_count": 8, "id": "45d9645a-b217-4812-8b35-2a154518087d", "metadata": {}, "outputs": [], "source": [ "def index_actions(q, id):\n", " '''\n", " given q, player_id, an array of heads,\n", " and the goal position,\n", " indexes into the corresponding expected\n", " reward of each action\n", " '''\n", " heads, tails, goal = game_engine.get_heads_tails_and_goal()\n", " state = np.array([sense_danger(heads[id], tails), qtsnake.sense_goal(heads[id], goal)])\n", " return state, q[state[0], state[1], :]" ] }, { "cell_type": "code", "execution_count": 9, "id": "b8fc49f8-fc8e-4c0b-92f2-a345200ee26d", "metadata": {}, "outputs": [], "source": [ "def pick_greedy_action(q, id, epsilon):\n", " state, rewards = index_actions(q, id)\n", "\n", " if np.random.uniform() < epsilon:\n", " return np.append(state, [np.random.choice([0, 1, 2, 3])])\n", " action = np.argmin(rewards)\n", " return np.append(state, [action])" ] }, { "cell_type": "code", "execution_count": 10, "id": "9dba22f2-a8fa-400c-8c71-fec7e988489c", "metadata": {}, "outputs": [], "source": [ "def update_q(q, old_state_action, new_state_action, outcome, lr=0.05):\n", " if outcome == multiplayer.CollisionType.GOAL:\n", " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 0\n", " elif outcome == multiplayer.CollisionType.DEATH:\n", " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 500\n", " else:\n", " td_error = 1 + q[new_state_action[0], new_state_action[1], new_state_action[2]] - q[old_state_action[0], old_state_action[1], old_state_action[2]]\n", " q[old_state_action[0], old_state_action[1], old_state_action[2]] += lr * td_error" ] }, { "cell_type": "code", "execution_count": 11, "id": "bc22f0b7-e409-4fd8-8433-8e918d5632ad", "metadata": {}, "outputs": [], "source": [ "n_steps = 52500\n", "epsilon = 1\n", "final_epsilon = 0.03\n", "epsilon_decay = np.exp(np.log(final_epsilon) / (n_steps))\n", "lr = 0.15" ] }, { "cell_type": "code", "execution_count": 12, "id": "c7c4417e-66f5-4028-8264-f4b732ba33a7", "metadata": {}, "outputs": [], "source": [ "p1_old_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n", "game_engine.player_advance([p1_old_s_a[-1]])\n", "\n", "for step in range(n_steps):\n", " p1_new_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n", " outcome = game_engine.player_advance([p1_new_s_a[-1]], 0.05)\n", "\n", " update_q(q, p1_old_s_a, p1_new_s_a, outcome[p1], lr)\n", "\n", " epsilon *= epsilon_decay\n", " p1_old_s_a = p1_new_s_a" ] }, { "cell_type": "code", "execution_count": 13, "id": "6538bcdc-4b4b-48ef-8c55-76834490b698", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Draw is now True.\n" ] } ], "source": [ "game_engine.toggle_draw()" ] }, { "cell_type": "code", "execution_count": 14, "id": "6e50b723-f0a4-48cd-bb98-1666ec0f59ce", "metadata": {}, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[14], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_steps):\n\u001b[1;32m 3\u001b[0m s_a \u001b[38;5;241m=\u001b[39m pick_greedy_action(q, p1, epsilon)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mgame_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplayer_advance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:118\u001b[0m, in \u001b[0;36mPlayfield.player_advance\u001b[0;34m(self, actions, noise)\u001b[0m\n\u001b[1;32m 116\u001b[0m collisions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_player_collisions()\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_draw_on:\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_ui\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m collisions\n", "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:163\u001b[0m, in \u001b[0;36mPlayfield._update_ui\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_goal\u001b[38;5;241m.\u001b[39mdraw()\n\u001b[1;32m 162\u001b[0m pg\u001b[38;5;241m.\u001b[39mdisplay\u001b[38;5;241m.\u001b[39mflip() \u001b[38;5;66;03m# full screen update\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_clock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtick\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_g_speed\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "epsilon = 0\n", "for step in range(n_steps):\n", " s_a = pick_greedy_action(q, p1, epsilon)\n", " game_engine.player_advance([s_a[-1]])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }