{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import gymnasium as gym\n", "from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n", "from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n", "from solarcarsim.simv1 import SolarRaceV1\n", "from stable_baselines3.common.env_checker import check_env\n", "from gymnasium.utils.env_checker import check_env as gym_check_env\n", "env = SolarRaceV1()\n", "wrapped_env = JaxToNumpy(env)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_checker.py:271: UserWarning: Your observation wind has an unconventional shape (neither an image, nor a 1D vector). We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data.\n", " warnings.warn(\n", "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:384: UserWarning: \u001b[33mWARN: The environment (>) is different from the unwrapped version (). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`.\u001b[0m\n", " logger.warn(\n", "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:434: UserWarning: \u001b[33mWARN: Not able to test alternative render modes due to the environment not having a spec. Try instantiating the environment through `gymnasium.make`\u001b[0m\n", " logger.warn(\n" ] } ], "source": [ "env.reset()\n", "check_env(wrapped_env)\n", "gym_check_env(wrapped_env)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/buffers.py:605: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 80.85GB > 53.66GB\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n", "Wrapping the env with a `Monitor` wrapper\n", "Wrapping the env in a DummyVecEnv.\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TD3\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMultiInputPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, wrapped_env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m30_000\u001b[39;49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:347\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# Special case when the user passes `gradient_steps=0`\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m gradient_steps \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 347\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgradient_steps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 349\u001b[0m callback\u001b[38;5;241m.\u001b[39mon_training_end()\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:184\u001b[0m, in \u001b[0;36mTD3.train\u001b[0;34m(self, gradient_steps, batch_size)\u001b[0m\n\u001b[1;32m 182\u001b[0m critic_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(F\u001b[38;5;241m.\u001b[39mmse_loss(current_q, target_q_values) \u001b[38;5;28;01mfor\u001b[39;00m current_q \u001b[38;5;129;01min\u001b[39;00m current_q_values)\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(critic_loss, th\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[0;32m--> 184\u001b[0m critic_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mcritic_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# Optimize the critics\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcritic\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# import a model and try it out!\n", "from sbx import TD3\n", "model = TD3(\"MultiInputPolicy\", env, verbose=1)\n", "model.learn(total_timesteps=30_000)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "vec_env = model.get_env()\n", "import matplotlib.pyplot as plt\n", "import jax.numpy as jnp\n", "obs = vec_env.reset()\n", "actions = []\n", "obs_list = []\n", "rewards = []\n", "for i in range(1000):\n", " action, _state = model.predict(obs, deterministic=True)\n", " actions.append(action)\n", " obs, reward, done, info = vec_env.step(action)\n", " obs_list.append(obs)\n", " rewards.append(reward)\n", "\n", " \n", " # VecEnv resets automatically\n", " if done:\n", " break\n", " # obs = vec_env.reset()\n", "\n", "position = jnp.array([x['position'] for x in obs_list]).flatten()\n", "energy = jnp.array([x['energy'] for x in obs_list]).flatten()\n", "actions = jnp.array(actions).flatten()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "fig, (ax1, ax2, ax3) = plt.subplots(3,1, figsize=(12,6))\n", "ax1.plot(position, label=\"position\")\n", "ax2.plot(actions, label=\"energy\")\n", "ax3.plot(rewards[0:250])\n", "# plt.legend()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", " -1., -1., -1.], dtype=float32)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "actions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 2 }