diff --git a/README.md b/README.md index 1f5ea90..4bb7924 100644 --- a/README.md +++ b/README.md @@ -1 +1,11 @@ # solarcarsim + + + +TODO: +fix wind (velocity + wind for drag) +make more functional +cleanup sim code +parameterize the environment +vectorize +cleanrl jax td3 \ No newline at end of file diff --git a/notebooks/testing.ipynb b/notebooks/testing.ipynb new file mode 100644 index 0000000..ef441e7 --- /dev/null +++ b/notebooks/testing.ipynb @@ -0,0 +1,65 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([ 0.0000000e+00, -2.5544850e-07, -1.4012958e-06, ...,\n", + " -1.1142221e-02, -1.1067827e-02, -1.1001030e-02], dtype=float32)" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from solarcarsim.physsim import CarParams, fractal_noise_1d\n", + "\n", + "\n", + "key = jax.random.key(0)\n", + "\n", + "slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08)\n", + "\n", + "slope" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# get an array of positions\n", + "positions = jnp.array([1.1,2.2,3.3,5,200.0], dtype=jnp.float32)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/v1gym.ipynb b/notebooks/v1gym.ipynb new file mode 100644 index 0000000..fb0c02d --- /dev/null +++ b/notebooks/v1gym.ipynb @@ -0,0 +1,247 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n", + "from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n", + "from solarcarsim.simv1 import SolarRaceV1\n", + "from stable_baselines3.common.env_checker import check_env\n", + "from gymnasium.utils.env_checker import check_env as gym_check_env\n", + "env = SolarRaceV1()\n", + "wrapped_env = JaxToNumpy(env)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/env_checker.py:271: UserWarning: Your observation wind has an unconventional shape (neither an image, nor a 1D vector). We recommend you to flatten the observation to have only a 1D vector or use a custom policy to properly process the data.\n", + " warnings.warn(\n", + "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:384: UserWarning: \u001b[33mWARN: The environment (>) is different from the unwrapped version (). This could effect the environment checker as the environment most likely has a wrapper applied to it. We recommend using the raw environment for `check_env` using `env.unwrapped`.\u001b[0m\n", + " logger.warn(\n", + "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/utils/env_checker.py:434: UserWarning: \u001b[33mWARN: Not able to test alternative render modes due to the environment not having a spec. Try instantiating the environment through `gymnasium.make`\u001b[0m\n", + " logger.warn(\n" + ] + } + ], + "source": [ + "env.reset()\n", + "check_env(wrapped_env)\n", + "gym_check_env(wrapped_env)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/saji/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/buffers.py:605: UserWarning: This system does not have apparently enough memory to store the complete replay buffer 80.85GB > 53.66GB\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n", + "Wrapping the env with a `Monitor` wrapper\n", + "Wrapping the env in a DummyVecEnv.\n" + ] + }, + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[3], line 4\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mstable_baselines3\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m TD3\n\u001b[1;32m 3\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMultiInputPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, wrapped_env, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m30_000\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:347\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 345\u001b[0m \u001b[38;5;66;03m# Special case when the user passes `gradient_steps=0`\u001b[39;00m\n\u001b[1;32m 346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m gradient_steps \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 347\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mgradient_steps\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 349\u001b[0m callback\u001b[38;5;241m.\u001b[39mon_training_end()\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n", + "File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:184\u001b[0m, in \u001b[0;36mTD3.train\u001b[0;34m(self, gradient_steps, batch_size)\u001b[0m\n\u001b[1;32m 182\u001b[0m critic_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28msum\u001b[39m(F\u001b[38;5;241m.\u001b[39mmse_loss(current_q, target_q_values) \u001b[38;5;28;01mfor\u001b[39;00m current_q \u001b[38;5;129;01min\u001b[39;00m current_q_values)\n\u001b[1;32m 183\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(critic_loss, th\u001b[38;5;241m.\u001b[39mTensor)\n\u001b[0;32m--> 184\u001b[0m critic_losses\u001b[38;5;241m.\u001b[39mappend(\u001b[43mcritic_loss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitem\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 186\u001b[0m \u001b[38;5;66;03m# Optimize the critics\u001b[39;00m\n\u001b[1;32m 187\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcritic\u001b[38;5;241m.\u001b[39moptimizer\u001b[38;5;241m.\u001b[39mzero_grad()\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "# import a model and try it out!\n", + "from sbx import TD3\n", + "model = TD3(\"MultiInputPolicy\", env, verbose=1)\n", + "model.learn(total_timesteps=30_000)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "vec_env = model.get_env()\n", + "import matplotlib.pyplot as plt\n", + "import jax.numpy as jnp\n", + "obs = vec_env.reset()\n", + "actions = []\n", + "obs_list = []\n", + "rewards = []\n", + "for i in range(1000):\n", + " action, _state = model.predict(obs, deterministic=True)\n", + " actions.append(action)\n", + " obs, reward, done, info = vec_env.step(action)\n", + " obs_list.append(obs)\n", + " rewards.append(reward)\n", + "\n", + " \n", + " # VecEnv resets automatically\n", + " if done:\n", + " break\n", + " # obs = vec_env.reset()\n", + "\n", + "position = jnp.array([x['position'] for x in obs_list]).flatten()\n", + "energy = jnp.array([x['energy'] for x in obs_list]).flatten()\n", + "actions = jnp.array(actions).flatten()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, (ax1, ax2, ax3) = plt.subplots(3,1, figsize=(12,6))\n", + "ax1.plot(position, label=\"position\")\n", + "ax2.plot(actions, label=\"energy\")\n", + "ax3.plot(rewards[0:250])\n", + "# plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Array([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1., -1.,\n", + " -1., -1., -1.], dtype=float32)" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "actions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/v1sim.ipynb b/notebooks/v1sim.ipynb index bb2d9fe..d076f90 100644 --- a/notebooks/v1sim.ipynb +++ b/notebooks/v1sim.ipynb @@ -564,7 +564,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -576,7 +576,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 2, "metadata": {}, "outputs": [ { diff --git a/pdm.lock b/pdm.lock index cc7b888..011a0cf 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["inherit_metadata"] lock_version = "4.5.0" -content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4" +content_hash = "sha256:a3b65f863c554725c33d452fd759776141740661fa3555d306ed08563a7e16e2" [[metadata.targets]] requires_python = ">=3.12,<3.13" @@ -152,7 +152,7 @@ version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." groups = ["default", "dev"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\" and sys_platform == \"win32\"" +marker = "sys_platform == \"win32\" and python_version >= \"3.12\" and python_version < \"3.13\" or platform_system == \"Windows\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -230,13 +230,30 @@ name = "decorator" version = "5.1.1" requires_python = ">=3.5" summary = "Decorators for Humans" -groups = ["dev"] +groups = ["default", "dev"] marker = "python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"}, {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "dm-tree" +version = "0.1.8" +summary = "Tree is a library for working with nested data structures." +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +files = [ + {file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"}, + {file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"}, + {file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"}, + {file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"}, +] + [[package]] name = "etils" version = "1.11.0" @@ -379,6 +396,47 @@ files = [ {file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"}, ] +[[package]] +name = "gast" +version = "0.6.0" +requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7" +summary = "Python AST that abstracts the underlying Python version" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +files = [ + {file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"}, + {file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"}, +] + +[[package]] +name = "gym" +version = "0.26.2" +requires_python = ">=3.6" +summary = "Gym: A universal API for reinforcement learning environments" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "cloudpickle>=1.2.0", + "dataclasses==0.8; python_version == \"3.6\"", + "gym-notices>=0.0.4", + "importlib-metadata>=4.8.0; python_version < \"3.10\"", + "numpy>=1.18.0", +] +files = [ + {file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"}, +] + +[[package]] +name = "gym-notices" +version = "0.0.8" +summary = "Notices for gym" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +files = [ + {file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"}, + {file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"}, +] + [[package]] name = "gymnasium" version = "1.0.0" @@ -417,6 +475,29 @@ files = [ {file = "gymnasium-1.0.0.tar.gz", hash = "sha256:9d2b66f30c1b34fe3c2ce7fae65ecf365d0e9982d2b3d860235e773328a3b403"}, ] +[[package]] +name = "gymnax" +version = "0.0.8" +requires_python = ">=3.10" +summary = "JAX-compatible version of Open AI's gym environments" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "chex", + "flax", + "gym>=0.26", + "gymnasium", + "jax", + "jaxlib", + "matplotlib", + "pyyaml", + "seaborn", +] +files = [ + {file = "gymnax-0.0.8-py3-none-any.whl", hash = "sha256:0af7edde1b71d74be8007ffe1e6338f8ce66693b1b78ae479c0c0cd02b10de03"}, + {file = "gymnax-0.0.8.tar.gz", hash = "sha256:81defc17f52a30a84338b3daa574d7a3bb112f2656f45c783a71efe31eea68ff"}, +] + [[package]] name = "humanize" version = "4.11.0" @@ -541,76 +622,6 @@ files = [ {file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"}, ] -[[package]] -name = "jax-cuda12-pjrt" -version = "0.4.36" -summary = "JAX XLA PJRT Plugin for NVIDIA GPUs" -groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" -files = [ - {file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"}, - {file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"}, -] - -[[package]] -name = "jax-cuda12-plugin" -version = "0.4.36" -requires_python = ">=3.10" -summary = "JAX Plugin for NVIDIA GPUs" -groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" -dependencies = [ - "jax-cuda12-pjrt==0.4.36", -] -files = [ - {file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"}, - {file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"}, -] - -[[package]] -name = "jax-cuda12-plugin" -version = "0.4.36" -extras = ["with_cuda"] -requires_python = ">=3.10" -summary = "JAX Plugin for NVIDIA GPUs" -groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" -dependencies = [ - "jax-cuda12-plugin==0.4.36", - "nvidia-cublas-cu12>=12.1.3.1", - "nvidia-cuda-cupti-cu12>=12.1.105", - "nvidia-cuda-nvcc-cu12>=12.6.85", - "nvidia-cuda-runtime-cu12>=12.1.105", - "nvidia-cudnn-cu12<10.0,>=9.1", - "nvidia-cufft-cu12>=11.0.2.54", - "nvidia-cusolver-cu12>=11.4.5.107", - "nvidia-cusparse-cu12>=12.1.0.106", - "nvidia-nccl-cu12>=2.18.1", - "nvidia-nvjitlink-cu12>=12.1.105", -] -files = [ - {file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"}, - {file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"}, -] - -[[package]] -name = "jax" -version = "0.4.37" -extras = ["cuda12"] -requires_python = ">=3.10" -summary = "Differentiate, compile, and transform Numpy code." -groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" -dependencies = [ - "jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36", - "jax==0.4.37", - "jaxlib==0.4.36", -] -files = [ - {file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"}, - {file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"}, -] - [[package]] name = "jaxlib" version = "0.4.36" @@ -926,7 +937,7 @@ version = "12.4.5.8" requires_python = ">=3" summary = "CUBLAS native runtime libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, @@ -939,26 +950,13 @@ version = "12.4.127" requires_python = ">=3" summary = "CUDA profiling tools runtime libs." groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, ] -[[package]] -name = "nvidia-cuda-nvcc-cu12" -version = "12.6.85" -requires_python = ">=3" -summary = "CUDA nvcc" -groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" -files = [ - {file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"}, - {file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"}, - {file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"}, -] - [[package]] name = "nvidia-cuda-nvrtc-cu12" version = "12.4.127" @@ -978,7 +976,7 @@ version = "12.4.127" requires_python = ">=3" summary = "CUDA Runtime native Libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, @@ -991,7 +989,7 @@ version = "9.1.0.70" requires_python = ">=3" summary = "cuDNN runtime libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" dependencies = [ "nvidia-cublas-cu12", ] @@ -1006,7 +1004,7 @@ version = "11.2.1.3" requires_python = ">=3" summary = "CUFFT native runtime libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" dependencies = [ "nvidia-nvjitlink-cu12", ] @@ -1035,7 +1033,7 @@ version = "11.6.1.9" requires_python = ">=3" summary = "CUDA solver native runtime libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" dependencies = [ "nvidia-cublas-cu12", "nvidia-cusparse-cu12", @@ -1053,7 +1051,7 @@ version = "12.3.1.170" requires_python = ">=3" summary = "CUSPARSE native runtime libraries" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" dependencies = [ "nvidia-nvjitlink-cu12", ] @@ -1069,7 +1067,7 @@ version = "2.21.5" requires_python = ">=3" summary = "NVIDIA Collective Communication Library (NCCL) Runtime" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, ] @@ -1080,7 +1078,7 @@ version = "12.4.127" requires_python = ">=3" summary = "Nvidia JIT LTO Library" groups = ["default"] -marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" files = [ {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, @@ -1653,6 +1651,29 @@ files = [ {file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"}, ] +[[package]] +name = "sbx-rl" +version = "0.18.0" +requires_python = ">=3.8" +summary = "Jax version of Stable Baselines, implementations of reinforcement learning algorithms." +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "flax", + "jax", + "jaxlib", + "optax; python_version >= \"3.9.0\"", + "optax<0.1.8; python_version < \"3.9.0\"", + "rich", + "stable-baselines3<3.0,>=2.4.0a4", + "tensorflow-probability", + "tqdm", +] +files = [ + {file = "sbx_rl-0.18.0-py3-none-any.whl", hash = "sha256:75ade634a33555ad4c4a81523bb0f99c89d1b3bc89fb74990ef87b22379abd9c"}, + {file = "sbx_rl-0.18.0.tar.gz", hash = "sha256:670f2bf095ec21ba6f8171602294baf0123787fe7be6811ebab276fb5010b8b3"}, +] + [[package]] name = "scipy" version = "1.14.1" @@ -1695,6 +1716,23 @@ files = [ {file = "scooby-0.10.0.tar.gz", hash = "sha256:7ea33c262c0cc6a33c6eeeb5648df787be4f22660e53c114e5fff1b811a8854f"}, ] +[[package]] +name = "seaborn" +version = "0.13.2" +requires_python = ">=3.8" +summary = "Statistical data visualization" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "matplotlib!=3.6.1,>=3.4", + "numpy!=1.24.0,>=1.20", + "pandas>=1.2", +] +files = [ + {file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"}, + {file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"}, +] + [[package]] name = "setuptools" version = "75.6.0" @@ -1809,6 +1847,26 @@ files = [ {file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"}, ] +[[package]] +name = "tensorflow-probability" +version = "0.25.0" +requires_python = ">=3.9" +summary = "Probabilistic modeling and statistical inference in TensorFlow" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "absl-py", + "cloudpickle>=1.3", + "decorator", + "dm-tree", + "gast>=0.3.2", + "numpy>=1.13.3", + "six>=1.10.0", +] +files = [ + {file = "tensorflow_probability-0.25.0-py2.py3-none-any.whl", hash = "sha256:f3f4d6431656c0122906888afe1b67b4400e82bd7f254b45b92e6c5b84ea8e3e"}, +] + [[package]] name = "tensorstore" version = "0.1.71" @@ -1899,6 +1957,21 @@ files = [ {file = "tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b"}, ] +[[package]] +name = "tqdm" +version = "4.67.1" +requires_python = ">=3.7" +summary = "Fast, Extensible Progress Meter" +groups = ["default"] +marker = "python_version >= \"3.12\" and python_version < \"3.13\"" +dependencies = [ + "colorama; platform_system == \"Windows\"", +] +files = [ + {file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"}, + {file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"}, +] + [[package]] name = "traitlets" version = "5.14.3" diff --git a/pyproject.toml b/pyproject.toml index 3ac6c98..1f5c82b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool" authors = [ {name = "saji", email = "saji@saji.dev"}, ] -dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"] +dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0"] requires-python = ">=3.10,<3.13" readme = "README.md" license = {text = "MIT"} diff --git a/report/report.tex b/report/report.tex new file mode 100644 index 0000000..4cddc6c --- /dev/null +++ b/report/report.tex @@ -0,0 +1,136 @@ +\documentclass[11pt]{article} + +% Essential packages +\usepackage[utf8]{inputenc} +\usepackage[T1]{fontenc} +\usepackage{amsmath,amssymb} +\usepackage{graphicx} +\usepackage[margin=1in]{geometry} +\usepackage{hyperref} +\usepackage{algorithm} +\usepackage{algorithmic} +\usepackage{float} +\usepackage{booktabs} +\usepackage{caption} +\usepackage{subcaption} + +% Custom commands +\newcommand{\sectionheading}[1]{\noindent\textbf{#1}} + +% Title and author information +\title{\Large{Your Project Title: A Study in Optimal Control and Reinforcement Learning}} +\author{Your Name\\ + Course Name\\ + Institution} +\date{\today} + +\begin{document} + +\maketitle + +\begin{abstract} +Solar Racing is a competition with the goal of creating highly efficient solar-assisted electric vehicles. Effective solar racing +requires awareness and complex decision making to determine optimal speeds to exploit the environmental conditions, such as winds, +cloud cover, and changes in elevation. We present an environment modelled on the dynamics involved for a race, including generated +elevation and wind profiles. The model uses the \texttt{gymnasium} interface to allow it to be used by a variety of algorithms. +We demonstrate a method of designing reward functions for multi-objective problems. The environment shows to be solvable by modern +reinforcement learning algorithms. +\end{abstract} + +\section{Introduction} +Start with a broad context of your problem area in optimal control/reinforcement learning. Then narrow down to your specific focus. Include: +\begin{itemize} + \item Problem motivation + \item Brief overview of existing approaches + \item Your specific contributions + \item Paper organization +\end{itemize} + +Solar racing was invented in the early 90s as a technology incubator for high-efficiency motor vehicles. The first solar races were speed +focused, however a style of race that focused on minimal energy use within a given route was developed to push focus towards vehicle efficiency. +The goal of these races is to arrive at a destination within a given time frame, while using as little grid (non-solar) energy as possible. +Aerodynamic drag is one of the most significant sources of energy consumption, along with elevation changes. The simplest policy to meet +the constraints of on-time arrival is: +$$ +V_{\text{avg}} = \frac{D}{T} +$$ +Where $D$ is the distance needed to travel, and $T$ is the maximum allowed time. + +\section{Background} +Provide necessary background on: +\begin{itemize} + \item Your specific application domain + \item Relevant algorithms and methods + \item Previous work in this area +\end{itemize} + +\section{Methodology} +Describe your approach in detail: +\begin{itemize} + \item Problem formulation + \item Algorithm description + \item Implementation details +\end{itemize} + +% Example of how to include an algorithm +\begin{algorithm}[H] +\caption{Your Algorithm Name} +\begin{algorithmic}[1] +\STATE Initialize parameters +\WHILE{not converged} + \STATE Update step +\ENDWHILE +\RETURN Result +\end{algorithmic} +\end{algorithm} + +\section{Experiments and Results} +Present your findings: +\begin{itemize} + \item Experimental setup + \item Results and analysis + \item Comparison with baselines (if applicable) +\end{itemize} + +% Example of how to include figures +\begin{figure}[H] +\centering +\caption{Description of your figure} +\label{fig:example} +\end{figure} + +% Example of how to include tables +\begin{table}[H] +\centering +\caption{Your Table Caption} +\begin{tabular}{lcc} +\toprule +Method & Metric 1 & Metric 2 \\ +\midrule +Approach 1 & Value & Value \\ +Approach 2 & Value & Value \\ +\bottomrule +\end{tabular} +\label{tab:results} +\end{table} + +\section{Discussion} +Analyze your results: +\begin{itemize} + \item Interpretation of findings + \item Limitations and challenges + \item Potential improvements +\end{itemize} + +\section{Conclusion} +Summarize your work: +\begin{itemize} + \item Key contributions + \item Practical implications + \item Future work directions +\end{itemize} + +\bibliography{references} +\bibliographystyle{plain} + +\end{document} diff --git a/src/solarcarsim/physsim.py b/src/solarcarsim/physsim.py index 368f6db..da14ba1 100644 --- a/src/solarcarsim/physsim.py +++ b/src/solarcarsim/physsim.py @@ -115,7 +115,6 @@ def bldc_power_draw(torque, velocity, params: MotorParams): return total_power -# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf']) @jit def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf): bemf = velocity / kv @@ -132,7 +131,6 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf): @partial( jit, static_argnums=( - 1, 2, ), ) @@ -144,47 +142,12 @@ def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]): return jnp.sum(cell_Ploss) -def forward(state, timestep, control, params: CarParams): - # state is (position, time, energy) - # control is velocity - # timestep is >0 time to advance - # params is the params dictionary. - # returns the next state with (position', time + timestep, energy') - # TODO: terrain, weather, solar - - # determine the forces acting on the car. - dragf = drag_force(control, params.frontal_area, params.drag_coeff, 1.184) - rollf = rolling_force(params.mass, 0, params.rolling_coeff) - hillforce = downslope_force(params.mass, 0) - totalf = dragf + rollf + hillforce - # determine the power needed to make this force - tau = params.wheel_radius * totalf - pdraw = bldc_power_draw(tau, control, params.motor) - net_power = 0 - pdraw # watts aka j/s - - # TODO: calculate battery-based power losses. - # TODO: support regenerative braking when going downhill - # TODO: delta x = cos(theta) * velocity * timestep - - new_state = jnp.array( - [ - state[0] + control * timestep, - state[1] + timestep, - state[2] + net_power * timestep, - ] - ) - return new_state - - def make_environment(seed): """Generate a race environment: terrain function, wind function, wrapped forward function.""" - key, subkey = jax.random.split(seed) - wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000) - key, subkey = jax.random.split(key) - slope = fractal_noise_1d(subkey, 10000, scale=1200, height_scale=0.08) + windkey, slopekey = jax.random.split(seed, 2) + wind = generate_wind_field(windkey, 10000, 600, spatial_scale=1000) + slope = fractal_noise_1d(slopekey, 10000, scale=1200, height_scale=0.08) elevation = jnp.cumsum(slope) - # elevation = generate_elevation_profile(subkey, 10000, height_variation=40.0, scale=1200, octaves=5) - # slope = jnp.arctan(jnp.diff(elevation, prepend=100.0)) # rise/run return wind, elevation, slope diff --git a/src/solarcarsim/simv1.py b/src/solarcarsim/simv1.py index dee29b8..47fe60c 100644 --- a/src/solarcarsim/simv1.py +++ b/src/solarcarsim/simv1.py @@ -4,11 +4,9 @@ import jax import jax.numpy as jnp from jax import jit from functools import partial -from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw - @partial(jit, static_argnames=["params"]) -def forwardv2(state, control, delta_time, wind, elevation, slope, params): +def forward(state, control, delta_time, wind, elevation, slope, params: sim.CarParams): pos = jnp.astype(jnp.round(state[0]), "int32") time = jnp.astype(jnp.round(state[1]), "int32") theta = slope[pos] @@ -23,23 +21,22 @@ def forwardv2(state, control, delta_time, wind, elevation, slope, params): totalf = dragf + rollf + hillforce # with the sum of forces, determine the needed torque at the wheels, and then power tau = params.wheel_radius * totalf - pdraw = bldc_power_draw(tau, velocity, params.motor) + pdraw = sim.bldc_power_draw(tau, velocity, params.motor) # determine the energy needed to do this power for the time step net_power = state[2] - delta_time * pdraw # joules - dpos = jnp.cos(theta) * velocity * delta_time - dist_remaining = 10000.0 - dpos + dpos = state[0] + jnp.cos(theta) * velocity * delta_time + new_pos = jnp.maximum(dpos, 0) + dist_remaining = 10000.0 - (state[0] + dpos) time_remaining = 600 - (state[1] + delta_time) return jnp.array( - [dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining] + [new_pos, state[1] + delta_time, net_power, dist_remaining, time_remaining] ) -def reward(state, prev_energy): - progress = state[0] / 10000 * 100 - energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy. - time_factor = (1.0 - (state[1] / 600)) * 50 - reward = progress + energy_usage + time_factor +def reward(state, prev_state): + reward = 0 + reward += state[0]/8000 return reward class SolarRaceV1(gym.Env): @@ -53,7 +50,6 @@ class SolarRaceV1(gym.Env): # self._state = jnp.array([np.array([x], dtype="float32") for x in (0,0,0, 10000.0, 600.0)]) self._state = jnp.array([[0],[0],[0],[10000.0], [600.0]]) # self._state = jnp.array([0, 0,0,10000.0, 600.0]) - def _vision_function(self): # extract the vision results. def slookup(x): @@ -81,7 +77,7 @@ class SolarRaceV1(gym.Env): self._reset_sim(jax.random.key(seed)) self._timestep = timestep self._car = car - self._simstep = forwardv2 + self._simstep = forward self._simreward = reward self.observation_space = gym.spaces.Dict( @@ -108,15 +104,20 @@ class SolarRaceV1(gym.Env): def step(self, action): wind, elevation, slope = self._environment - old_energy = self._state[2] + old_state = self._state self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car) - reward = self._simreward(self._state, old_energy)[0] + reward = self._simreward(self._state, old_state)[0] terminated = False truncated = False - if jnp.all(self._state[0] > 10000): + if jnp.all(self._state[0] > 8000): + reward += 500 terminated = True - if self._state[1] > 600: + # we want the time to be as close to 600 as possible + reward -= 600 - self._state[1][0] + reward += 1e-6 * (self._state[2][0]) # net energy is negative. + if jnp.all(self._state[1] > 600): + reward -= 50000 truncated = True return self._get_obs(), reward, terminated, truncated, {} diff --git a/src/solarcarsim/simv2.py b/src/solarcarsim/simv2.py index 60a66d2..fffe8c0 100644 --- a/src/solarcarsim/simv2.py +++ b/src/solarcarsim/simv2.py @@ -1,14 +1,190 @@ -""" Second-generation simulator. More functional, cleaner code, faster """ +"""Second-generation simulator. More functional, cleaner code, faster""" -from typing import NamedTuple +from typing import NamedTuple, Optional, Tuple, Union, Dict, Any import jax import jax.numpy as jnp +import chex +from flax import struct +from jax import lax +from gymnax.environments import environment +from gymnax.environments import spaces + +from solarcarsim.physsim import CarParams, fractal_noise_1d +import solarcarsim.physsim as sim -class SimState(NamedTuple): - position: float - time: float - energy: float - distance_remaining: float - time_remaining: float +@struct.dataclass +class SimState(environment.EnvState): + position: jnp.ndarray + velocity: jnp.ndarray + realtime: jnp.ndarray + energy: jnp.ndarray + # distance_remaining: jnp.ndarray + # time_remaining: jnp.ndarray + slope: jnp.ndarray + +@struct.dataclass +class SimParams(environment.EnvParams): + car: CarParams = CarParams() + goal_time: int = 600 + goal_dist: int = 8000 + map_size: int = 10000 + time_step: float = 1.0 + terrain_lookahead: int = 100 + # skip wind for now + + +class Snax(environment.Environment[SimState, SimParams]): + """JAX version of the solar race simulator""" + + @property + def default_params(self) -> SimParams: + return SimParams() + + def action_space(self, params: Optional[SimParams] = None): + return spaces.Box(low=-1.0, high=1.0, shape=(1,)) + + def observation_space(self, params: Optional[SimParams] = None) -> spaces.Box: + if params is None: + params = self.default_params + # needs to be a box. it will be [pos, time, energy, dist_to_goal, time_remaining, terrain0, terrain1] + shape = 5 + params.terrain_lookahead + low = jnp.array( + [0, 0, -1e11, 0, 0] + [-1.0] * params.terrain_lookahead, dtype=jnp.float32 + ) + high = jnp.array( + [params.map_size, params.goal_time, 0, params.goal_dist, params.goal_time] + + [1.0] * params.terrain_lookahead, + dtype=jnp.float32, + ) + return spaces.Box(low, high, shape=(shape,)) + # return spaces.Dict( + # { + # "position": spaces.Box(0.0, params.map_size, (), jnp.float32), + # "realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32), + # "energy": spaces.Box(-1e11, 0.0, (), jnp.float32), + # "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32), + # "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32), + # "upcoming_terrain": spaces.Box( + # -1.0, 1.0, shape=(100,), dtype=jnp.float32 + # ), + # # skip wind for now + # } + # ) + + def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict: + if params is None: + params = self.default_params + return spaces.Dict( + { + "position": spaces.Box(0.0, params.map_size, (), jnp.float32), + "realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32), + "energy": spaces.Box(-1e11, 0.0, (), jnp.float32), + # "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32), + # "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32), + "slope": spaces.Box( + -1.0, 1.0, shape=(params.map_size,), dtype=jnp.float32 + ), + "time": spaces.Discrete(int(params.goal_time / params.time_step)), + } + ) + + def reset_env( + self, key: chex.PRNGKey, params: Optional[SimParams] = None + ) -> Tuple[chex.Array, SimState]: + if params is None: + params = self.default_params + slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08) + init_state = SimState( + position=jnp.array(0.0), + velocity=jnp.array(0.0), + time=0, + realtime=jnp.array(0.0), + energy=jnp.array(0.0), + # distance_remaining=jnp.array(params.goal_dist), + # time_remaining=jnp.array(params.goal_time), + slope=slope, + ) + return self.get_obs(init_state, key, params), init_state + + def get_obs( + self, state: SimState, key: chex.PRNGKey, params: SimParams + ) -> chex.Array: + if params is None: + params = self.default_params + + # get rounded position from state + pos_int = jnp.astype(state.position, jnp.int32) + + terrain_view = jax.lax.dynamic_slice(state.slope, (pos_int,), (100,)) + dist_to_goal = jnp.abs(params.goal_dist - state.position) + time_remaining = jnp.abs(params.goal_time - state.realtime) + main_state = jnp.array( + [state.position, state.realtime, state.energy, dist_to_goal, time_remaining] + ) + return jnp.concat([main_state, terrain_view]).squeeze() + + def step_env( + self, + key: chex.PRNGKey, + state: SimState, + action: Union[int, float, chex.Array], + params: SimParams, + ) -> Tuple[chex.Array, SimState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]: + pos = jnp.astype(state.position, jnp.int32) + theta = state.slope[pos] + velocity = jnp.array([action * params.car.max_speed]).squeeze() + dragf = sim.drag_force( + velocity, params.car.frontal_area, params.car.drag_coeff, 1.184 + ) + rollf = sim.rolling_force(params.car.mass, theta, params.car.rolling_coeff) + hillf = sim.downslope_force(params.car.mass, theta) + total_f = dragf + rollf + hillf + tau = params.car.wheel_radius * total_f / params.car.n_motors + p_draw = ( + sim.bldc_power_draw(tau, velocity, params.car.motor) * params.car.n_motors + ) + + new_energy = state.energy - params.time_step * p_draw + new_position = state.position + jnp.cos(theta) * velocity * params.time_step + new_state = SimState( + position=new_position.squeeze(), + velocity=velocity.squeeze(), + realtime=state.realtime + params.time_step, + energy=new_energy.squeeze(), + slope=state.slope, + time=state.time + 1, + ) + + # compute reward + # reward = new_state.position / params.goal_dist + # if new_state.position >= params.goal_dist: + # reward += 100 + # reward += params.goal_time - new_state.realtime + # # penalize energy use + # reward += 1e-7 * new_state.energy # energy is negative + # if ( + # new_state.realtime >= params.goal_time + # or new_state.time > params.max_steps_in_episode + # ): + # reward -= 500 + + # we have to vectorize that. + reward = new_state.position / params.goal_dist + \ + (new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy) + \ + (new_state.realtime >= params.goal_time) * -500 + reward = reward.squeeze() + terminal = self.is_terminal(state, params) + return ( + lax.stop_gradient(self.get_obs(new_state, key, params)), + lax.stop_gradient(new_state), + reward, + terminal, + {}, + ) + + def is_terminal(self, state: SimState, params: SimParams) -> jnp.ndarray: + finish = state.position >= params.goal_dist + timeout = state.time >= params.max_steps_in_episode + return jnp.logical_or(finish, timeout).squeeze()