diff options
author | bd-912 <bdunahu@gmail.com> | 2023-12-13 17:41:03 -0700 |
---|---|---|
committer | bd-912 <bdunahu@gmail.com> | 2023-12-13 17:41:03 -0700 |
commit | b1137269b269eed1207005828b7939efc9f557c2 (patch) | |
tree | 895338f3d2e32badfca55ee527739705135a7b1d /revised_get_viable_actions.ipynb | |
parent | ca527f085dc996c81854d8450252353986c6e82f (diff) |
Decreased size of training environments in notebooks
This allowed for more consistent training.
Also, final changes to notebook commentary.
Diffstat (limited to 'revised_get_viable_actions.ipynb')
-rw-r--r-- | revised_get_viable_actions.ipynb | 693 |
1 files changed, 0 insertions, 693 deletions
diff --git a/revised_get_viable_actions.ipynb b/revised_get_viable_actions.ipynb deleted file mode 100644 index 805d630..0000000 --- a/revised_get_viable_actions.ipynb +++ /dev/null @@ -1,693 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "83cff4a8-32f3-410f-b1cf-a5f406559d01", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n", - "Hello from the pygame community. https://www.pygame.org/contribute.html\n" - ] - } - ], - "source": [ - "import numpy as np\n", - "import matplotlib.pyplot as plt\n", - "from collections import namedtuple\n", - "import queue\n", - "from IPython.core.debugger import Pdb\n", - "from IPython.display import display, clear_output\n", - "\n", - "from GameEngine import multiplayer\n", - "from QTable import qtsnake\n", - "\n", - "Point = namedtuple('Point', 'x, y')" - ] - }, - { - "cell_type": "markdown", - "id": "ea980d7c-d430-44b1-9cb4-9f21122005ff", - "metadata": {}, - "source": [ - "### The 'get_viable_actions' function" - ] - }, - { - "cell_type": "markdown", - "id": "55cd65c9-bd4c-4b7a-8484-5e0f94742967", - "metadata": {}, - "source": [ - "#### Hamiltonian Cycles" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "2640d368-eeb0-4dba-98dd-40aba5579b87", - "metadata": {}, - "outputs": [], - "source": [ - "# defines game window size and block size, in pixels\n", - "WINDOW_WIDTH = 300\n", - "WINDOW_HEIGHT = 300\n", - "GAME_UNITS = 30\n", - "S_SIZE = 3\n", - "S_START = (0 + ((S_SIZE-1) * GAME_UNITS), 0 * GAME_UNITS) #FIXME" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "68d5bf28-1a6f-4188-b078-6cb9e6e36612", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n", - " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "visited = np.zeros((WINDOW_WIDTH // GAME_UNITS,\n", - " WINDOW_HEIGHT // GAME_UNITS))\n", - "visited" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "ee4e92a3-842d-40bb-aa0b-45a2bd377e56", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "80.0" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "COVERAGE_THRESHOLD = .8 * visited.size\n", - "COVERAGE_THRESHOLD" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "0a1ba6af-88f0-4945-b16e-b396ccfeb88a", - "metadata": {}, - "outputs": [], - "source": [ - "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", - " window_height=WINDOW_HEIGHT,\n", - " units=GAME_UNITS,\n", - " g_speed=35,\n", - " s_size=S_SIZE)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "ba66417b-90a3-42f5-b6b1-9d87361d1a7e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Game starting with 1 players.\n" - ] - } - ], - "source": [ - "p1 = game_engine.add_player(S_START)\n", - "game_engine.start_game()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "11393601-b055-4371-ac4f-4f0079513f8d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[<CollisionType.NONE: 2>]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "game_engine.player_advance([1])" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "93f5c3d9-4457-4fb9-9dc7-a1947a0859e5", - "metadata": {}, - "outputs": [], - "source": [ - "def get_visited(unreachable, width=WINDOW_WIDTH,\n", - " height=WINDOW_HEIGHT, units=GAME_UNITS):\n", - " '''\n", - " given a numpy array corresponding to grid\n", - " and a list of tails,\n", - " marks unreachable places as visited\n", - " '''\n", - " visited = np.zeros((height//units, width//units))\n", - " for node in unreachable:\n", - " visited[node.y//units, node.x//units] = 1\n", - " return visited" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "bc03d77a-ff75-4fc7-9dfe-fd851e10077b", - "metadata": {}, - "outputs": [], - "source": [ - "def is_instant_death(visited, expansion):\n", - " if (min(expansion) < 0 or\n", - " expansion.y >= visited.shape[0] or\n", - " expansion.x >= visited.shape[1] or\n", - " visited[expansion.y, expansion.x] != 0):\n", - " return True\n", - " return False" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "2fbc7f1c-f55f-4360-9dba-88d169a1ea0a", - "metadata": {}, - "outputs": [], - "source": [ - "def generate_successors(visited, frontier, units=GAME_UNITS):\n", - " '''\n", - " a generator function used to generate\n", - " every new reachable state\n", - "\n", - " actions corresponding to displacement\n", - " 0 1 2 3\n", - " UP, RIGHT, DOWN, LEFT\n", - " '''\n", - " actions = [Point(0, -units), Point(units, 0),\n", - " Point(0, units), Point(-units, 0)]\n", - " all_actions = [0, 1, 2, 3]\n", - " expansion = None\n", - "\n", - " for action in all_actions:\n", - " ''' calculate new position '''\n", - " expansion = Point((frontier.x + actions[action].x) // units,\n", - " (frontier.y + actions[action].y) // units)\n", - " if not is_instant_death(visited, expansion):\n", - " yield expansion, action" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "b3914dc1-e888-439e-b21e-1151f7e29ea9", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 0.],\n", - " [0., 0.]])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "visited = get_visited([Point(0,0)], 2, 2, 1)\n", - "visited" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "175ab587-f779-495d-9149-9e492b6872b5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(Point(x=1, y=1), 2)\n" - ] - } - ], - "source": [ - "for successor in generate_successors(visited, Point(1,0), 1):\n", - " print(successor)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "d3b9ce23-b7a8-4683-bbc0-58a5a9cce7da", - "metadata": {}, - "outputs": [], - "source": [ - "ACTION_NAMES = [\"UP\", \"RIGHT\", \"DOWN\", \"LEFT\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "b802a66f-b19b-427e-8065-420c397bc87a", - "metadata": {}, - "outputs": [], - "source": [ - "def breadth_first_search(visited, frontier, threshold=COVERAGE_THRESHOLD, units=1, verbose=False):\n", - " '''\n", - " A general search algorithm which expands nodes determined by the frontier.\n", - " '''\n", - "\n", - " while not frontier.empty():\n", - " curr_node = frontier.get()\n", - " visited[curr_node.y, curr_node.x] = 1\n", - " if np.count_nonzero(visited) > threshold:\n", - " return True\n", - " for child, action in generate_successors(visited, curr_node, units):\n", - " if visited[child.y, child.x] == 0:\n", - " visited[child.y, child.x] = 1\n", - " if verbose:\n", - " print(f'From frontier {curr_node}, found child {child} from action {ACTION_NAMES[action]}.')\n", - " print(visited)\n", - " frontier.put(child)\n", - "\n", - " if verbose:\n", - " print(f'Could not expand enough children!')\n", - " return False\t\t\t\t\t\t# return failure" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "e762fc34-6376-428d-a1be-4ee8b68670a7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 1., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.],\n", - " [0., 0., 0., 0.]])" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "visited = get_visited([Point(0,0), Point(1,0)], 4, 4, 1)\n", - "visited" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "588e3802-65f3-4dc9-bfdc-da5d1ad59ae3", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "From frontier Point(x=2, y=0), found child Point(x=3, y=0) from action RIGHT.\n", - "[[1. 1. 1. 1.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "From frontier Point(x=2, y=0), found child Point(x=2, y=1) from action DOWN.\n", - "[[1. 1. 1. 1.]\n", - " [0. 0. 1. 0.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "From frontier Point(x=3, y=0), found child Point(x=3, y=1) from action DOWN.\n", - "[[1. 1. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 0. 0.]\n", - " [0. 0. 0. 0.]]\n", - "From frontier Point(x=2, y=1), found child Point(x=2, y=2) from action DOWN.\n", - "[[1. 1. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 0.]\n", - " [0. 0. 0. 0.]]\n", - "From frontier Point(x=2, y=1), found child Point(x=1, y=1) from action LEFT.\n", - "[[1. 1. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [0. 0. 1. 0.]\n", - " [0. 0. 0. 0.]]\n" - ] - }, - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "frontier = queue.Queue()\n", - "frontier.put(Point(2,0))\n", - "breadth_first_search(visited, frontier, threshold=6, units=1, verbose=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "24cd517a-05e4-49db-8967-e44af7fb0892", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[0., 0., 1., 0.],\n", - " [0., 0., 1., 0.],\n", - " [0., 0., 1., 0.],\n", - " [0., 0., 1., 0.]])" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(2,3)], 4, 4, 1)\n", - "visited" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "0a904a60-f413-4351-aa8b-ecb1139a5b0d", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "From frontier Point(x=3, y=3), found child Point(x=3, y=2) from action UP.\n", - "[[0. 0. 1. 0.]\n", - " [0. 0. 1. 0.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]]\n", - "From frontier Point(x=3, y=2), found child Point(x=3, y=1) from action UP.\n", - "[[0. 0. 1. 0.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]]\n", - "From frontier Point(x=3, y=1), found child Point(x=3, y=0) from action UP.\n", - "[[0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]]\n", - "Could not expand enough children!\n", - "action RIGHT: bad\n", - "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", - "[[0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [0. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", - "[[0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", - "[[0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", - "[[0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "action LEFT: good\n" - ] - } - ], - "source": [ - "head = Point(2,3)\n", - "for successor, action in generate_successors(visited, head, units=1):\n", - " frontier = queue.Queue()\n", - " frontier.put(successor)\n", - " if breadth_first_search(visited, frontier, threshold=12, units=1, verbose=True):\n", - " print(f'action {ACTION_NAMES[action]}: good')\n", - " else:\n", - " print(f'action {ACTION_NAMES[action]}: bad')" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "23d6abad-e052-41e7-9685-0ab22928cd8f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[1., 0., 1., 1.],\n", - " [0., 0., 1., 0.],\n", - " [0., 0., 1., 0.],\n", - " [0., 0., 0., 0.]])" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(3,0), Point(0,0)], 4, 4, 1)\n", - "visited" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "96e64e4a-1844-4ed6-a132-282e69afddb6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "From frontier Point(x=3, y=1), found child Point(x=3, y=2) from action DOWN.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 0. 0.]]\n", - "From frontier Point(x=3, y=2), found child Point(x=3, y=3) from action DOWN.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 0. 1.]]\n", - "From frontier Point(x=3, y=3), found child Point(x=2, y=3) from action LEFT.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]]\n", - "From frontier Point(x=2, y=3), found child Point(x=1, y=3) from action LEFT.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [0. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n", - "[[1. 0. 1. 1.]\n", - " [0. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n", - "[[1. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n", - "[[1. 0. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=1), found child Point(x=1, y=0) from action UP.\n", - "[[1. 1. 1. 1.]\n", - " [0. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "From frontier Point(x=1, y=1), found child Point(x=0, y=1) from action LEFT.\n", - "[[1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]\n", - " [1. 1. 1. 1.]]\n", - "action DOWN: good\n" - ] - } - ], - "source": [ - "head = Point(3,0)\n", - "for successor, action in generate_successors(visited, head, units=1):\n", - " frontier = queue.Queue()\n", - " frontier.put(successor)\n", - " if breadth_first_search(visited, frontier, threshold=15, units=1, verbose=True):\n", - " print(f'action {ACTION_NAMES[action]}: good')\n", - " else:\n", - " print(f'action {ACTION_NAMES[action]}: bad')" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "79a82894-7d07-4e56-94e7-01765750f97e", - "metadata": {}, - "outputs": [], - "source": [ - "def get_viable_actions(id):\n", - " heads, tails, _ = game_engine.get_heads_tails_and_goal()\n", - " visited = get_visited(heads + tails)\n", - "\n", - " viable_actions = []\n", - " valid_actions = []\n", - "\n", - " for successor, action in generate_successors(visited, heads[id]):\n", - " valid_actions.append(action)\n", - " frontier = queue.Queue()\n", - " frontier.put(successor)\n", - " if breadth_first_search(visited.copy(), frontier):\n", - " viable_actions.append(action)\n", - " if np.count_nonzero(visited) >= COVERAGE_THRESHOLD:\n", - " print(f'Coverage reached!')\n", - " print(f'No more smart actions left!')\n", - " if len(viable_actions) == 0:\n", - " return valid_actions\n", - " return viable_actions" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "a5c7e6a8-fd69-4e71-a016-d404a1b99906", - "metadata": {}, - "outputs": [], - "source": [ - "superior_table = qtsnake.load_q('superior_qt.npy')" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "0b45d83d-6a8e-4af9-97a7-28c9d965e52a", - "metadata": {}, - "outputs": [], - "source": [ - "class QSnakeImproved(qtsnake.QSnake):\n", - "\n", - " def __init__(self, game_engine):\n", - " super().__init__(game_engine)\n", - "\n", - " ''' override '''\n", - " def pick_greedy_action(self, q, id, epsilon=0):\n", - " viable_actions = get_viable_actions(id)\n", - " state, rewards = self.index_actions(q, id)\n", - "\n", - " if np.random.uniform() < epsilon:\n", - " return (state, np.random.choice(viable_actions)) if viable_actions.size > 0 else (state, 0)\n", - " for action in self.argmin_gen(rewards):\n", - " if action in viable_actions:\n", - " return (state, action)\n", - " return (state, 0) # death" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "0c3e03aa-0c54-4db7-bbc8-99fe4841f310", - "metadata": {}, - "outputs": [], - "source": [ - "q_table = QSnakeImproved(game_engine)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "3b7551d3-d70b-4243-b445-e9dcf1c4e39c", - "metadata": {}, - "outputs": [], - "source": [ - "for step in range(2000):\n", - " _, p1_action = q_table.pick_greedy_action(superior_table, p1)\n", - " game_engine.player_advance([p1_action])" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} |