{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "83cff4a8-32f3-410f-b1cf-a5f406559d01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n", "Hello from the pygame community. https://www.pygame.org/contribute.html\n" ] } ], "source": [ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from collections import namedtuple\n", "import queue\n", "from IPython.core.debugger import Pdb\n", "from IPython.display import display, clear_output\n", "\n", "from GameEngine import multiplayer\n", "from QTable import qtsnake\n", "\n", "Point = namedtuple('Point', 'x, y')" ] }, { "cell_type": "markdown", "id": "ea980d7c-d430-44b1-9cb4-9f21122005ff", "metadata": {}, "source": [ "### The 'get_viable_actions' function" ] }, { "cell_type": "markdown", "id": "55cd65c9-bd4c-4b7a-8484-5e0f94742967", "metadata": {}, "source": [ "#### Hamiltonian Cycles" ] }, { "cell_type": "code", "execution_count": 2, "id": "2640d368-eeb0-4dba-98dd-40aba5579b87", "metadata": {}, "outputs": [], "source": [ "# defines game window size and block size, in pixels\n", "WINDOW_WIDTH = 300\n", "WINDOW_HEIGHT = 300\n", "GAME_UNITS = 30\n", "S_SIZE = 3\n", "S_START = (0 + ((S_SIZE-1) * GAME_UNITS), 0 * GAME_UNITS) #FIXME" ] }, { "cell_type": "code", "execution_count": 3, "id": "68d5bf28-1a6f-4188-b078-6cb9e6e36612", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "visited = np.zeros((WINDOW_WIDTH // GAME_UNITS,\n", " WINDOW_HEIGHT // GAME_UNITS))\n", "visited" ] }, { "cell_type": "code", "execution_count": 4, "id": "ee4e92a3-842d-40bb-aa0b-45a2bd377e56", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "80.0" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "COVERAGE_THRESHOLD = .8 * visited.size\n", "COVERAGE_THRESHOLD" ] }, { "cell_type": "code", "execution_count": 5, "id": "0a1ba6af-88f0-4945-b16e-b396ccfeb88a", "metadata": {}, "outputs": [], "source": [ "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", " window_height=WINDOW_HEIGHT,\n", " units=GAME_UNITS,\n", " g_speed=35,\n", " s_size=S_SIZE)" ] }, { "cell_type": "code", "execution_count": 6, "id": "ba66417b-90a3-42f5-b6b1-9d87361d1a7e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Game starting with 1 players.\n" ] } ], "source": [ "p1 = game_engine.add_player(S_START)\n", "game_engine.start_game()" ] }, { "cell_type": "code", "execution_count": 7, "id": "11393601-b055-4371-ac4f-4f0079513f8d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "game_engine.player_advance([1])" ] }, { "cell_type": "code", "execution_count": 8, "id": "93f5c3d9-4457-4fb9-9dc7-a1947a0859e5", "metadata": {}, "outputs": [], "source": [ "def get_visited(unreachable, width=WINDOW_WIDTH,\n", " height=WINDOW_HEIGHT, units=GAME_UNITS):\n", " '''\n", " given a numpy array corresponding to grid\n", " and a list of tails,\n", " marks unreachable places as visited\n", " '''\n", " visited = np.zeros((height//units, width//units))\n", " for node in unreachable:\n", " visited[node.y//units, node.x//units] = 1\n", " return visited" ] }, { "cell_type": "code", "execution_count": 9, "id": "bc03d77a-ff75-4fc7-9dfe-fd851e10077b", "metadata": {}, "outputs": [], "source": [ "def is_instant_death(visited, expansion):\n", " if (min(expansion) < 0 or\n", " expansion.y >= visited.shape[0] or\n", " expansion.x >= visited.shape[1] or\n", " visited[expansion.y, expansion.x] != 0):\n", " return True\n", " return False" ] }, { "cell_type": "code", "execution_count": 10, "id": "2fbc7f1c-f55f-4360-9dba-88d169a1ea0a", "metadata": {}, "outputs": [], "source": [ "def generate_successors(visited, frontier, units=GAME_UNITS):\n", " '''\n", " a generator function used to generate\n", " every new reachable state\n", "\n", " actions corresponding to displacement\n", " 0 1 2 3\n", " UP, RIGHT, DOWN, LEFT\n", " '''\n", " actions = [Point(0, -units), Point(units, 0),\n", " Point(0, units), Point(-units, 0)]\n", " all_actions = [0, 1, 2, 3]\n", " expansion = None\n", "\n", " for action in all_actions:\n", " ''' calculate new position '''\n", " expansion = Point((frontier.x + actions[action].x) // units,\n", " (frontier.y + actions[action].y) // units)\n", " if not is_instant_death(visited, expansion):\n", " yield expansion, action" ] }, { "cell_type": "code", "execution_count": 11, "id": "b3914dc1-e888-439e-b21e-1151f7e29ea9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0.],\n", " [0., 0.]])" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "visited = get_visited([Point(0,0)], 2, 2, 1)\n", "visited" ] }, { "cell_type": "code", "execution_count": 12, "id": "175ab587-f779-495d-9149-9e492b6872b5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(Point(x=1, y=1), 2)\n" ] } ], "source": [ "for successor in generate_successors(visited, Point(1,0), 1):\n", " print(successor)" ] }, { "cell_type": "code", "execution_count": 13, "id": "d3b9ce23-b7a8-4683-bbc0-58a5a9cce7da", "metadata": {}, "outputs": [], "source": [ "ACTION_NAMES = [\"UP\", \"RIGHT\", \"DOWN\", \"LEFT\"]" ] }, { "cell_type": "code", "execution_count": 14, "id": "b802a66f-b19b-427e-8065-420c397bc87a", "metadata": {}, "outputs": [], "source": [ "def breadth_first_search(visited, frontier, threshold=COVERAGE_THRESHOLD, units=1, verbose=False):\n", " '''\n", " A general search algorithm which expands nodes determined by the frontier.\n", " '''\n", "\n", " while not frontier.empty():\n", " curr_node = frontier.get()\n", " visited[curr_node.y, curr_node.x] = 1\n", " if np.count_nonzero(visited) > threshold:\n", " return True\n", " for child, action in generate_successors(visited, curr_node, units):\n", " if visited[child.y, child.x] == 0:\n", " visited[child.y, child.x] = 1\n", " if verbose:\n", " print(f'From frontier {curr_node}, found child {child} from action {ACTION_NAMES[action]}.')\n", " print(visited)\n", " frontier.put(child)\n", "\n", " if verbose:\n", " print(f'Could not expand enough children!')\n", " return False\t\t\t\t\t\t# return failure" ] }, { "cell_type": "code", "execution_count": 15, "id": "e762fc34-6376-428d-a1be-4ee8b68670a7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 1., 0., 0.],\n", " [0., 0., 0., 0.],\n", " [0., 0., 0., 0.],\n", " [0., 0., 0., 0.]])" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "visited = get_visited([Point(0,0), Point(1,0)], 4, 4, 1)\n", "visited" ] }, { "cell_type": "code", "execution_count": 16, "id": "588e3802-65f3-4dc9-bfdc-da5d1ad59ae3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From frontier Point(x=2, y=0), found child Point(x=3, y=0) from action RIGHT.\n", "[[1. 1. 1. 1.]\n", " [0. 0. 0. 0.]\n", " [0. 0. 0. 0.]\n", " [0. 0. 0. 0.]]\n", "From frontier Point(x=2, y=0), found child Point(x=2, y=1) from action DOWN.\n", "[[1. 1. 1. 1.]\n", " [0. 0. 1. 0.]\n", " [0. 0. 0. 0.]\n", " [0. 0. 0. 0.]]\n", "From frontier Point(x=3, y=0), found child Point(x=3, y=1) from action DOWN.\n", "[[1. 1. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 0. 0.]\n", " [0. 0. 0. 0.]]\n", "From frontier Point(x=2, y=1), found child Point(x=2, y=2) from action DOWN.\n", "[[1. 1. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 0.]\n", " [0. 0. 0. 0.]]\n", "From frontier Point(x=2, y=1), found child Point(x=1, y=1) from action LEFT.\n", "[[1. 1. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [0. 0. 1. 0.]\n", " [0. 0. 0. 0.]]\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "frontier = queue.Queue()\n", "frontier.put(Point(2,0))\n", "breadth_first_search(visited, frontier, threshold=6, units=1, verbose=True)" ] }, { "cell_type": "code", "execution_count": 17, "id": "24cd517a-05e4-49db-8967-e44af7fb0892", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[0., 0., 1., 0.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 1., 0.]])" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(2,3)], 4, 4, 1)\n", "visited" ] }, { "cell_type": "code", "execution_count": 18, "id": "0a904a60-f413-4351-aa8b-ecb1139a5b0d", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From frontier Point(x=3, y=3), found child Point(x=3, y=2) from action UP.\n", "[[0. 0. 1. 0.]\n", " [0. 0. 1. 0.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]]\n", "From frontier Point(x=3, y=2), found child Point(x=3, y=1) from action UP.\n", "[[0. 0. 1. 0.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]]\n", "From frontier Point(x=3, y=1), found child Point(x=3, y=0) from action UP.\n", "[[0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]]\n", "Could not expand enough children!\n", "action RIGHT: bad\n", "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", "[[0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [0. 1. 1. 1.]]\n", "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", "[[0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", "[[0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", "[[0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "action LEFT: good\n" ] } ], "source": [ "head = Point(2,3)\n", "for successor, action in generate_successors(visited, head, units=1):\n", " frontier = queue.Queue()\n", " frontier.put(successor)\n", " if breadth_first_search(visited, frontier, threshold=12, units=1, verbose=True):\n", " print(f'action {ACTION_NAMES[action]}: good')\n", " else:\n", " print(f'action {ACTION_NAMES[action]}: bad')" ] }, { "cell_type": "code", "execution_count": 19, "id": "23d6abad-e052-41e7-9685-0ab22928cd8f", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[1., 0., 1., 1.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 1., 0.],\n", " [0., 0., 0., 0.]])" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(3,0), Point(0,0)], 4, 4, 1)\n", "visited" ] }, { "cell_type": "code", "execution_count": 20, "id": "96e64e4a-1844-4ed6-a132-282e69afddb6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "From frontier Point(x=3, y=1), found child Point(x=3, y=2) from action DOWN.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 0. 0.]]\n", "From frontier Point(x=3, y=2), found child Point(x=3, y=3) from action DOWN.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 0. 1.]]\n", "From frontier Point(x=3, y=3), found child Point(x=2, y=3) from action LEFT.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]]\n", "From frontier Point(x=2, y=3), found child Point(x=1, y=3) from action LEFT.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 1. 1. 1.]]\n", "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [0. 1. 1. 1.]]\n", "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", "[[1. 0. 1. 1.]\n", " [0. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", "[[1. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", "[[1. 0. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=1), found child Point(x=1, y=0) from action UP.\n", "[[1. 1. 1. 1.]\n", " [0. 1. 1. 1.]\n", " [1. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "From frontier Point(x=1, y=1), found child Point(x=0, y=1) from action LEFT.\n", "[[1. 1. 1. 1.]\n", " [1. 1. 1. 1.]\n", " [1. 1. 1. 1.]\n", " [1. 1. 1. 1.]]\n", "action DOWN: good\n" ] } ], "source": [ "head = Point(3,0)\n", "for successor, action in generate_successors(visited, head, units=1):\n", " frontier = queue.Queue()\n", " frontier.put(successor)\n", " if breadth_first_search(visited, frontier, threshold=15, units=1, verbose=True):\n", " print(f'action {ACTION_NAMES[action]}: good')\n", " else:\n", " print(f'action {ACTION_NAMES[action]}: bad')" ] }, { "cell_type": "code", "execution_count": 21, "id": "79a82894-7d07-4e56-94e7-01765750f97e", "metadata": {}, "outputs": [], "source": [ "def get_viable_actions(id):\n", " heads, tails, _ = game_engine.get_heads_tails_and_goal()\n", " visited = get_visited(heads + tails)\n", "\n", " viable_actions = []\n", " valid_actions = []\n", "\n", " for successor, action in generate_successors(visited, heads[id]):\n", " valid_actions.append(action)\n", " frontier = queue.Queue()\n", " frontier.put(successor)\n", " if breadth_first_search(visited.copy(), frontier):\n", " viable_actions.append(action)\n", " if np.count_nonzero(visited) >= COVERAGE_THRESHOLD:\n", " print(f'Coverage reached!')\n", " print(f'No more smart actions left!')\n", " if len(viable_actions) == 0:\n", " return valid_actions\n", " return viable_actions" ] }, { "cell_type": "code", "execution_count": 22, "id": "a5c7e6a8-fd69-4e71-a016-d404a1b99906", "metadata": {}, "outputs": [], "source": [ "superior_table = qtsnake.load_q('superior_qt.npy')" ] }, { "cell_type": "code", "execution_count": 23, "id": "0b45d83d-6a8e-4af9-97a7-28c9d965e52a", "metadata": {}, "outputs": [], "source": [ "class QSnakeImproved(qtsnake.QSnake):\n", "\n", " def __init__(self, game_engine):\n", " super().__init__(game_engine)\n", "\n", " ''' override '''\n", " def pick_greedy_action(self, q, id, epsilon=0):\n", " viable_actions = get_viable_actions(id)\n", " state, rewards = self.index_actions(q, id)\n", "\n", " if np.random.uniform() < epsilon:\n", " return (state, np.random.choice(viable_actions)) if viable_actions.size > 0 else (state, 0)\n", " for action in self.argmin_gen(rewards):\n", " if action in viable_actions:\n", " return (state, action)\n", " return (state, 0) # death" ] }, { "cell_type": "code", "execution_count": 24, "id": "0c3e03aa-0c54-4db7-bbc8-99fe4841f310", "metadata": {}, "outputs": [], "source": [ "q_table = QSnakeImproved(game_engine)" ] }, { "cell_type": "code", "execution_count": 25, "id": "3b7551d3-d70b-4243-b445-e9dcf1c4e39c", "metadata": {}, "outputs": [], "source": [ "for step in range(2000):\n", " _, p1_action = q_table.pick_greedy_action(superior_table, p1)\n", " game_engine.player_advance([p1_action])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }