summaryrefslogtreecommitdiff
path: root/revised_snake_q_table_noise.ipynb
diff options
context:
space:
mode:
authorbd-912 <bdunahu@gmail.com>2023-11-15 12:47:42 -0700
committerbd-912 <bdunahu@gmail.com>2023-11-15 12:47:42 -0700
commitced115881e2f60afe41373b6da899f5d2f5403e0 (patch)
tree99d323876a5c32b3b138102440b4a1c6c12ac7be /revised_snake_q_table_noise.ipynb
parenta2b56742da7b30afa00f33c9a806fa6031be68a5 (diff)
Added new notebook concerning improvements on get_viable_actions
Diffstat (limited to 'revised_snake_q_table_noise.ipynb')
-rw-r--r--revised_snake_q_table_noise.ipynb237
1 files changed, 235 insertions, 2 deletions
diff --git a/revised_snake_q_table_noise.ipynb b/revised_snake_q_table_noise.ipynb
index 695875e..52455e8 100644
--- a/revised_snake_q_table_noise.ipynb
+++ b/revised_snake_q_table_noise.ipynb
@@ -21,16 +21,249 @@
"from IPython.core.debugger import Pdb\n",
"\n",
"from GameEngine import multiplayer\n",
+ "from QTable import qtsnake\n",
"Point = namedtuple('Point', 'x, y')"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "9eec33d8-9a65-426a-8ad5-8ffb2cbe2541",
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "# defines game window size and block size, in pixels\n",
+ "WINDOW_WIDTH = 640\n",
+ "WINDOW_HEIGHT = 480\n",
+ "GAME_UNITS = 80"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "77b2d951-72be-4bee-aa73-3b87ea0288b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n",
+ " window_height=WINDOW_HEIGHT,\n",
+ " units=GAME_UNITS,\n",
+ " g_speed=35,\n",
+ " s_size=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "93ce92f3-b336-4aca-a8da-ebd913e9d16b",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Game starting with 1 players.\n",
+ "Draw is now False.\n"
+ ]
+ }
+ ],
+ "source": [
+ "p1 = game_engine.add_player()\n",
+ "game_engine.start_game()\n",
+ "game_engine.toggle_draw()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "78ad0c30-63e7-4689-9332-6e9540a67105",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "danger_states = 16\n",
+ "goal_relations = 8\n",
+ "actions = 4\n",
+ "q = np.zeros((danger_states,\n",
+ " goal_relations,\n",
+ " actions))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "09804da3-ef32-4265-a63d-7d4ea7e59031",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def sense_danger(head, tail, units=GAME_UNITS, window_width=WINDOW_WIDTH, window_height=WINDOW_HEIGHT):\n",
+ " '''\n",
+ " Given a player's head and tail,\n",
+ " returns a list of actions that does\n",
+ " not result in immediate death\n",
+ " '''\n",
+ " danger_array = np.array([\n",
+ " head.y-units < 0 or Point(head.x, head.y-units) in tail[1:], # up\n",
+ " head.x+units >= window_width or Point(head.x+units, head.y) in tail[1:], # right\n",
+ " head.y+units >= window_height or Point(head.x, head.y+units) in tail[1:], # down\n",
+ " head.x-units < 0 or Point(head.x-units, head.y) in tail[1:], # left\n",
+ " ])\n",
+ "\n",
+ " binary_string = \"\".join(map(str, map(int, danger_array)))\n",
+ "\n",
+ " return int(binary_string, base=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "bfca56cf-0cef-4523-8e4a-f3254180e397",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x_b = 5; y_b = 5;\n",
+ "head = Point(3,3); tail = [head]\n",
+ "assert sense_danger(head,tail,1,x_b,y_b) == 0\n",
+ "\n",
+ "head = Point(0,0); tail = [head, Point(0,1)]\n",
+ "assert sense_danger(head,tail,1,x_b,y_b) == 11\n",
+ "\n",
+ "head = Point(3,3); tail = [head, Point(3,4)]\n",
+ "assert sense_danger(head,tail,1,x_b,y_b) == 2\n",
+ "\n",
+ "head = Point(3,3); tail = [head, Point(2,3)]\n",
+ "assert sense_danger(head,tail,1,x_b,y_b) == 1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "45d9645a-b217-4812-8b35-2a154518087d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def index_actions(q, id):\n",
+ " '''\n",
+ " given q, player_id, an array of heads,\n",
+ " and the goal position,\n",
+ " indexes into the corresponding expected\n",
+ " reward of each action\n",
+ " '''\n",
+ " heads, tails, goal = game_engine.get_heads_tails_and_goal()\n",
+ " state = np.array([sense_danger(heads[id], tails), qtsnake.sense_goal(heads[id], goal)])\n",
+ " return state, q[state[0], state[1], :]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "b8fc49f8-fc8e-4c0b-92f2-a345200ee26d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def pick_greedy_action(q, id, epsilon):\n",
+ " state, rewards = index_actions(q, id)\n",
+ "\n",
+ " if np.random.uniform() < epsilon:\n",
+ " return np.append(state, [np.random.choice([0, 1, 2, 3])])\n",
+ " action = np.argmin(rewards)\n",
+ " return np.append(state, [action])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "9dba22f2-a8fa-400c-8c71-fec7e988489c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def update_q(q, old_state_action, new_state_action, outcome, lr=0.05):\n",
+ " if outcome == multiplayer.CollisionType.GOAL:\n",
+ " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 0\n",
+ " elif outcome == multiplayer.CollisionType.DEATH:\n",
+ " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 500\n",
+ " else:\n",
+ " td_error = 1 + q[new_state_action[0], new_state_action[1], new_state_action[2]] - q[old_state_action[0], old_state_action[1], old_state_action[2]]\n",
+ " q[old_state_action[0], old_state_action[1], old_state_action[2]] += lr * td_error"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "bc22f0b7-e409-4fd8-8433-8e918d5632ad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "n_steps = 52500\n",
+ "epsilon = 1\n",
+ "final_epsilon = 0.03\n",
+ "epsilon_decay = np.exp(np.log(final_epsilon) / (n_steps))\n",
+ "lr = 0.15"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "c7c4417e-66f5-4028-8264-f4b732ba33a7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "p1_old_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n",
+ "game_engine.player_advance([p1_old_s_a[-1]])\n",
+ "\n",
+ "for step in range(n_steps):\n",
+ " p1_new_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n",
+ " outcome = game_engine.player_advance([p1_new_s_a[-1]], 0.05)\n",
+ "\n",
+ " update_q(q, p1_old_s_a, p1_new_s_a, outcome[p1], lr)\n",
+ "\n",
+ " epsilon *= epsilon_decay\n",
+ " p1_old_s_a = p1_new_s_a"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "6538bcdc-4b4b-48ef-8c55-76834490b698",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Draw is now True.\n"
+ ]
+ }
+ ],
+ "source": [
+ "game_engine.toggle_draw()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "6e50b723-f0a4-48cd-bb98-1666ec0f59ce",
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "KeyboardInterrupt",
+ "evalue": "",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[14], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_steps):\n\u001b[1;32m 3\u001b[0m s_a \u001b[38;5;241m=\u001b[39m pick_greedy_action(q, p1, epsilon)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mgame_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplayer_advance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
+ "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:118\u001b[0m, in \u001b[0;36mPlayfield.player_advance\u001b[0;34m(self, actions, noise)\u001b[0m\n\u001b[1;32m 116\u001b[0m collisions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_player_collisions()\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_draw_on:\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_ui\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m collisions\n",
+ "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:163\u001b[0m, in \u001b[0;36mPlayfield._update_ui\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_goal\u001b[38;5;241m.\u001b[39mdraw()\n\u001b[1;32m 162\u001b[0m pg\u001b[38;5;241m.\u001b[39mdisplay\u001b[38;5;241m.\u001b[39mflip() \u001b[38;5;66;03m# full screen update\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_clock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtick\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_g_speed\u001b[49m\u001b[43m)\u001b[49m\n",
+ "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
+ ]
+ }
+ ],
+ "source": [
+ "epsilon = 0\n",
+ "for step in range(n_steps):\n",
+ " s_a = pick_greedy_action(q, p1, epsilon)\n",
+ " game_engine.player_advance([s_a[-1]])"
+ ]
}
],
"metadata": {