diff options
author | bd-912 <bdunahu@gmail.com> | 2023-11-15 12:47:42 -0700 |
---|---|---|
committer | bd-912 <bdunahu@gmail.com> | 2023-11-15 12:47:42 -0700 |
commit | ced115881e2f60afe41373b6da899f5d2f5403e0 (patch) | |
tree | 99d323876a5c32b3b138102440b4a1c6c12ac7be /revised_get_viable_actions.ipynb | |
parent | a2b56742da7b30afa00f33c9a806fa6031be68a5 (diff) |
Added new notebook concerning improvements on get_viable_actions
Diffstat (limited to 'revised_get_viable_actions.ipynb')
-rw-r--r-- | revised_get_viable_actions.ipynb | 693 |
1 files changed, 693 insertions, 0 deletions
diff --git a/revised_get_viable_actions.ipynb b/revised_get_viable_actions.ipynb new file mode 100644 index 0000000..805d630 --- /dev/null +++ b/revised_get_viable_actions.ipynb @@ -0,0 +1,693 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "83cff4a8-32f3-410f-b1cf-a5f406559d01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n", + "Hello from the pygame community. https://www.pygame.org/contribute.html\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from collections import namedtuple\n", + "import queue\n", + "from IPython.core.debugger import Pdb\n", + "from IPython.display import display, clear_output\n", + "\n", + "from GameEngine import multiplayer\n", + "from QTable import qtsnake\n", + "\n", + "Point = namedtuple('Point', 'x, y')" + ] + }, + { + "cell_type": "markdown", + "id": "ea980d7c-d430-44b1-9cb4-9f21122005ff", + "metadata": {}, + "source": [ + "### The 'get_viable_actions' function" + ] + }, + { + "cell_type": "markdown", + "id": "55cd65c9-bd4c-4b7a-8484-5e0f94742967", + "metadata": {}, + "source": [ + "#### Hamiltonian Cycles" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2640d368-eeb0-4dba-98dd-40aba5579b87", + "metadata": {}, + "outputs": [], + "source": [ + "# defines game window size and block size, in pixels\n", + "WINDOW_WIDTH = 300\n", + "WINDOW_HEIGHT = 300\n", + "GAME_UNITS = 30\n", + "S_SIZE = 3\n", + "S_START = (0 + ((S_SIZE-1) * GAME_UNITS), 0 * GAME_UNITS) #FIXME" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "68d5bf28-1a6f-4188-b078-6cb9e6e36612", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = np.zeros((WINDOW_WIDTH // GAME_UNITS,\n", + " WINDOW_HEIGHT // GAME_UNITS))\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ee4e92a3-842d-40bb-aa0b-45a2bd377e56", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "80.0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "COVERAGE_THRESHOLD = .8 * visited.size\n", + "COVERAGE_THRESHOLD" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0a1ba6af-88f0-4945-b16e-b396ccfeb88a", + "metadata": {}, + "outputs": [], + "source": [ + "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", + " window_height=WINDOW_HEIGHT,\n", + " units=GAME_UNITS,\n", + " g_speed=35,\n", + " s_size=S_SIZE)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ba66417b-90a3-42f5-b6b1-9d87361d1a7e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Game starting with 1 players.\n" + ] + } + ], + "source": [ + "p1 = game_engine.add_player(S_START)\n", + "game_engine.start_game()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "11393601-b055-4371-ac4f-4f0079513f8d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[<CollisionType.NONE: 2>]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "game_engine.player_advance([1])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "93f5c3d9-4457-4fb9-9dc7-a1947a0859e5", + "metadata": {}, + "outputs": [], + "source": [ + "def get_visited(unreachable, width=WINDOW_WIDTH,\n", + " height=WINDOW_HEIGHT, units=GAME_UNITS):\n", + " '''\n", + " given a numpy array corresponding to grid\n", + " and a list of tails,\n", + " marks unreachable places as visited\n", + " '''\n", + " visited = np.zeros((height//units, width//units))\n", + " for node in unreachable:\n", + " visited[node.y//units, node.x//units] = 1\n", + " return visited" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bc03d77a-ff75-4fc7-9dfe-fd851e10077b", + "metadata": {}, + "outputs": [], + "source": [ + "def is_instant_death(visited, expansion):\n", + " if (min(expansion) < 0 or\n", + " expansion.y >= visited.shape[0] or\n", + " expansion.x >= visited.shape[1] or\n", + " visited[expansion.y, expansion.x] != 0):\n", + " return True\n", + " return False" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2fbc7f1c-f55f-4360-9dba-88d169a1ea0a", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_successors(visited, frontier, units=GAME_UNITS):\n", + " '''\n", + " a generator function used to generate\n", + " every new reachable state\n", + "\n", + " actions corresponding to displacement\n", + " 0 1 2 3\n", + " UP, RIGHT, DOWN, LEFT\n", + " '''\n", + " actions = [Point(0, -units), Point(units, 0),\n", + " Point(0, units), Point(-units, 0)]\n", + " all_actions = [0, 1, 2, 3]\n", + " expansion = None\n", + "\n", + " for action in all_actions:\n", + " ''' calculate new position '''\n", + " expansion = Point((frontier.x + actions[action].x) // units,\n", + " (frontier.y + actions[action].y) // units)\n", + " if not is_instant_death(visited, expansion):\n", + " yield expansion, action" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "b3914dc1-e888-439e-b21e-1151f7e29ea9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 0.],\n", + " [0., 0.]])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(0,0)], 2, 2, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "175ab587-f779-495d-9149-9e492b6872b5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(Point(x=1, y=1), 2)\n" + ] + } + ], + "source": [ + "for successor in generate_successors(visited, Point(1,0), 1):\n", + " print(successor)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "d3b9ce23-b7a8-4683-bbc0-58a5a9cce7da", + "metadata": {}, + "outputs": [], + "source": [ + "ACTION_NAMES = [\"UP\", \"RIGHT\", \"DOWN\", \"LEFT\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b802a66f-b19b-427e-8065-420c397bc87a", + "metadata": {}, + "outputs": [], + "source": [ + "def breadth_first_search(visited, frontier, threshold=COVERAGE_THRESHOLD, units=1, verbose=False):\n", + " '''\n", + " A general search algorithm which expands nodes determined by the frontier.\n", + " '''\n", + "\n", + " while not frontier.empty():\n", + " curr_node = frontier.get()\n", + " visited[curr_node.y, curr_node.x] = 1\n", + " if np.count_nonzero(visited) > threshold:\n", + " return True\n", + " for child, action in generate_successors(visited, curr_node, units):\n", + " if visited[child.y, child.x] == 0:\n", + " visited[child.y, child.x] = 1\n", + " if verbose:\n", + " print(f'From frontier {curr_node}, found child {child} from action {ACTION_NAMES[action]}.')\n", + " print(visited)\n", + " frontier.put(child)\n", + "\n", + " if verbose:\n", + " print(f'Could not expand enough children!')\n", + " return False\t\t\t\t\t\t# return failure" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e762fc34-6376-428d-a1be-4ee8b68670a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 1., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.],\n", + " [0., 0., 0., 0.]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(0,0), Point(1,0)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "588e3802-65f3-4dc9-bfdc-da5d1ad59ae3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=2, y=0), found child Point(x=3, y=0) from action RIGHT.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=0), found child Point(x=2, y=1) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=3, y=0), found child Point(x=3, y=1) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=1), found child Point(x=2, y=2) from action DOWN.\n", + "[[1. 1. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=2, y=1), found child Point(x=1, y=1) from action LEFT.\n", + "[[1. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 0. 0.]]\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "frontier = queue.Queue()\n", + "frontier.put(Point(2,0))\n", + "breadth_first_search(visited, frontier, threshold=6, units=1, verbose=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "24cd517a-05e4-49db-8967-e44af7fb0892", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(2,3)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "0a904a60-f413-4351-aa8b-ecb1139a5b0d", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=3, y=3), found child Point(x=3, y=2) from action UP.\n", + "[[0. 0. 1. 0.]\n", + " [0. 0. 1. 0.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=3, y=2), found child Point(x=3, y=1) from action UP.\n", + "[[0. 0. 1. 0.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=3, y=1), found child Point(x=3, y=0) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "Could not expand enough children!\n", + "action RIGHT: bad\n", + "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", + "[[0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", + "[[0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", + "[[0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "action LEFT: good\n" + ] + } + ], + "source": [ + "head = Point(2,3)\n", + "for successor, action in generate_successors(visited, head, units=1):\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited, frontier, threshold=12, units=1, verbose=True):\n", + " print(f'action {ACTION_NAMES[action]}: good')\n", + " else:\n", + " print(f'action {ACTION_NAMES[action]}: bad')" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "23d6abad-e052-41e7-9685-0ab22928cd8f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1., 0., 1., 1.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 1., 0.],\n", + " [0., 0., 0., 0.]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(3,0), Point(0,0)], 4, 4, 1)\n", + "visited" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "96e64e4a-1844-4ed6-a132-282e69afddb6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "From frontier Point(x=3, y=1), found child Point(x=3, y=2) from action DOWN.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 0.]]\n", + "From frontier Point(x=3, y=2), found child Point(x=3, y=3) from action DOWN.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 0. 1.]]\n", + "From frontier Point(x=3, y=3), found child Point(x=2, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]]\n", + "From frontier Point(x=2, y=3), found child Point(x=1, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", + "[[1. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", + "[[1. 0. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=1), found child Point(x=1, y=0) from action UP.\n", + "[[1. 1. 1. 1.]\n", + " [0. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "From frontier Point(x=1, y=1), found child Point(x=0, y=1) from action LEFT.\n", + "[[1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]\n", + " [1. 1. 1. 1.]]\n", + "action DOWN: good\n" + ] + } + ], + "source": [ + "head = Point(3,0)\n", + "for successor, action in generate_successors(visited, head, units=1):\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited, frontier, threshold=15, units=1, verbose=True):\n", + " print(f'action {ACTION_NAMES[action]}: good')\n", + " else:\n", + " print(f'action {ACTION_NAMES[action]}: bad')" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "79a82894-7d07-4e56-94e7-01765750f97e", + "metadata": {}, + "outputs": [], + "source": [ + "def get_viable_actions(id):\n", + " heads, tails, _ = game_engine.get_heads_tails_and_goal()\n", + " visited = get_visited(heads + tails)\n", + "\n", + " viable_actions = []\n", + " valid_actions = []\n", + "\n", + " for successor, action in generate_successors(visited, heads[id]):\n", + " valid_actions.append(action)\n", + " frontier = queue.Queue()\n", + " frontier.put(successor)\n", + " if breadth_first_search(visited.copy(), frontier):\n", + " viable_actions.append(action)\n", + " if np.count_nonzero(visited) >= COVERAGE_THRESHOLD:\n", + " print(f'Coverage reached!')\n", + " print(f'No more smart actions left!')\n", + " if len(viable_actions) == 0:\n", + " return valid_actions\n", + " return viable_actions" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "a5c7e6a8-fd69-4e71-a016-d404a1b99906", + "metadata": {}, + "outputs": [], + "source": [ + "superior_table = qtsnake.load_q('superior_qt.npy')" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "0b45d83d-6a8e-4af9-97a7-28c9d965e52a", + "metadata": {}, + "outputs": [], + "source": [ + "class QSnakeImproved(qtsnake.QSnake):\n", + "\n", + " def __init__(self, game_engine):\n", + " super().__init__(game_engine)\n", + "\n", + " ''' override '''\n", + " def pick_greedy_action(self, q, id, epsilon=0):\n", + " viable_actions = get_viable_actions(id)\n", + " state, rewards = self.index_actions(q, id)\n", + "\n", + " if np.random.uniform() < epsilon:\n", + " return (state, np.random.choice(viable_actions)) if viable_actions.size > 0 else (state, 0)\n", + " for action in self.argmin_gen(rewards):\n", + " if action in viable_actions:\n", + " return (state, action)\n", + " return (state, 0) # death" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "0c3e03aa-0c54-4db7-bbc8-99fe4841f310", + "metadata": {}, + "outputs": [], + "source": [ + "q_table = QSnakeImproved(game_engine)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "3b7551d3-d70b-4243-b445-e9dcf1c4e39c", + "metadata": {}, + "outputs": [], + "source": [ + "for step in range(2000):\n", + " _, p1_action = q_table.pick_greedy_action(superior_table, p1)\n", + " game_engine.player_advance([p1_action])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |