248 lines
52 KiB
Plaintext
248 lines
52 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import gymnasium as gym\n",
|
|
"from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n",
|
|
"from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n",
|
|
"from solarcarsim.simv1 import SolarRaceV1\n",
|
|
"from stable_baselines3.common.env_checker import check_env\n",
|
|
"from gymnasium.utils.env_checker import check_env as gym_check_env\n",
|
|
"env = SolarRaceV1()\n",
|
|
"wrapped_env = JaxToNumpy(env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_checker.py:271: UserWarning: Your observation wind has an unconventional shape (neither an image, nor a 1D vector). We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data.\n",
|
|
" warnings.warn(\n",
|
|
"/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:384: UserWarning: \u001b[33mWARN: The environment (<JaxToNumpy<SolarRaceV1 instance>>) is different from the unwrapped version (<SolarRaceV1 instance>). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`.\u001b[0m\n",
|
|
" logger.warn(\n",
|
|
"/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:434: UserWarning: \u001b[33mWARN: Not able to test alternative render modes due to the environment not having a spec. Try instantiating the environment through `gymnasium.make`\u001b[0m\n",
|
|
" logger.warn(\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"env.reset()\n",
|
|
"check_env(wrapped_env)\n",
|
|
"gym_check_env(wrapped_env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/buffers.py:605: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 80.85GB > 53.66GB\n",
|
|
" warnings.warn(\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using cuda device\n",
|
|
"Wrapping the env with a `Monitor` wrapper\n",
|
|
"Wrapping the env in a DummyVecEnv.\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "KeyboardInterrupt",
|
|
"evalue": "",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TD3\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMultiInputPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, wrapped_env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m30_000\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:347\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# Special case when the user passes `gradient_steps=0`\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m gradient_steps \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 347\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgradient_steps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 349\u001b[0m callback\u001b[38;5;241m.\u001b[39mon_training_end()\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:184\u001b[0m, in \u001b[0;36mTD3.train\u001b[0;34m(self, gradient_steps, batch_size)\u001b[0m\n\u001b[1;32m 182\u001b[0m critic_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(F\u001b[38;5;241m.\u001b[39mmse_loss(current_q, target_q_values) \u001b[38;5;28;01mfor\u001b[39;00m current_q \u001b[38;5;129;01min\u001b[39;00m current_q_values)\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(critic_loss, th\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[0;32m--> 184\u001b[0m critic_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mcritic_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# Optimize the critics\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcritic\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n",
|
|
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# import a model and try it out!\n",
|
|
"from sbx import TD3\n",
|
|
"model = TD3(\"MultiInputPolicy\", env, verbose=1)\n",
|
|
"model.learn(total_timesteps=30_000)\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"vec_env = model.get_env()\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import jax.numpy as jnp\n",
|
|
"obs = vec_env.reset()\n",
|
|
"actions = []\n",
|
|
"obs_list = []\n",
|
|
"rewards = []\n",
|
|
"for i in range(1000):\n",
|
|
" action, _state = model.predict(obs, deterministic=True)\n",
|
|
" actions.append(action)\n",
|
|
" obs, reward, done, info = vec_env.step(action)\n",
|
|
" obs_list.append(obs)\n",
|
|
" rewards.append(reward)\n",
|
|
"\n",
|
|
" \n",
|
|
" # VecEnv resets automatically\n",
|
|
" if done:\n",
|
|
" break\n",
|
|
" # obs = vec_env.reset()\n",
|
|
"\n",
|
|
"position = jnp.array([x['position'] for x in obs_list]).flatten()\n",
|
|
"energy = jnp.array([x['energy'] for x in obs_list]).flatten()\n",
|
|
"actions = jnp.array(actions).flatten()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"[<matplotlib.lines.Line2D at 0x79468c67a1b0>]"
|
|
]
|
|
},
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
},
|
|
{
|
|
"data": {
|
|
"image/png": "",
|
|
"text/plain": [
|
|
"<Figure size 1200x600 with 3 Axes>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"fig, (ax1, ax2, ax3) = plt.subplots(3,1, figsize=(12,6))\n",
|
|
"ax1.plot(position, label=\"position\")\n",
|
|
"ax2.plot(actions, label=\"energy\")\n",
|
|
"ax3.plot(rewards[0:250])\n",
|
|
"# plt.legend()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"Array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n",
|
|
" -1., -1., -1.], dtype=float32)"
|
|
]
|
|
},
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"actions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|