diff options
Diffstat (limited to 'three_revised_get_viable_actions.ipynb')
-rw-r--r-- | three_revised_get_viable_actions.ipynb | 693 |
1 files changed, 693 insertions, 0 deletions
diff --git a/three_revised_get_viable_actions.ipynb b/three_revised_get_viable_actions.ipynb new file mode 100644 index 0000000..b568ee2 --- /dev/null +++ b/three_revised_get_viable_actions.ipynb @@ -0,0 +1,693 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "83cff4a8-32f3-410f-b1cf-a5f406559d01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n", + "Hello from the pygame community. https://www.pygame.org/contribute.html\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from collections import namedtuple\n", + "import queue\n", + "from IPython.core.debugger import Pdb\n", + "from IPython.display import display, clear_output\n", + "\n", + "from GameEngine import multiplayer\n", + "from QTable import qtsnake\n", + "\n", + "Point = namedtuple('Point', 'x, y')" + ] + }, + { + "cell_type": "markdown", + "id": "ea980d7c-d430-44b1-9cb4-9f21122005ff", + "metadata": {}, + "source": [ + "### The 'get_viable_actions' function\n", + "\n", + "The following notebook is my attempt to create an even more sophisticated get_viable_actions function. I used a breadth-first search to determine if each move allows the snake to reach at least COVERAGE_THRESHOLD% of the playfield. If it does, then the action is considered safe, and the greedy_action_selector is allowed to choose it.\n", + "\n", + "Often, this approach was defeated by the snake cutting the remaining reachable areas in half during an action, turning off the safe exploration before the snake had filled up the screen as I had intended.\n", + "\n", + "While this approach could likely be fixed in the future, I decided to leave it as is, as it would have only been marginally better.\n", + "\n", + "The best known solution to the snake problem is Hamiltonion cycles, not reinforcement learning. I see future iterations of reinforcement learning on the snake learning problem as mostly a dead-end." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2640d368-eeb0-4dba-98dd-40aba5579b87", + "metadata": {}, + "outputs": [], + "source": [ + "# defines game window size and block size, in pixels\n", + "WINDOW_WIDTH = 300\n", + "WINDOW_HEIGHT = 300\n", + "GAME_UNITS = 30\n", + "S_SIZE = 3\n", + "S_START = (0 + ((S_SIZE-1) * GAME_UNITS), 0 * GAME_UNITS) #FIXME" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "68d5bf28-1a6f-4188-b078-6cb9e6e36612", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = np.zeros((WINDOW_WIDTH // GAME_UNITS,\n", + " WINDOW_HEIGHT // GAME_UNITS))\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ee4e92a3-842d-40bb-aa0b-45a2bd377e56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "COVERAGE_THRESHOLD = .8 * visited.size\n", + "COVERAGE_THRESHOLD" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0a1ba6af-88f0-4945-b16e-b396ccfeb88a", + "metadata": {}, + "outputs": [], + "source": [ + "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", + " window_height=WINDOW_HEIGHT,\n", + " units=GAME_UNITS,\n", + " g_speed=35,\n", + " s_size=S_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ba66417b-90a3-42f5-b6b1-9d87361d1a7e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Game starting with 1 players.\n" + ] + } + ], + "source": [ + "p1 = game_engine.add_player(S_START)\n", + "game_engine.start_game()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "11393601-b055-4371-ac4f-4f0079513f8d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[<CollisionType.NONE: 2>]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "game_engine.player_advance([1])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "93f5c3d9-4457-4fb9-9dc7-a1947a0859e5", + "metadata": {}, + "outputs": [], + "source": [ + "def get_visited(unreachable, width=WINDOW_WIDTH,\n", + " height=WINDOW_HEIGHT, units=GAME_UNITS):\n", + " '''\n", + " given a numpy array corresponding to grid\n", + " and a list of tails,\n", + " marks unreachable places as visited\n", + " '''\n", + " visited = np.zeros((height//units, width//units))\n", + " for node in unreachable:\n", + " visited[node.y//units, node.x//units] = 1\n", + " return visited" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bc03d77a-ff75-4fc7-9dfe-fd851e10077b", + "metadata": {}, + "outputs": [], + "source": [ + "def is_instant_death(visited, expansion):\n", + " if (min(expansion) < 0 or\n", + " expansion.y >= visited.shape[0] or\n", + " expansion.x >= visited.shape[1] or\n", + " visited[expansion.y, expansion.x] != 0):\n", + " return True\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2fbc7f1c-f55f-4360-9dba-88d169a1ea0a", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_successors(visited, frontier, units=GAME_UNITS):\n", + " '''\n", + " a generator function used to generate\n", + " every new reachable state\n", + "\n", + " actions corresponding to displacement\n", + " 0 1 2 3\n", + " UP, RIGHT, DOWN, LEFT\n", + " '''\n", + " actions = [Point(0, -units), Point(units, 0),\n", + " Point(0, units), Point(-units, 0)]\n", + " all_actions = [0, 1, 2, 3]\n", + " expansion = None\n", + "\n", + " for action in all_actions:\n", + " ''' calculate new position '''\n", + " expansion = Point((frontier.x + actions[action].x) // units,\n", + " (frontier.y + actions[action].y) // units)\n", + " if not is_instant_death(visited, expansion):\n", + " yield expansion, action" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b3914dc1-e888-439e-b21e-1151f7e29ea9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 0.],\n", + " [0., 0.]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(0,0)], 2, 2, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "175ab587-f779-495d-9149-9e492b6872b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Point(x=1, y=1), 2)\n" + ] + } + ], + "source": [ + "for successor in generate_successors(visited, Point(1,0), 1):\n", + " print(successor)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d3b9ce23-b7a8-4683-bbc0-58a5a9cce7da", + "metadata": {}, + "outputs": [], + "source": [ + "ACTION_NAMES = [\"UP\", \"RIGHT\", \"DOWN\", \"LEFT\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b802a66f-b19b-427e-8065-420c397bc87a", + "metadata": {}, + "outputs": [], + "source": [ + "def breadth_first_search(visited, frontier, threshold=COVERAGE_THRESHOLD, units=1, verbose=False):\n", + " '''\n", + " A general search algorithm which expands nodes determined by the frontier.\n", + " '''\n", + "\n", + " while not frontier.empty():\n", + " curr_node = frontier.get()\n", + " visited[curr_node.y, curr_node.x] = 1\n", + " if np.count_nonzero(visited) > threshold:\n", + " return True\n", + " for child, action in generate_successors(visited, curr_node, units):\n", + " if visited[child.y, child.x] == 0:\n", + " visited[child.y, child.x] = 1\n", + " if verbose:\n", + " print(f'From frontier {curr_node}, found child {child} from action {ACTION_NAMES[action]}.')\n", + " print(visited)\n", + " frontier.put(child)\n", + "\n", + " if verbose:\n", + " print(f'Could not expand enough children!')\n", + " return False\t\t\t\t\t\t# return failure" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e762fc34-6376-428d-a1be-4ee8b68670a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 1., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(0,0), Point(1,0)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "588e3802-65f3-4dc9-bfdc-da5d1ad59ae3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=2, y=0), found child Point(x=3, y=0) from action RIGHT.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=0), found child Point(x=2, y=1) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=3, y=0), found child Point(x=3, y=1) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=1), found child Point(x=2, y=2) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=1), found child Point(x=1, y=1) from action LEFT.\n", + "[[1. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]]\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "frontier = queue.Queue()\n", + "frontier.put(Point(2,0))\n", + "breadth_first_search(visited, frontier, threshold=6, units=1, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "24cd517a-05e4-49db-8967-e44af7fb0892", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(2,3)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "0a904a60-f413-4351-aa8b-ecb1139a5b0d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=3, y=3), found child Point(x=3, y=2) from action UP.\n", + "[[0. 0. 1. 0.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=3, y=2), found child Point(x=3, y=1) from action UP.\n", + "[[0. 0. 1. 0.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=3, y=1), found child Point(x=3, y=0) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "Could not expand enough children!\n", + "action RIGHT: bad\n", + "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", + "[[0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "action LEFT: good\n" + ] + } + ], + "source": [ + "head = Point(2,3)\n", + "for successor, action in generate_successors(visited, head, units=1):\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited, frontier, threshold=12, units=1, verbose=True):\n", + " print(f'action {ACTION_NAMES[action]}: good')\n", + " else:\n", + " print(f'action {ACTION_NAMES[action]}: bad')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "23d6abad-e052-41e7-9685-0ab22928cd8f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 0., 1., 1.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 0., 0.]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(3,0), Point(0,0)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "96e64e4a-1844-4ed6-a132-282e69afddb6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=3, y=1), found child Point(x=3, y=2) from action DOWN.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=3, y=2), found child Point(x=3, y=3) from action DOWN.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 1.]]\n", + "From frontier Point(x=3, y=3), found child Point(x=2, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=2, y=3), found child Point(x=1, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", + "[[1. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=1), found child Point(x=1, y=0) from action UP.\n", + "[[1. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=1), found child Point(x=0, y=1) from action LEFT.\n", + "[[1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "action DOWN: good\n" + ] + } + ], + "source": [ + "head = Point(3,0)\n", + "for successor, action in generate_successors(visited, head, units=1):\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited, frontier, threshold=15, units=1, verbose=True):\n", + " print(f'action {ACTION_NAMES[action]}: good')\n", + " else:\n", + " print(f'action {ACTION_NAMES[action]}: bad')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "79a82894-7d07-4e56-94e7-01765750f97e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_viable_actions(id):\n", + " heads, tails, _ = game_engine.get_heads_tails_and_goal()\n", + " visited = get_visited(heads + tails)\n", + "\n", + " viable_actions = []\n", + " valid_actions = []\n", + "\n", + " for successor, action in generate_successors(visited, heads[id]):\n", + " valid_actions.append(action)\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited.copy(), frontier):\n", + " viable_actions.append(action)\n", + " if np.count_nonzero(visited) >= COVERAGE_THRESHOLD:\n", + " print(f'Coverage reached!')\n", + " print(f'No more smart actions left!')\n", + " if len(viable_actions) == 0:\n", + " return valid_actions\n", + " return viable_actions" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "a5c7e6a8-fd69-4e71-a016-d404a1b99906", + "metadata": {}, + "outputs": [], + "source": [ + "superior_table = qtsnake.load_q('superior_qt.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "0b45d83d-6a8e-4af9-97a7-28c9d965e52a", + "metadata": {}, + "outputs": [], + "source": [ + "class QSnakeImproved(qtsnake.QSnake):\n", + "\n", + " def __init__(self, game_engine):\n", + " super().__init__(game_engine)\n", + "\n", + " ''' override '''\n", + " def pick_greedy_action(self, q, id, epsilon=0):\n", + " viable_actions = get_viable_actions(id)\n", + " state, rewards = self.index_actions(q, id)\n", + "\n", + " if np.random.uniform() < epsilon:\n", + " return (state, np.random.choice(viable_actions)) if viable_actions.size > 0 else (state, 0)\n", + " for action in self.argmin_gen(rewards):\n", + " if action in viable_actions:\n", + " return (state, action)\n", + " return (state, 0) # death" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "0c3e03aa-0c54-4db7-bbc8-99fe4841f310", + "metadata": {}, + "outputs": [], + "source": [ + "q_table = QSnakeImproved(game_engine)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3b7551d3-d70b-4243-b445-e9dcf1c4e39c", + "metadata": {}, + "outputs": [], + "source": [ + "for step in range(800):\n", + " _, p1_action = q_table.pick_greedy_action(superior_table, p1)\n", + " game_engine.player_advance([p1_action])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |