summaryrefslogtreecommitdiff
path: root/revised_get_viable_actions.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'revised_get_viable_actions.ipynb')
-rw-r--r--revised_get_viable_actions.ipynb693
1 files changed, 693 insertions, 0 deletions
diff --git a/revised_get_viable_actions.ipynb b/revised_get_viable_actions.ipynb
new file mode 100644
index 0000000..805d630
--- /dev/null
+++ b/revised_get_viable_actions.ipynb
@@ -0,0 +1,693 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "83cff4a8-32f3-410f-b1cf-a5f406559d01",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n",
+ "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
+ ]
+ }
+ ],
+ "source": [
+ "import numpy as np\n",
+ "import matplotlib.pyplot as plt\n",
+ "from collections import namedtuple\n",
+ "import queue\n",
+ "from IPython.core.debugger import Pdb\n",
+ "from IPython.display import display, clear_output\n",
+ "\n",
+ "from GameEngine import multiplayer\n",
+ "from QTable import qtsnake\n",
+ "\n",
+ "Point = namedtuple('Point', 'x, y')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ea980d7c-d430-44b1-9cb4-9f21122005ff",
+ "metadata": {},
+ "source": [
+ "### The 'get_viable_actions' function"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "55cd65c9-bd4c-4b7a-8484-5e0f94742967",
+ "metadata": {},
+ "source": [
+ "#### Hamiltonian Cycles"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "2640d368-eeb0-4dba-98dd-40aba5579b87",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# defines game window size and block size, in pixels\n",
+ "WINDOW_WIDTH = 300\n",
+ "WINDOW_HEIGHT = 300\n",
+ "GAME_UNITS = 30\n",
+ "S_SIZE = 3\n",
+ "S_START = (0 + ((S_SIZE-1) * GAME_UNITS), 0 * GAME_UNITS) #FIXME"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "68d5bf28-1a6f-4188-b078-6cb9e6e36612",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visited = np.zeros((WINDOW_WIDTH // GAME_UNITS,\n",
+ " WINDOW_HEIGHT // GAME_UNITS))\n",
+ "visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "ee4e92a3-842d-40bb-aa0b-45a2bd377e56",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "80.0"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "COVERAGE_THRESHOLD = .8 * visited.size\n",
+ "COVERAGE_THRESHOLD"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "0a1ba6af-88f0-4945-b16e-b396ccfeb88a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n",
+ " window_height=WINDOW_HEIGHT,\n",
+ " units=GAME_UNITS,\n",
+ " g_speed=35,\n",
+ " s_size=S_SIZE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "ba66417b-90a3-42f5-b6b1-9d87361d1a7e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Game starting with 1 players.\n"
+ ]
+ }
+ ],
+ "source": [
+ "p1 = game_engine.add_player(S_START)\n",
+ "game_engine.start_game()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "11393601-b055-4371-ac4f-4f0079513f8d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[<CollisionType.NONE: 2>]"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "game_engine.player_advance([1])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "93f5c3d9-4457-4fb9-9dc7-a1947a0859e5",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_visited(unreachable, width=WINDOW_WIDTH,\n",
+ " height=WINDOW_HEIGHT, units=GAME_UNITS):\n",
+ " '''\n",
+ " given a numpy array corresponding to grid\n",
+ " and a list of tails,\n",
+ " marks unreachable places as visited\n",
+ " '''\n",
+ " visited = np.zeros((height//units, width//units))\n",
+ " for node in unreachable:\n",
+ " visited[node.y//units, node.x//units] = 1\n",
+ " return visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "bc03d77a-ff75-4fc7-9dfe-fd851e10077b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def is_instant_death(visited, expansion):\n",
+ " if (min(expansion) < 0 or\n",
+ " expansion.y >= visited.shape[0] or\n",
+ " expansion.x >= visited.shape[1] or\n",
+ " visited[expansion.y, expansion.x] != 0):\n",
+ " return True\n",
+ " return False"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "2fbc7f1c-f55f-4360-9dba-88d169a1ea0a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def generate_successors(visited, frontier, units=GAME_UNITS):\n",
+ " '''\n",
+ " a generator function used to generate\n",
+ " every new reachable state\n",
+ "\n",
+ " actions corresponding to displacement\n",
+ " 0 1 2 3\n",
+ " UP, RIGHT, DOWN, LEFT\n",
+ " '''\n",
+ " actions = [Point(0, -units), Point(units, 0),\n",
+ " Point(0, units), Point(-units, 0)]\n",
+ " all_actions = [0, 1, 2, 3]\n",
+ " expansion = None\n",
+ "\n",
+ " for action in all_actions:\n",
+ " ''' calculate new position '''\n",
+ " expansion = Point((frontier.x + actions[action].x) // units,\n",
+ " (frontier.y + actions[action].y) // units)\n",
+ " if not is_instant_death(visited, expansion):\n",
+ " yield expansion, action"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "b3914dc1-e888-439e-b21e-1151f7e29ea9",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[1., 0.],\n",
+ " [0., 0.]])"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visited = get_visited([Point(0,0)], 2, 2, 1)\n",
+ "visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "175ab587-f779-495d-9149-9e492b6872b5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "(Point(x=1, y=1), 2)\n"
+ ]
+ }
+ ],
+ "source": [
+ "for successor in generate_successors(visited, Point(1,0), 1):\n",
+ " print(successor)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "d3b9ce23-b7a8-4683-bbc0-58a5a9cce7da",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ACTION_NAMES = [\"UP\", \"RIGHT\", \"DOWN\", \"LEFT\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "b802a66f-b19b-427e-8065-420c397bc87a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def breadth_first_search(visited, frontier, threshold=COVERAGE_THRESHOLD, units=1, verbose=False):\n",
+ " '''\n",
+ " A general search algorithm which expands nodes determined by the frontier.\n",
+ " '''\n",
+ "\n",
+ " while not frontier.empty():\n",
+ " curr_node = frontier.get()\n",
+ " visited[curr_node.y, curr_node.x] = 1\n",
+ " if np.count_nonzero(visited) > threshold:\n",
+ " return True\n",
+ " for child, action in generate_successors(visited, curr_node, units):\n",
+ " if visited[child.y, child.x] == 0:\n",
+ " visited[child.y, child.x] = 1\n",
+ " if verbose:\n",
+ " print(f'From frontier {curr_node}, found child {child} from action {ACTION_NAMES[action]}.')\n",
+ " print(visited)\n",
+ " frontier.put(child)\n",
+ "\n",
+ " if verbose:\n",
+ " print(f'Could not expand enough children!')\n",
+ " return False\t\t\t\t\t\t# return failure"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "e762fc34-6376-428d-a1be-4ee8b68670a7",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[1., 1., 0., 0.],\n",
+ " [0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0.],\n",
+ " [0., 0., 0., 0.]])"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visited = get_visited([Point(0,0), Point(1,0)], 4, 4, 1)\n",
+ "visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "588e3802-65f3-4dc9-bfdc-da5d1ad59ae3",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "From frontier Point(x=2, y=0), found child Point(x=3, y=0) from action RIGHT.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0.]]\n",
+ "From frontier Point(x=2, y=0), found child Point(x=2, y=1) from action DOWN.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 0. 1. 0.]\n",
+ " [0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0.]]\n",
+ "From frontier Point(x=3, y=0), found child Point(x=3, y=1) from action DOWN.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 0. 0.]\n",
+ " [0. 0. 0. 0.]]\n",
+ "From frontier Point(x=2, y=1), found child Point(x=2, y=2) from action DOWN.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 0.]\n",
+ " [0. 0. 0. 0.]]\n",
+ "From frontier Point(x=2, y=1), found child Point(x=1, y=1) from action LEFT.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [0. 0. 1. 0.]\n",
+ " [0. 0. 0. 0.]]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "True"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "frontier = queue.Queue()\n",
+ "frontier.put(Point(2,0))\n",
+ "breadth_first_search(visited, frontier, threshold=6, units=1, verbose=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "24cd517a-05e4-49db-8967-e44af7fb0892",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[0., 0., 1., 0.],\n",
+ " [0., 0., 1., 0.],\n",
+ " [0., 0., 1., 0.],\n",
+ " [0., 0., 1., 0.]])"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(2,3)], 4, 4, 1)\n",
+ "visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "0a904a60-f413-4351-aa8b-ecb1139a5b0d",
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "From frontier Point(x=3, y=3), found child Point(x=3, y=2) from action UP.\n",
+ "[[0. 0. 1. 0.]\n",
+ " [0. 0. 1. 0.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]]\n",
+ "From frontier Point(x=3, y=2), found child Point(x=3, y=1) from action UP.\n",
+ "[[0. 0. 1. 0.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]]\n",
+ "From frontier Point(x=3, y=1), found child Point(x=3, y=0) from action UP.\n",
+ "[[0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]]\n",
+ "Could not expand enough children!\n",
+ "action RIGHT: bad\n",
+ "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n",
+ "[[0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n",
+ "[[0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n",
+ "[[0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n",
+ "[[0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "action LEFT: good\n"
+ ]
+ }
+ ],
+ "source": [
+ "head = Point(2,3)\n",
+ "for successor, action in generate_successors(visited, head, units=1):\n",
+ " frontier = queue.Queue()\n",
+ " frontier.put(successor)\n",
+ " if breadth_first_search(visited, frontier, threshold=12, units=1, verbose=True):\n",
+ " print(f'action {ACTION_NAMES[action]}: good')\n",
+ " else:\n",
+ " print(f'action {ACTION_NAMES[action]}: bad')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "23d6abad-e052-41e7-9685-0ab22928cd8f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "array([[1., 0., 1., 1.],\n",
+ " [0., 0., 1., 0.],\n",
+ " [0., 0., 1., 0.],\n",
+ " [0., 0., 0., 0.]])"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "visited = get_visited([Point(2,0), Point(2,1), Point(2,2), Point(3,0), Point(0,0)], 4, 4, 1)\n",
+ "visited"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "96e64e4a-1844-4ed6-a132-282e69afddb6",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "From frontier Point(x=3, y=1), found child Point(x=3, y=2) from action DOWN.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 0. 0.]]\n",
+ "From frontier Point(x=3, y=2), found child Point(x=3, y=3) from action DOWN.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 0. 1.]]\n",
+ "From frontier Point(x=3, y=3), found child Point(x=2, y=3) from action LEFT.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]]\n",
+ "From frontier Point(x=2, y=3), found child Point(x=1, y=3) from action LEFT.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=3), found child Point(x=1, y=2) from action UP.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=3), found child Point(x=0, y=3) from action LEFT.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=2), found child Point(x=1, y=1) from action UP.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=2), found child Point(x=0, y=2) from action LEFT.\n",
+ "[[1. 0. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=1), found child Point(x=1, y=0) from action UP.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [0. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "From frontier Point(x=1, y=1), found child Point(x=0, y=1) from action LEFT.\n",
+ "[[1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]\n",
+ " [1. 1. 1. 1.]]\n",
+ "action DOWN: good\n"
+ ]
+ }
+ ],
+ "source": [
+ "head = Point(3,0)\n",
+ "for successor, action in generate_successors(visited, head, units=1):\n",
+ " frontier = queue.Queue()\n",
+ " frontier.put(successor)\n",
+ " if breadth_first_search(visited, frontier, threshold=15, units=1, verbose=True):\n",
+ " print(f'action {ACTION_NAMES[action]}: good')\n",
+ " else:\n",
+ " print(f'action {ACTION_NAMES[action]}: bad')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "79a82894-7d07-4e56-94e7-01765750f97e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_viable_actions(id):\n",
+ " heads, tails, _ = game_engine.get_heads_tails_and_goal()\n",
+ " visited = get_visited(heads + tails)\n",
+ "\n",
+ " viable_actions = []\n",
+ " valid_actions = []\n",
+ "\n",
+ " for successor, action in generate_successors(visited, heads[id]):\n",
+ " valid_actions.append(action)\n",
+ " frontier = queue.Queue()\n",
+ " frontier.put(successor)\n",
+ " if breadth_first_search(visited.copy(), frontier):\n",
+ " viable_actions.append(action)\n",
+ " if np.count_nonzero(visited) >= COVERAGE_THRESHOLD:\n",
+ " print(f'Coverage reached!')\n",
+ " print(f'No more smart actions left!')\n",
+ " if len(viable_actions) == 0:\n",
+ " return valid_actions\n",
+ " return viable_actions"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "a5c7e6a8-fd69-4e71-a016-d404a1b99906",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "superior_table = qtsnake.load_q('superior_qt.npy')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "0b45d83d-6a8e-4af9-97a7-28c9d965e52a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class QSnakeImproved(qtsnake.QSnake):\n",
+ "\n",
+ " def __init__(self, game_engine):\n",
+ " super().__init__(game_engine)\n",
+ "\n",
+ " ''' override '''\n",
+ " def pick_greedy_action(self, q, id, epsilon=0):\n",
+ " viable_actions = get_viable_actions(id)\n",
+ " state, rewards = self.index_actions(q, id)\n",
+ "\n",
+ " if np.random.uniform() < epsilon:\n",
+ " return (state, np.random.choice(viable_actions)) if viable_actions.size > 0 else (state, 0)\n",
+ " for action in self.argmin_gen(rewards):\n",
+ " if action in viable_actions:\n",
+ " return (state, action)\n",
+ " return (state, 0) # death"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "0c3e03aa-0c54-4db7-bbc8-99fe4841f310",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "q_table = QSnakeImproved(game_engine)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "3b7551d3-d70b-4243-b445-e9dcf1c4e39c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for step in range(2000):\n",
+ " _, p1_action = q_table.pick_greedy_action(superior_table, p1)\n",
+ " game_engine.player_advance([p1_action])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.5"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}