diff options
author | bd-912 <bdunahu@gmail.com> | 2023-11-15 12:47:42 -0700 |
---|---|---|
committer | bd-912 <bdunahu@gmail.com> | 2023-11-15 12:47:42 -0700 |
commit | ced115881e2f60afe41373b6da899f5d2f5403e0 (patch) | |
tree | 99d323876a5c32b3b138102440b4a1c6c12ac7be /revised_snake_q_table_noise.ipynb | |
parent | a2b56742da7b30afa00f33c9a806fa6031be68a5 (diff) |
Added new notebook concerning improvements on get_viable_actions
Diffstat (limited to 'revised_snake_q_table_noise.ipynb')
-rw-r--r-- | revised_snake_q_table_noise.ipynb | 237 |
1 files changed, 235 insertions, 2 deletions
diff --git a/revised_snake_q_table_noise.ipynb b/revised_snake_q_table_noise.ipynb index 695875e..52455e8 100644 --- a/revised_snake_q_table_noise.ipynb +++ b/revised_snake_q_table_noise.ipynb @@ -21,16 +21,249 @@ "from IPython.core.debugger import Pdb\n", "\n", "from GameEngine import multiplayer\n", + "from QTable import qtsnake\n", "Point = namedtuple('Point', 'x, y')" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "9eec33d8-9a65-426a-8ad5-8ffb2cbe2541", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "# defines game window size and block size, in pixels\n", + "WINDOW_WIDTH = 640\n", + "WINDOW_HEIGHT = 480\n", + "GAME_UNITS = 80" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "77b2d951-72be-4bee-aa73-3b87ea0288b6", + "metadata": {}, + "outputs": [], + "source": [ + "game_engine = multiplayer.Playfield(window_width=WINDOW_WIDTH,\n", + " window_height=WINDOW_HEIGHT,\n", + " units=GAME_UNITS,\n", + " g_speed=35,\n", + " s_size=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "93ce92f3-b336-4aca-a8da-ebd913e9d16b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Game starting with 1 players.\n", + "Draw is now False.\n" + ] + } + ], + "source": [ + "p1 = game_engine.add_player()\n", + "game_engine.start_game()\n", + "game_engine.toggle_draw()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "78ad0c30-63e7-4689-9332-6e9540a67105", + "metadata": {}, + "outputs": [], + "source": [ + "danger_states = 16\n", + "goal_relations = 8\n", + "actions = 4\n", + "q = np.zeros((danger_states,\n", + " goal_relations,\n", + " actions))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "09804da3-ef32-4265-a63d-7d4ea7e59031", + "metadata": {}, + "outputs": [], + "source": [ + "def sense_danger(head, tail, units=GAME_UNITS, window_width=WINDOW_WIDTH, window_height=WINDOW_HEIGHT):\n", + " '''\n", + " Given a player's head and tail,\n", + " returns a list of actions that does\n", + " not result in immediate death\n", + " '''\n", + " danger_array = np.array([\n", + " head.y-units < 0 or Point(head.x, head.y-units) in tail[1:], # up\n", + " head.x+units >= window_width or Point(head.x+units, head.y) in tail[1:], # right\n", + " head.y+units >= window_height or Point(head.x, head.y+units) in tail[1:], # down\n", + " head.x-units < 0 or Point(head.x-units, head.y) in tail[1:], # left\n", + " ])\n", + "\n", + " binary_string = \"\".join(map(str, map(int, danger_array)))\n", + "\n", + " return int(binary_string, base=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bfca56cf-0cef-4523-8e4a-f3254180e397", + "metadata": {}, + "outputs": [], + "source": [ + "x_b = 5; y_b = 5;\n", + "head = Point(3,3); tail = [head]\n", + "assert sense_danger(head,tail,1,x_b,y_b) == 0\n", + "\n", + "head = Point(0,0); tail = [head, Point(0,1)]\n", + "assert sense_danger(head,tail,1,x_b,y_b) == 11\n", + "\n", + "head = Point(3,3); tail = [head, Point(3,4)]\n", + "assert sense_danger(head,tail,1,x_b,y_b) == 2\n", + "\n", + "head = Point(3,3); tail = [head, Point(2,3)]\n", + "assert sense_danger(head,tail,1,x_b,y_b) == 1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "45d9645a-b217-4812-8b35-2a154518087d", + "metadata": {}, + "outputs": [], + "source": [ + "def index_actions(q, id):\n", + " '''\n", + " given q, player_id, an array of heads,\n", + " and the goal position,\n", + " indexes into the corresponding expected\n", + " reward of each action\n", + " '''\n", + " heads, tails, goal = game_engine.get_heads_tails_and_goal()\n", + " state = np.array([sense_danger(heads[id], tails), qtsnake.sense_goal(heads[id], goal)])\n", + " return state, q[state[0], state[1], :]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b8fc49f8-fc8e-4c0b-92f2-a345200ee26d", + "metadata": {}, + "outputs": [], + "source": [ + "def pick_greedy_action(q, id, epsilon):\n", + " state, rewards = index_actions(q, id)\n", + "\n", + " if np.random.uniform() < epsilon:\n", + " return np.append(state, [np.random.choice([0, 1, 2, 3])])\n", + " action = np.argmin(rewards)\n", + " return np.append(state, [action])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "9dba22f2-a8fa-400c-8c71-fec7e988489c", + "metadata": {}, + "outputs": [], + "source": [ + "def update_q(q, old_state_action, new_state_action, outcome, lr=0.05):\n", + " if outcome == multiplayer.CollisionType.GOAL:\n", + " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 0\n", + " elif outcome == multiplayer.CollisionType.DEATH:\n", + " q[new_state_action[0], new_state_action[1], new_state_action[2]] = 500\n", + " else:\n", + " td_error = 1 + q[new_state_action[0], new_state_action[1], new_state_action[2]] - q[old_state_action[0], old_state_action[1], old_state_action[2]]\n", + " q[old_state_action[0], old_state_action[1], old_state_action[2]] += lr * td_error" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "bc22f0b7-e409-4fd8-8433-8e918d5632ad", + "metadata": {}, + "outputs": [], + "source": [ + "n_steps = 52500\n", + "epsilon = 1\n", + "final_epsilon = 0.03\n", + "epsilon_decay = np.exp(np.log(final_epsilon) / (n_steps))\n", + "lr = 0.15" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c7c4417e-66f5-4028-8264-f4b732ba33a7", + "metadata": {}, + "outputs": [], + "source": [ + "p1_old_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n", + "game_engine.player_advance([p1_old_s_a[-1]])\n", + "\n", + "for step in range(n_steps):\n", + " p1_new_s_a = pick_greedy_action(q, p1, epsilon) # state, action\n", + " outcome = game_engine.player_advance([p1_new_s_a[-1]], 0.05)\n", + "\n", + " update_q(q, p1_old_s_a, p1_new_s_a, outcome[p1], lr)\n", + "\n", + " epsilon *= epsilon_decay\n", + " p1_old_s_a = p1_new_s_a" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6538bcdc-4b4b-48ef-8c55-76834490b698", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Draw is now True.\n" + ] + } + ], + "source": [ + "game_engine.toggle_draw()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "6e50b723-f0a4-48cd-bb98-1666ec0f59ce", + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m step \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(n_steps):\n\u001b[1;32m 3\u001b[0m s_a \u001b[38;5;241m=\u001b[39m pick_greedy_action(q, p1, epsilon)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mgame_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mplayer_advance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43ms_a\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:118\u001b[0m, in \u001b[0;36mPlayfield.player_advance\u001b[0;34m(self, actions, noise)\u001b[0m\n\u001b[1;32m 116\u001b[0m collisions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_player_collisions()\n\u001b[1;32m 117\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_draw_on:\n\u001b[0;32m--> 118\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_update_ui\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 119\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m collisions\n", + "File \u001b[0;32m~/Personal/roam/program_repository/cs498/multiagent_snake/GameEngine/multiplayer.py:163\u001b[0m, in \u001b[0;36mPlayfield._update_ui\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_goal\u001b[38;5;241m.\u001b[39mdraw()\n\u001b[1;32m 162\u001b[0m pg\u001b[38;5;241m.\u001b[39mdisplay\u001b[38;5;241m.\u001b[39mflip() \u001b[38;5;66;03m# full screen update\u001b[39;00m\n\u001b[0;32m--> 163\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_clock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtick\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_g_speed\u001b[49m\u001b[43m)\u001b[49m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "epsilon = 0\n", + "for step in range(n_steps):\n", + " s_a = pick_greedy_action(q, p1, epsilon)\n", + " game_engine.player_advance([s_a[-1]])" + ] } ], "metadata": { |