summaryrefslogtreecommitdiff
path: root/revised_snake_q_table_noise.ipynb
diff options
context:
space:
mode:
authorbd-912 <bdunahu@gmail.com>2023-12-12 21:14:04 -0700
committerbd-912 <bdunahu@gmail.com>2023-12-12 21:14:04 -0700
commitca527f085dc996c81854d8450252353986c6e82f (patch)
tree028d5ac730030fd4d25a3f39067bf3c2601a5540 /revised_snake_q_table_noise.ipynb
parent09701f70bed858f51e927ef4a63a6f5345a370b9 (diff)
Many commentary changes in q_network notebook
Diffstat (limited to 'revised_snake_q_table_noise.ipynb')
-rw-r--r--revised_snake_q_table_noise.ipynb290
1 files changed, 0 insertions, 290 deletions
diff --git a/revised_snake_q_table_noise.ipynb b/revised_snake_q_table_noise.ipynb
deleted file mode 100644
index 52455e8..0000000
--- a/revised_snake_q_table_noise.ipynb
+++ /dev/null
@@ -1,290 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "85da5df2-c926-417c-bd7b-d214ad31ebe1",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)\n",
- "Hello from the pygame community. https://www.pygame.org/contribute.html\n"
- ]
- }
- ],
- "source": [
- "import numpy as np\n",
- "from collections import namedtuple\n",
- "from IPython.core.debugger import Pdb\n",
- "\n",
- "from GameEngine import multiplayer\n",
- "from QTable import qtsnake\n",
- "Point = namedtuple('Point', 'x, y')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "9eec33d8-9a65-426a-8ad5-8ffb2cbe2541",
- "metadata": {},
- "outputs": [],
- "source": [
- "# defines game window size and block size, in pixels\n",
- "WINDOW_WIDTH = 640\n",
- "WINDOW_HEIGHT = 480\n",
- "GAME_UNITS = 80"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "77b2d951-72be-4bee-aa73-3b87ea0288b6",
- "metadata": {},
- "outputs": [],
- "source": [
- "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n",
- " window_height=WINDOW_HEIGHT,\n",
- " units=GAME_UNITS,\n",
- " g_speed=35,\n",
- " s_size=1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "93ce92f3-b336-4aca-a8da-ebd913e9d16b",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Game starting with 1 players.\n",
- "Draw is now False.\n"
- ]
- }
- ],
- "source": [
- "p1 = game_engine.add_player()\n",
- "game_engine.start_game()\n",
- "game_engine.toggle_draw()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "78ad0c30-63e7-4689-9332-6e9540a67105",
- "metadata": {},
- "outputs": [],
- "source": [
- "danger_states = 16\n",
- "goal_relations = 8\n",
- "actions = 4\n",
- "q = np.zeros((danger_states,\n",
- " goal_relations,\n",
- " actions))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "09804da3-ef32-4265-a63d-7d4ea7e59031",
- "metadata": {},
- "outputs": [],
- "source": [
- "def sense_danger(head, tail, units=GAME_UNITS, window_width=WINDOW_WIDTH, window_height=WINDOW_HEIGHT):\n",
- " '''\n",
- " Given a player's head and tail,\n",
- " returns a list of actions that does\n",
- " not result in immediate death\n",
- " '''\n",
- " danger_array = np.array([\n",
- " head.y-units < 0 or Point(head.x, head.y-units) in tail[1:], # up\n",
- " head.x+units >= window_width or Point(head.x+units, head.y) in tail[1:], # right\n",
- " head.y+units >= window_height or Point(head.x, head.y+units) in tail[1:], # down\n",
- " head.x-units < 0 or Point(head.x-units, head.y) in tail[1:], # left\n",
- " ])\n",
- "\n",
- " binary_string = \"\".join(map(str, map(int, danger_array)))\n",
- "\n",
- " return int(binary_string, base=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "bfca56cf-0cef-4523-8e4a-f3254180e397",
- "metadata": {},
- "outputs": [],
- "source": [
- "x_b = 5; y_b = 5;\n",
- "head = Point(3,3); tail = [head]\n",
- "assert sense_danger(head,tail,1,x_b,y_b) == 0\n",
- "\n",
- "head = Point(0,0); tail = [head, Point(0,1)]\n",
- "assert sense_danger(head,tail,1,x_b,y_b) == 11\n",
- "\n",
- "head = Point(3,3); tail = [head, Point(3,4)]\n",
- "assert sense_danger(head,tail,1,x_b,y_b) == 2\n",
- "\n",
- "head = Point(3,3); tail = [head, Point(2,3)]\n",
- "assert sense_danger(head,tail,1,x_b,y_b) == 1"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "id": "45d9645a-b217-4812-8b35-2a154518087d",
- "metadata": {},
- "outputs": [],
- "source": [
- "def index_actions(q, id):\n",
- " '''\n",
- " given q, player_id, an array of heads,\n",
- " and the goal position,\n",
- " indexes into the corresponding expected\n",
- " reward of each action\n",
- " '''\n",
- " heads, tails, goal = game_engine.get_heads_tails_and_goal()\n",
- " state = np.array([sense_danger(heads[id], tails), qtsnake.sense_goal(heads[id], goal)])\n",
- " return state, q[state[0], state[1], :]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "b8fc49f8-fc8e-4c0b-92f2-a345200ee26d",
- "metadata": {},
- "outputs": [],
- "source": [
- "def pick_greedy_action(q, id, epsilon):\n",
- " state, rewards = index_actions(q, id)\n",
- "\n",
- " if np.random.uniform() < epsilon:\n",
- " return np.append(state, [np.random.choice([0, 1, 2, 3])])\n",
- " action = np.argmin(rewards)\n",
- " return np.append(state, [action])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "9dba22f2-a8fa-400c-8c71-fec7e988489c",
- "metadata": {},
- "outputs": [],
- "source": [
- "def update_q(q, old_state_action, new_state_action, outcome, lr=0.05):\n",
- " if outcome == multiplayer.CollisionType.GOAL:\n",
- " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 0\n",
- " elif outcome == multiplayer.CollisionType.DEATH:\n",
- " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 500\n",
- " else:\n",
- " td_error = 1 + q[new_state_action[0], new_state_action[1], new_state_action[2]] - q[old_state_action[0], old_state_action[1], old_state_action[2]]\n",
- " q[old_state_action[0], old_state_action[1], old_state_action[2]] += lr * td_error"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "id": "bc22f0b7-e409-4fd8-8433-8e918d5632ad",
- "metadata": {},
- "outputs": [],
- "source": [
- "n_steps = 52500\n",
- "epsilon = 1\n",
- "final_epsilon = 0.03\n",
- "epsilon_decay = np.exp(np.log(final_epsilon) / (n_steps))\n",
- "lr = 0.15"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "id": "c7c4417e-66f5-4028-8264-f4b732ba33a7",
- "metadata": {},
- "outputs": [],
- "source": [
- "p1_old_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n",
- "game_engine.player_advance([p1_old_s_a[-1]])\n",
- "\n",
- "for step in range(n_steps):\n",
- " p1_new_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n",
- " outcome = game_engine.player_advance([p1_new_s_a[-1]], 0.05)\n",
- "\n",
- " update_q(q, p1_old_s_a, p1_new_s_a, outcome[p1], lr)\n",
- "\n",
- " epsilon *= epsilon_decay\n",
- " p1_old_s_a = p1_new_s_a"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "id": "6538bcdc-4b4b-48ef-8c55-76834490b698",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Draw is now True.\n"
- ]
- }
- ],
- "source": [
- "game_engine.toggle_draw()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "6e50b723-f0a4-48cd-bb98-1666ec0f59ce",
- "metadata": {},
- "outputs": [
- {
- "ename": "KeyboardInterrupt",
- "evalue": "",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "Cell \u001b[0;32mIn[14], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_steps):\n\u001b[1;32m 3\u001b[0m s_a \u001b[38;5;241m=\u001b[39m pick_greedy_action(q, p1, epsilon)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mgame_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplayer_advance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
- "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:118\u001b[0m, in \u001b[0;36mPlayfield.player_advance\u001b[0;34m(self, actions, noise)\u001b[0m\n\u001b[1;32m 116\u001b[0m collisions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_player_collisions()\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_draw_on:\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_ui\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m collisions\n",
- "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:163\u001b[0m, in \u001b[0;36mPlayfield._update_ui\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_goal\u001b[38;5;241m.\u001b[39mdraw()\n\u001b[1;32m 162\u001b[0m pg\u001b[38;5;241m.\u001b[39mdisplay\u001b[38;5;241m.\u001b[39mflip() \u001b[38;5;66;03m# full screen update\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_clock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtick\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_g_speed\u001b[49m\u001b[43m)\u001b[49m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "epsilon = 0\n",
- "for step in range(n_steps):\n",
- " s_a = pick_greedy_action(q, p1, epsilon)\n",
- " game_engine.player_advance([s_a[-1]])"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3 (ipykernel)",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.11.5"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 5
-}