final bit

This commit is contained in:
saji 2024-12-20 16:58:51 -06:00
parent 70a659f468
commit 7ad7070129
11 changed files with 6448 additions and 213 deletions

110
notebooks/gymv2_jax.ipynb Normal file
View file

@ -0,0 +1,110 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"\n",
"from gymnasium.utils.env_checker import check_env as gym_check_env\n",
"from stable_baselines3 import TD3\n",
"from stable_baselines3.common.env_checker import check_env\n",
"from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n",
"from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n",
"from gymnax.wrappers.gym import GymnaxToVectorGymWrapper, GymnaxToGymWrapper\n",
"import matplotlib.pyplot as plt\n",
"import jax.numpy as jnp\n",
"from solarcarsim.simv2 import Snax\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"\n",
"env = Snax()\n",
"wrapped_env = GymnaxToGymWrapper(env)\n",
"vector_gym_env = GymnaxToVectorGymWrapper(env)\n",
"np_wrapper = JaxToNumpy(wrapped_env)\n",
"np_vec_wrapper = VJaxToNumpy(vector_gym_env)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n",
"Wrapping the env with a `Monitor` wrapper\n",
"Wrapping the env in a DummyVecEnv.\n"
]
},
{
"ename": "ValueError",
"evalue": "Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, np_wrapper, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:328\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_freq, TrainFreq) \u001b[38;5;66;03m# check done in _setup_learn()\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 328\u001b[0m rollout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect_rollouts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_freq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[43m \u001b[49m\u001b[43maction_noise\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maction_noise\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 332\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43mlearning_starts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearning_starts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mreplay_buffer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreplay_buffer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m rollout\u001b[38;5;241m.\u001b[39mcontinue_training:\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:560\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.collect_rollouts\u001b[0;34m(self, env, callback, train_freq, replay_buffer, action_noise, learning_starts, log_interval)\u001b[0m\n\u001b[1;32m 557\u001b[0m actions, buffer_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sample_action(learning_starts, action_noise, env\u001b[38;5;241m.\u001b[39mnum_envs)\n\u001b[1;32m 559\u001b[0m \u001b[38;5;66;03m# Rescale and perform action\u001b[39;00m\n\u001b[0;32m--> 560\u001b[0m new_obs, rewards, dones, infos \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mnum_envs\n\u001b[1;32m 563\u001b[0m num_collected_steps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/base_vec_env.py:206\u001b[0m, in \u001b[0;36mVecEnv.step\u001b[0;34m(self, actions)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;124;03mStep the environments with the given action\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \n\u001b[1;32m 202\u001b[0m \u001b[38;5;124;03m:param actions: the action\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;124;03m:return: observation, reward, done, information\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstep_async(actions)\n\u001b[0;32m--> 206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_wait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:58\u001b[0m, in \u001b[0;36mDummyVecEnv.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep_wait\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m VecEnvStepReturn:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# Avoid circular imports\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m env_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_envs):\n\u001b[0;32m---> 58\u001b[0m obs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_rews[env_idx], terminated, truncated, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_infos[env_idx] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menvs\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactions\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;66;03m# convert to SB3 VecEnv api\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_dones[env_idx] \u001b[38;5;241m=\u001b[39m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/monitor.py:94\u001b[0m, in \u001b[0;36mMonitor.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneeds_reset:\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTried to step environment that needs reset\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 94\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrewards\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mfloat\u001b[39m(reward))\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated:\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/wrappers/jax_to_numpy.py:166\u001b[0m, in \u001b[0;36mJaxToNumpy.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Transforms the action to a jax array .\u001b[39;00m\n\u001b[1;32m 158\u001b[0m \n\u001b[1;32m 159\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;124;03m A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.\u001b[39;00m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m jax_action \u001b[38;5;241m=\u001b[39m numpy_to_jax(action)\n\u001b[0;32m--> 166\u001b[0m obs, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax_action\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 169\u001b[0m jax_to_numpy(obs),\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28mfloat\u001b[39m(reward),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 173\u001b[0m jax_to_numpy(info),\n\u001b[1;32m 174\u001b[0m )\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/wrappers/gym.py:70\u001b[0m, in \u001b[0;36mGymnaxToGymWrapper.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Step environment, follow new step API.\"\"\"\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng, step_key \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng)\n\u001b[0;32m---> 70\u001b[0m o, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_state, r, d, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mstep_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_params\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m o, r, d, d, info\n",
" \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/environments/environment.py:45\u001b[0m, in \u001b[0;36mEnvironment.step\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 43\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_params\n\u001b[1;32m 44\u001b[0m key, key_reset \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(key)\n\u001b[0;32m---> 45\u001b[0m obs_st, state_st, reward, done, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m obs_re, state_re \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_env(key_reset, params)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Auto-reset environment based on termination\u001b[39;00m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/src/solarcarsim/simv2.py:138\u001b[0m, in \u001b[0;36mSnax.step_env\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 136\u001b[0m theta \u001b[38;5;241m=\u001b[39m state\u001b[38;5;241m.\u001b[39mslope[pos]\n\u001b[1;32m 137\u001b[0m velocity \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([action \u001b[38;5;241m*\u001b[39m params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmax_speed])\u001b[38;5;241m.\u001b[39msqueeze()\n\u001b[0;32m--> 138\u001b[0m dragf \u001b[38;5;241m=\u001b[39m \u001b[43msim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_force\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[43m \u001b[49m\u001b[43mvelocity\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrontal_area\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_coeff\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.184\u001b[39;49m\n\u001b[1;32m 140\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m rollf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mrolling_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta, params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mrolling_coeff)\n\u001b[1;32m 142\u001b[0m hillf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mdownslope_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta)\n",
" \u001b[0;31m[... skipping hidden 3 frame]\u001b[0m\n",
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/jax/_src/pjit.py:768\u001b[0m, in \u001b[0;36m_infer_params\u001b[0;34m(fun, ji, args, kwargs)\u001b[0m\n\u001b[1;32m 764\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,\n\u001b[1;32m 765\u001b[0m kwargs, in_avals\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 766\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m p, p\u001b[38;5;241m.\u001b[39mconsts \u001b[38;5;241m+\u001b[39m args_flat\n\u001b[0;32m--> 768\u001b[0m entry \u001b[38;5;241m=\u001b[39m \u001b[43m_infer_params_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 769\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mji\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msignature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpjit_mesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresource_env\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m entry\u001b[38;5;241m.\u001b[39mpjit_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 771\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(\n\u001b[1;32m 772\u001b[0m fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals\u001b[38;5;241m=\u001b[39mavals)\n",
"\u001b[0;31mValueError\u001b[0m: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n"
]
}
],
"source": [
"\n",
"model = TD3(\"MlpPolicy\", np_wrapper, verbose=1)\n",
"model.learn(total_timesteps=1000)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

273
pdm.lock
View file

@ -5,7 +5,7 @@
groups = ["default", "dev"]
strategy = ["inherit_metadata"]
lock_version = "4.5.0"
content_hash = "sha256:a3b65f863c554725c33d452fd759776141740661fa3555d306ed08563a7e16e2"
content_hash = "sha256:2f7c4bee801973a3b7856ba0707891eb01fd05659948707f44be4aa302e5dabd"
[[metadata.targets]]
requires_python = ">=3.12,<3.13"
@ -237,6 +237,27 @@ files = [
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
name = "distrax"
version = "0.1.5"
requires_python = ">=3.9"
summary = "Distrax: Probability distributions in JAX."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"absl-py>=0.9.0",
"chex>=0.1.8",
"jax>=0.1.55",
"jaxlib>=0.1.67",
"numpy>=1.23.0",
"setuptools; python_version >= \"3.12\"",
"tensorflow-probability>=0.15.0",
]
files = [
{file = "distrax-0.1.5-py3-none-any.whl", hash = "sha256:5020f4b53a9a480d019c12e44292fbacb7de857cce478bc594dacf29519c61b7"},
{file = "distrax-0.1.5.tar.gz", hash = "sha256:ec41522d389af69efedc8d475a7e6d8f229429c00f2140dcd641feacf7e21948"},
]
[[package]]
name = "dm-tree"
version = "0.1.8"
@ -254,6 +275,18 @@ files = [
{file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"},
]
[[package]]
name = "docstring-parser"
version = "0.16"
requires_python = ">=3.6,<4.0"
summary = "Parse Python docstrings in reST, Google and Numpydoc format"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"},
{file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"},
]
[[package]]
name = "etils"
version = "1.11.0"
@ -408,6 +441,26 @@ files = [
{file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"},
]
[[package]]
name = "grpcio"
version = "1.68.1"
requires_python = ">=3.8"
summary = "HTTP/2-based RPC framework"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"},
{file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"},
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"},
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"},
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"},
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"},
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"},
{file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"},
{file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"},
{file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"},
]
[[package]]
name = "gym"
version = "0.26.2"
@ -622,6 +675,76 @@ files = [
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
]
[[package]]
name = "jax-cuda12-pjrt"
version = "0.4.36"
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-pjrt==0.4.36",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
extras = ["with_cuda"]
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin==0.4.36",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.6.85",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12<10.0,>=9.1",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
"nvidia-nvjitlink-cu12>=12.1.105",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax"
version = "0.4.37"
extras = ["cuda12"]
requires_python = ">=3.10"
summary = "Differentiate, compile, and transform Numpy code."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
"jax==0.4.37",
"jaxlib==0.4.36",
]
files = [
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
]
[[package]]
name = "jaxlib"
version = "0.4.36"
@ -737,6 +860,21 @@ files = [
{file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"},
]
[[package]]
name = "markdown"
version = "3.7"
requires_python = ">=3.8"
summary = "Python implementation of John Gruber's Markdown."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"importlib-metadata>=4.4; python_version < \"3.10\"",
]
files = [
{file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"},
{file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"},
]
[[package]]
name = "markdown-it-py"
version = "3.0.0"
@ -937,7 +1075,7 @@ version = "12.4.5.8"
requires_python = ">=3"
summary = "CUBLAS native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
@ -950,13 +1088,26 @@ version = "12.4.127"
requires_python = ">=3"
summary = "CUDA profiling tools runtime libs."
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
]
[[package]]
name = "nvidia-cuda-nvcc-cu12"
version = "12.6.85"
requires_python = ">=3"
summary = "CUDA nvcc"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
@ -976,7 +1127,7 @@ version = "12.4.127"
requires_python = ">=3"
summary = "CUDA Runtime native Libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
@ -989,7 +1140,7 @@ version = "9.1.0.70"
requires_python = ">=3"
summary = "cuDNN runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-cublas-cu12",
]
@ -1004,7 +1155,7 @@ version = "11.2.1.3"
requires_python = ">=3"
summary = "CUFFT native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-nvjitlink-cu12",
]
@ -1033,7 +1184,7 @@ version = "11.6.1.9"
requires_python = ">=3"
summary = "CUDA solver native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-cublas-cu12",
"nvidia-cusparse-cu12",
@ -1051,7 +1202,7 @@ version = "12.3.1.170"
requires_python = ">=3"
summary = "CUSPARSE native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-nvjitlink-cu12",
]
@ -1067,7 +1218,7 @@ version = "2.21.5"
requires_python = ">=3"
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
]
@ -1078,7 +1229,7 @@ version = "12.4.127"
requires_python = ">=3"
summary = "Nvidia JIT LTO Library"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
@ -1759,6 +1910,18 @@ files = [
{file = "shiboken6-6.8.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:b11e750e696bb565d897e0f5836710edfb86bd355f87b09988bd31b2aad404d3"},
]
[[package]]
name = "shtab"
version = "1.7.1"
requires_python = ">=3.7"
summary = "Automagic shell tab completion for Python CLI applications"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "shtab-1.7.1-py3-none-any.whl", hash = "sha256:32d3d2ff9022d4c77a62492b6ec875527883891e33c6b479ba4d41a51e259983"},
{file = "shtab-1.7.1.tar.gz", hash = "sha256:4e4bcb02eeb82ec45920a5d0add92eac9c9b63b2804c9196c1f1fdc2d039243c"},
]
[[package]]
name = "simplejson"
version = "3.19.3"
@ -1847,6 +2010,42 @@ files = [
{file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"},
]
[[package]]
name = "tensorboard"
version = "2.18.0"
requires_python = ">=3.9"
summary = "TensorBoard lets you watch Tensors Flow"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"absl-py>=0.4",
"grpcio>=1.48.2",
"markdown>=2.6.8",
"numpy>=1.12.0",
"packaging",
"protobuf!=4.24.0,>=3.19.6",
"setuptools>=41.0.0",
"six>1.9",
"tensorboard-data-server<0.8.0,>=0.7.0",
"werkzeug>=1.0.1",
]
files = [
{file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"},
]
[[package]]
name = "tensorboard-data-server"
version = "0.7.2"
requires_python = ">=3.7"
summary = "Fast data loading for TensorBoard"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
{file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
{file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
]
[[package]]
name = "tensorflow-probability"
version = "0.25.0"
@ -1997,6 +2196,22 @@ files = [
{file = "triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc"},
]
[[package]]
name = "typeguard"
version = "4.4.1"
requires_python = ">=3.9"
summary = "Run-time type checker for Python"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"importlib-metadata>=3.6; python_version < \"3.10\"",
"typing-extensions>=4.10.0",
]
files = [
{file = "typeguard-4.4.1-py3-none-any.whl", hash = "sha256:9324ec07a27ec67fc54a9c063020ca4c0ae6abad5e9f0f9804ca59aee68c6e21"},
{file = "typeguard-4.4.1.tar.gz", hash = "sha256:0d22a89d00b453b47c49875f42b6601b961757541a2e1e0ef517b6e24213c21b"},
]
[[package]]
name = "typing-extensions"
version = "4.12.2"
@ -2009,6 +2224,29 @@ files = [
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
]
[[package]]
name = "tyro"
version = "0.9.2"
requires_python = ">=3.7"
summary = "CLI interfaces & config objects, from types"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"backports-cached-property>=1.0.2; python_version < \"3.8\"",
"colorama>=0.4.0; platform_system == \"Windows\"",
"docstring-parser>=0.16",
"eval-type-backport>=0.1.3; python_version < \"3.10\"",
"rich>=11.1.0",
"shtab>=1.5.6",
"typeguard>=4.0.0",
"typing-extensions>=4.7.0; python_version < \"3.8\"",
"typing-extensions>=4.9.0; python_version >= \"3.8\"",
]
files = [
{file = "tyro-0.9.2-py3-none-any.whl", hash = "sha256:f7c301b30b1ac7b18672f234e45013786c494d64c0e3621b25b8414637af8f90"},
{file = "tyro-0.9.2.tar.gz", hash = "sha256:692687e07c1ed35cc3a841e8c4a188424023f16bdef37f2d9c23cbeb8a3b51aa"},
]
[[package]]
name = "tzdata"
version = "2024.2"
@ -2063,6 +2301,21 @@ files = [
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
]
[[package]]
name = "werkzeug"
version = "3.1.3"
requires_python = ">=3.9"
summary = "The comprehensive WSGI web application library."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"MarkupSafe>=2.1.1",
]
files = [
{file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"},
{file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"},
]
[[package]]
name = "zipp"
version = "3.21.0"

View file

@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
authors = [
{name = "saji", email = "saji@saji.dev"},
]
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0"]
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0", "tyro>=0.9.2", "tensorboard>=2.18.0", "distrax>=0.1.5"]
requires-python = ">=3.10,<3.13"
readme = "README.md"
license = {text = "MIT"}

View file

@ -7,21 +7,23 @@
\usepackage{graphicx}
\usepackage[margin=1in]{geometry}
\usepackage{hyperref}
\usepackage{pdfpages}
\usepackage{algorithm}
\usepackage{algorithmic}
\usepackage{float}
\usepackage{booktabs}
\usepackage{caption}
\usepackage{subcaption}
\usepackage{tikz}
% Custom commands
\newcommand{\sectionheading}[1]{\noindent\textbf{#1}}
% Title and author information
\title{\Large{Your Project Title: A Study in Optimal Control and Reinforcement Learning}}
\author{Your Name\\
Course Name\\
Institution}
\title{\Large{Solarcarsim: A Solar Racing Environment for RL Agents}}
\author{Saji Champlin\\
EE5241
}
\date{\today}
\begin{document}
@ -33,102 +35,273 @@ Solar Racing is a competition with the goal of creating highly efficient solar-a
requires awareness and complex decision making to determine optimal speeds to exploit the environmental conditions, such as winds,
cloud cover, and changes in elevation. We present an environment modelled on the dynamics involved for a race, including generated
elevation and wind profiles. The model uses the \texttt{gymnasium} interface to allow it to be used by a variety of algorithms.
We demonstrate a method of designing reward functions for multi-objective problems. The environment shows to be solvable by modern
reinforcement learning algorithms.
We demonstrate a method of designing reward functions for multi-objective problems. We show learning using an Jax-based PPO model.
\end{abstract}
\section{Introduction}
Start with a broad context of your problem area in optimal control/reinforcement learning. Then narrow down to your specific focus. Include:
\begin{itemize}
\item Problem motivation
\item Brief overview of existing approaches
\item Your specific contributions
\item Paper organization
\end{itemize}
Solar racing was invented in the early 90s as a technology incubator for high-efficiency motor vehicles. The first solar races were speed
focused, however a style of race that focused on minimal energy use within a given route was developed to push focus towards vehicle efficiency.
The goal of these races is to arrive at a destination within a given time frame, while using as little grid (non-solar) energy as possible.
Aerodynamic drag is one of the most significant sources of energy consumption, along with elevation changes. The simplest policy to meet
the constraints of on-time arrival is:
$$
V_{\text{avg}} = \frac{D}{T}
$$
Where $D$ is the distance needed to travel, and $T$ is the maximum allowed time.
Optimal driving is a complex policy based on terrain slope, wind forecasts, and solar forecasts.
Direct solutions to find the global minima of the energy usage on a route
segment are difficult to compute. Instead, we present a reinforcement learning
environment that can be used to train RL agents to race efficiently given
limited foresight. The environment simulates key components of the race, such
as terrain and wind, as well as car dynamics. The simulator is written using
the Jax\cite{jax2018github} library which enables computations to be offloaded to the GPU. We
provide wrappers for the \texttt{gymnasium} API as well as a
\texttt{purejaxrl}\cite{lu2022discovered} implementation which can train a PPO
agent with millions of timesteps in several minutes. We present an exploration of reward function design with regards to sparsity,
magnitude, and learning efficiency.
\section{Background}
Provide necessary background on:
\begin{itemize}
\item Your specific application domain
\item Relevant algorithms and methods
\item Previous work in this area
\end{itemize}
Performance evaluation for solar races typically take the form of
$$
S = D/E \times T
$$
Where $S$ is the score, $D$ is the distance travelled, $E$ is the energy consumed, and $T$ is the speed derating.
The speed derate is calculated based on a desired average speed throughout the race. If average velocity is at or above the target,
$T=1$, however $T$ approaches $0$ exponentially as the average velocity goes below the target.
Based on this metric we conclude that the optimal score:
\begin{enumerate}
\item Maintains an average speed $V_{\text{avg}}$ as required by the derate.
\item Minimizes energy usage otherwise.
\end{enumerate}
The simplest control to meet the constraints of on-time arrival is:
$$
V_{\text{avg}} = \frac{D_{goal}}{T_{goal}}
$$
Where $D_{goal}$ is the distance needed to travel, and $T_{goal}$ is the maximum allowed time. The average speed is nearly-optimal in most cases, but
is not a globally optimal solution. Consider a small - there is much more energy being used from the battery when going uphill,
but the same energy is returned to the car going downhill. Losses in the vehicle dictate that it is more effective to drive slowly
up the hill, and speed up on the descent. The decision is further complicated by wind and cloud cover, which can aid or neuter the
performance of the car. It is therefore of great interest for solar racing teams to have advanced strategies that can effectively
traverse the terrain while minimizing environmental resistances.
Existing research on this subject is limited, as advanced solar car strategy is a competitive differentiator and is usually kept secret.
However, the author knows that most of the work on this subject involves use of Modelica or similar acausal system simulators, and
non-linear solvers that use multi-starts to attempt to find the global optimum. Other methods include exhaustive search, genetic algorithms,
and Big Bang-Big Crunch optimization\cite{heuristicsolar}.
We start by analyzing a simple force-based model of the car, and then connect this to an energy system using motor equations. We generate
a simulated environment including terrain and wind. Then, we develop a reward system that encapsulates the goals of the environment.
Finally, we train off-the-shelf RL models from Stable Baselines3 and purejaxrl to show learning on the environment.
\section{Methodology}
Describe your approach in detail:
\begin{itemize}
\item Problem formulation
\item Algorithm description
\item Implementation details
\end{itemize}
% \begin{tikzpicture}[scale=1.5]
% % Define slope angle
% \def\angle{30}
%
% % Draw ground/slope
% \draw[thick] (-3,0) -- (3,0);
% \draw[thick] (-2,0) -- (2.5,2);
%
% % Draw angle arc and label
% \draw (-,0) arc (0:\angle:0.5);
% \node at (0.4,0.3) {$\theta$};
%
% % Draw simplified car (rectangle)
% \begin{scope}[rotate=\angle,shift={(0,1)}]
% \draw[thick] (-0.8,-0.4) rectangle (0.8,0.4);
% % Add wheels
% \fill[black] (-0.6,-0.4) circle (0.1);
% \fill[black] (0.6,-0.4) circle (0.1);
% \end{scope}
%
% % Draw forces
% % Weight force
% \draw[->,thick,red] (0,1) -- (0,0) node[midway,right] {$W$};
%
% % Normal force
% \draw[->,thick,blue] (0,1) -- ({-sin(\angle)},{cos(\angle)}) node[midway,above left] {$N$};
%
% % Downslope component
% \draw[->,thick,green] (0,1) -- ({cos(\angle)},{sin(\angle)}) node[midway,below right] {$W\sin\theta$};
% \end{tikzpicture}
% Example of how to include an algorithm
\begin{algorithm}[H]
\caption{Your Algorithm Name}
\begin{algorithmic}[1]
\STATE Initialize parameters
\WHILE{not converged}
\STATE Update step
\ENDWHILE
\RETURN Result
\end{algorithmic}
\end{algorithm}
\section{Experiments and Results}
Present your findings:
\begin{itemize}
\item Experimental setup
\item Results and analysis
\item Comparison with baselines (if applicable)
\end{itemize}
% Example of how to include figures
\begin{figure}[H]
\begin{tikzpicture}[scale=1.5]
% Define slope angle
\def\angle{30}
% Calculate some points for consistent geometry
\def\slopeStart{-2}
\def\slopeEnd{2}
\def\slopeHeight{2.309} % tan(30°) * 2
% Draw ground (horizontal line)
\draw[thick] (-3,0) -- (3,0);
% Draw slope
\draw[thick] (\slopeStart,0) -- (\slopeEnd,\slopeHeight);
% Calculate car center position on slope
\def\carX{0} % Center position along x-axis
\def\carY{1.6} % tan(30°) * carX + appropriate offset
% Draw car (rectangle) exactly on slope
\begin{scope}[shift={(\carX,\carY)}, rotate=\angle]
\draw[thick] (-0.6,-0.3) rectangle (0.6,0.3);
% Add wheels aligned with slope
\fill[black] (-0.45,-0.3) circle (0.08);
\fill[black] (0.45,-0.3) circle (0.08);
\draw[->,thick] (0,0) -- ++(-0.8, 0) node[left] {$F_{slope} + F_{drag} + F_{rolling}$};
\draw[->,thick] (0,0) -- ++(0.8, 0) node[right] {$F_{motor}$};
\node at (0,0) [circle,fill,inner sep=1.5pt]{};
\end{scope}
% Draw forces from center of car
% Center point of car for forces
\coordinate (carCenter) at (\carX,\carY);
\end{tikzpicture}
\centering
\caption{Description of your figure}
\label{fig:example}
\caption{Free body diagram showing relevant forces on a 2-dimensional car}
\label{fig:freebody}
\end{figure}
% Example of how to include tables
\begin{table}[H]
To model the vehicle dynamics, we simplify the system to a 2d plane. As seen in Figure~\ref{fig:freebody}, the forces on the car
are due to intrinsic vehicle properties, current velocity, and environment conditions like slope and wind. If the velocity is held
constant, we can assert that the sum of the forces on the car is zero:
\begin{align}
F_{drag} + F_{slope} + F_{rolling} + F_{motor} &= 0 \\
F_{drag} &= \frac{1}{2} \rho v^2 C_dA \\
F_{slope} &= mg\sin {\theta} \\
F_{rolling} &= mg\cos {\theta} C_{rr}
\end{align}
The $F_{motor}$ term is modulated by the driver. In our case, we give the agent a simpler control mechanism with a normalized
velocity control instead of a force-based control. This is written as $v = \alpha v_{max}$ where $\alpha$ is the action taken
by the agent in $\left[-1,1\right]$.
From the velocity, and the forces acting on the car, we can determine
the power of the car using a simple $K_t$ model:
\begin{align}
\tau &= \left(F_{drag} + F_{slope} + F_{rolling}\right) r \\
P_{motor} &= \tau v + R_{motor} \left(\frac{\tau}{K_t}\right)^2
\end{align}
The torque of motor is the sum of the outstanding forces times the wheel radius. $K_t$ is a motor parameter, as is $R_{motor}$.
Both can be extracted from physical motors to simulate them, but simple "rule-of-thumb" numbers were used during development.
The power of the motor is given in watts. Based on the time-step of the simulation, we can determine the energy consumed in joules
with $W \times s = J$. A time-step of 1 second was chosen to accelerate simulation. Lower values result in reduced integration
errors over time at the cost of longer episodes.
\subsection{Environment Generation}
It is important that our agent learns not just the optimal policy for a fixed course, but an approximate optimal policy
for any course. To this end we must be able to generate a wide variety of terrain and wind scenarios. Perlin noise
is typically used in this context. We use a 1D Perlin noise to generate the slope of the terrain, and then integrate the slope to create
the elevation profile. Currently the elevation profile is unused, but it can be important for drag force due to changes in air pressure.
This was done because differentiated Perlin noise is not smooth, and is not an accurate representation of slope. The wind was
generated with a 2D Perlin noise, where one axis was time, and the other was position. The noise was blurred in the time-axis
to ease the changes in wind at any given point.
An example of the environment can be seen in Figure~\ref{fig:env_vis}.
\begin{figure}[H]
\centering
\caption{Your Table Caption}
\begin{tabular}{lcc}
\toprule
Method & Metric 1 & Metric 2 \\
\midrule
Approach 1 & Value & Value \\
Approach 2 & Value & Value \\
\bottomrule
\end{tabular}
\label{tab:results}
\end{table}
\includegraphics[width=\textwidth]{environment.pdf}
\caption{Visualization of the generated environment}
\label{fig:env_vis}
\end{figure}
\subsection{Performance Evaluation}
To quantify agent performance, we must produce a single value reward. While multi-objective learning is an interesting
subject\footnote{I especially wanted to do meta-rl using a neural net to compute a single reward from inputs} it is out of the scope
of this project. Additionally, sparse rewards can significantly slow down learning. A poor reward function can prevent
agents from approaching optimal policy. With these factors in mind, we use the following:
\[
R = x/D_{goal} + (x > D_{goal}) * \left(100 - E - 10(t - T_{goal})\right) + (t > T_{goal}) * -500
\]
To understand this, there are three major components: the continuous reward, which is rewarded at every step, and is the position of the car
relative to the goal distance. The victory reward is a constant, minus the energy used and the early arrival penalty.
This was added to help guide the agent towards arriving with as little time left as possible. Finally, there's a penalty for the time
going above the goal time, as after that point the car is disqualified from the race.
It took a few iterations to find a reward metric that promoted fast learning. Some of these issues were exacerbated by the initially low
performance when using stable baselines. A crucial part of the improvement was the energy loss only being applied during wins.
This allowed the model to quickly learn to go forward to finish, after which refinement of speed could take
place\footnote{I looked into Q-initialization but couldn't figure out a way to implement it easily.}.
\subsection{State and Observation Spaces}
The complete state of the simulator is the position, velocity, and energy of the car, as well as the entire environment.
These parameters are sufficient for a deterministic snapshot of the simulator. However, one goal of the project
was to enable partial-observation of the system. To this end, we separate the observation space into a small snippet
of the upcoming wind and slope. This also simplifies the agent calculations since the view of the environment is
relative to its current position. The size of the view can be controlled as a parameter.
\section{Experiments and Results}
An implementation of the aforementioned simulator was developed with Jax. Jax was chosen as it enables
vectorization and optimization to improve performance. Additionally, Jax allows for gradients of any function
to be computed, which is useful for certain classes of reinforcement learning. In our case, we didn't
use this as there seemed to be very little available off the shelf.
Initially Stable Baselines was used since it is one of the most popular implemntations of common RL algorithms.
Stable Baselines3\cite{stable-baselines3} is written in PyTorch\cite{Ansel_PyTorch_2_Faster_2024}, and uses the Gym\cite{gymnasium} format for environments. A basic Gym wrapper
was created to connect SB3 to our environment.
PPO was chosen as the RL algorithm as it is very simple, while still being effective \cite{proximalpolicyoptimization}
The performance and convergence was very bad. This made
it difficult to diagnose as the model would need potentially millions of steps before it would learn anything interesting.
The primary performance loss was in the Jax/Numpy/PyTorch conversion, as this requires a CPU roundtrip.
To combat this I found a Jax-based implementation of PPO called \texttt{purejaxrl}. This library is
written in the style of CleanRL but instead uses pure Jax and an environment library called \texttt{gymnax}\cite{gymnax2022github}.
The primary advantage of writing everything in Jax is that both the RL agent and the environment can be offloaded to the GPU.
Additionally, the infrastructure provided by \texttt{gymnax} allows for environments to be vectorized. The speedup from
using this library cannot be understated. The SB3 PPO implementation ran at around 150 actions per second. After rewriting
some of the code to make it work with \texttt{purejaxrl}, the effective action rate\footnote{I ran 2048 environments in parallel}
was nearly$238000$ actions per second\footnote{It's likely that performance with SB3 could have been improved, but I was struggling to figure out exactly how.}.
The episode returns after 50 million timesteps with a PPO model can be seen in Figure~\ref{fig:returns}. Each update step
is performed after collecting minibatches of rollouts based on the current policy. We can see a clean ascent at the start of training,
this is the agent learning to drive forward. After a certain point, the returns become noisy. This is likely due to energy scoring
being random based on the terrain. A solution to this, which wasn't pursued due to lack of time, would be to compute the
"nominal energy" use based on travelling at $v_{avg}$. Energy consumption that was above the nominal use would be penalized, and
below would be heavily rewarded. Despite this, performance continued to improve, which is a good sign for the agent being
able to learn the underlying dynamics.
\begin{figure}[H]
\centering
\includegraphics[width=0.8\textwidth]{PPO_results.pdf}
\caption{Episodic Returns during PPO training}
\label{fig:returns}
\end{figure}
Initially I thought that this was actually pretty impressive, but I looked at an individual level
and it seemed to just drive forward too fast. Reworking the reward function might cause this to converge better.
I wish I had a graph, but I keep running out of memory when I try to capture a rollout.
\section{Discussion}
Analyze your results:
\begin{itemize}
\item Interpretation of findings
\item Limitations and challenges
\item Potential improvements
\end{itemize}
While the PPO performance was decent, it still had a significant amount of improvement on the table. Tuning the reward function
would probably help it find a solution better. One strategy that would help significantly is to pre-tune the model to output the
average speed by default, so the model doesn't have to learn that at the beginning. This is called Q-initialization and is a common
trick for problem spaces where an initial estimate exists and is easy to define. Perhaps the most important takeaway from this
work is the power of end-to-end Jax RL. \texttt{purejaxrl} is CleanRL levels of code clarity, with everything for an agent
being contained in one file, but surpassing Stable Baselines3 significantly in terms of performance. One drawback is that
the ecosystem is very new, so there was very little to reference when I was developing my simulator. Often there would be
an opaque error message that would yield no results on search engines, and would require digging into the Jax source code to diagnose.
Typically this was some misunderstanding about the inner works of Jax. Future work on this project would involve trying out
other agents, and comparing different reward functions. Adjusting the actor-critic network would also be an interesting avenue,
especially since a CNN will likely work well with wind and cloud information, which have both a spatial and temporal
axis\footnote{You can probably tell that the quality dropped off near the end - bunch of life things got in the way, so this didn't go as well as I'd hoped. Learned a lot though.}.
\section{Conclusion}
Summarize your work:
\begin{itemize}
\item Key contributions
\item Practical implications
\item Future work directions
\end{itemize}
We outline the design of a physics based model of solar car races. We implement this model and create a simulation environment
for use with popular RL algorithm packages. We demonstrate the performance and learning ability of these algorithms on our model.
Further work includes more accurate modelling, improved reward functions, and hyperparameter tuning.
\bibliography{references}
\bibliographystyle{plain}

View file

@ -0,0 +1,368 @@
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy
import os
import random
import time
from dataclasses import dataclass
import flax
import flax.linen as nn
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
import optax
import tyro
from flax.training.train_state import TrainState
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard.writer import SummaryWriter
@dataclass
class Args:
exp_name: str = os.path.basename(__file__)[: -len(".py")]
"""the name of this experiment"""
seed: int = 1
"""seed of the experiment"""
track: bool = False
"""if toggled, this experiment will be tracked with Weights and Biases"""
wandb_project_name: str = "cleanRL"
"""the wandb's project name"""
wandb_entity: str = None
"""the entity (team) of wandb's project"""
capture_video: bool = False
"""whether to capture videos of the agent performances (check out `videos` folder)"""
save_model: bool = False
"""whether to save model into the `runs/{run_name}` folder"""
upload_model: bool = False
"""whether to upload the saved model to huggingface"""
hf_entity: str = ""
"""the user or org name of the model repository from the Hugging Face Hub"""
# Algorithm specific arguments
env_id: str = "MountainCarContinuous-v0"
"""the id of the environment"""
total_timesteps: int = 1000000
"""total timesteps of the experiments"""
learning_rate: float = 3e-4
"""the learning rate of the optimizer"""
buffer_size: int = int(1e6)
"""the replay memory buffer size"""
gamma: float = 0.99
"""the discount factor gamma"""
tau: float = 0.005
"""target smoothing coefficient (default: 0.005)"""
batch_size: int = 256
"""the batch size of sample from the reply memory"""
policy_noise: float = 0.2
"""the scale of policy noise"""
exploration_noise: float = 0.1
"""the scale of exploration noise"""
learning_starts: int = 25e3
"""timestep to start learning"""
policy_frequency: int = 2
"""the frequency of training policy (delayed)"""
noise_clip: float = 0.5
"""noise clip parameter of the Target Policy Smoothing Regularization"""
def make_env(env_id, seed, idx, capture_video, run_name):
def thunk():
if capture_video and idx == 0:
env = gym.make(env_id, render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
else:
env = gym.make(env_id)
env = gym.wrappers.RecordEpisodeStatistics(env)
env.action_space.seed(seed)
return env
return thunk
# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
@nn.compact
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
x = jnp.concatenate([x, a], -1)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(1)(x)
return x
class Actor(nn.Module):
action_dim: int
action_scale: jnp.ndarray
action_bias: jnp.ndarray
@nn.compact
def __call__(self, x):
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(self.action_dim)(x)
x = nn.tanh(x)
x = x * self.action_scale + self.action_bias
return x
class TrainState(TrainState):
target_params: flax.core.FrozenDict
if __name__ == "__main__":
import stable_baselines3 as sb3
if sb3.__version__ < "2.0":
raise ValueError(
"""Ongoing migration: run the following command to install the new dependencies:
poetry run pip install "stable_baselines3==2.0.0a1"
"""
)
args = tyro.cli(Args)
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
import wandb
wandb.init(
project=args.wandb_project_name,
entity=args.wandb_entity,
sync_tensorboard=True,
config=vars(args),
name=run_name,
monitor_gym=True,
save_code=True,
)
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
"hyperparameters",
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
key = jax.random.PRNGKey(args.seed)
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)
# env setup
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
max_action = float(envs.single_action_space.high[0])
envs.single_observation_space.dtype = np.float32
rb = ReplayBuffer(
args.buffer_size,
envs.single_observation_space,
envs.single_action_space,
device="cpu",
handle_timeout_termination=False,
)
# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
actor = Actor(
action_dim=np.prod(envs.single_action_space.shape),
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0),
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0),
)
actor_state = TrainState.create(
apply_fn=actor.apply,
params=actor.init(actor_key, obs),
target_params=actor.init(actor_key, obs),
tx=optax.adam(learning_rate=args.learning_rate),
)
qf = QNetwork()
qf1_state = TrainState.create(
apply_fn=qf.apply,
params=qf.init(qf1_key, obs, envs.action_space.sample()),
target_params=qf.init(qf1_key, obs, envs.action_space.sample()),
tx=optax.adam(learning_rate=args.learning_rate),
)
qf2_state = TrainState.create(
apply_fn=qf.apply,
params=qf.init(qf2_key, obs, envs.action_space.sample()),
target_params=qf.init(qf2_key, obs, envs.action_space.sample()),
tx=optax.adam(learning_rate=args.learning_rate),
)
actor.apply = jax.jit(actor.apply)
qf.apply = jax.jit(qf.apply)
@jax.jit
def update_critic(
actor_state: TrainState,
qf1_state: TrainState,
qf2_state: TrainState,
observations: np.ndarray,
actions: np.ndarray,
next_observations: np.ndarray,
rewards: np.ndarray,
terminations: np.ndarray,
key: jnp.ndarray,
):
# TODO Maybe pre-generate a lot of random keys
# also check https://jax.readthedocs.io/en/latest/jax.random.html
key, noise_key = jax.random.split(key, 2)
clipped_noise = (
jnp.clip(
(jax.random.normal(noise_key, actions.shape) * args.policy_noise),
-args.noise_clip,
args.noise_clip,
)
* actor.action_scale
)
next_state_actions = jnp.clip(
actor.apply(actor_state.target_params, next_observations) + clipped_noise,
envs.single_action_space.low,
envs.single_action_space.high,
)
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1)
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target)
next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1)
def mse_loss(params):
qf_a_values = qf.apply(params, observations, actions).squeeze()
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params)
qf1_state = qf1_state.apply_gradients(grads=grads1)
qf2_state = qf2_state.apply_gradients(grads=grads2)
return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key
@jax.jit
def update_actor(
actor_state: TrainState,
qf1_state: TrainState,
qf2_state: TrainState,
observations: np.ndarray,
):
def actor_loss(params):
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean()
actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params)
actor_state = actor_state.apply_gradients(grads=grads)
actor_state = actor_state.replace(
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
)
qf1_state = qf1_state.replace(
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau)
)
qf2_state = qf2_state.replace(
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau)
)
return actor_state, (qf1_state, qf2_state), actor_loss_value
start_time = time.time()
for global_step in range(args.total_timesteps):
# ALGO LOGIC: put action logic here
if global_step < args.learning_starts:
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
else:
actions = actor.apply(actor_state.params, obs)
actions = np.array(
[
(
jax.device_get(actions)[0]
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape)
).clip(envs.single_action_space.low, envs.single_action_space.high)
]
)
# TRY NOT TO MODIFY: execute the game and log data.
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
# TRY NOT TO MODIFY: record rewards for plotting purposes
if "final_info" in infos:
for info in infos["final_info"]:
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
break
# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
real_next_obs = next_obs.copy()
for idx, trunc in enumerate(truncations):
if trunc:
real_next_obs[idx] = infos["final_observation"][idx]
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
obs = next_obs
# ALGO LOGIC: training.
if global_step > args.learning_starts:
data = rb.sample(args.batch_size)
(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic(
actor_state,
qf1_state,
qf2_state,
data.observations.numpy(),
data.actions.numpy(),
data.next_observations.numpy(),
data.rewards.flatten().numpy(),
data.dones.flatten().numpy(),
key,
)
if global_step % args.policy_frequency == 0:
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor(
actor_state,
qf1_state,
qf2_state,
data.observations.numpy(),
)
if global_step % 100 == 0:
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step)
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step)
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
print("SPS:", int(global_step / (time.time() - start_time)))
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
if args.save_model:
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
with open(model_path, "wb") as f:
f.write(
flax.serialization.to_bytes(
[
actor_state.params,
qf1_state.params,
qf2_state.params,
]
)
)
print(f"model saved to {model_path}")
from cleanrl_utils.evals.td3_jax_eval import evaluate
episodic_returns = evaluate(
model_path,
make_env,
args.env_id,
eval_episodes=10,
run_name=f"{run_name}-eval",
Model=(Actor, QNetwork),
exploration_noise=args.exploration_noise,
)
for idx, episodic_return in enumerate(episodic_returns):
writer.add_scalar("eval/episodic_return", episodic_return, idx)
if args.upload_model:
from cleanrl_utils.huggingface import push_to_hub
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval")
envs.close()
writer.close()

View file

@ -66,12 +66,12 @@ def downslope_force(mass, theta):
return mass * 9.8 * jnp.sin(theta)
@partial(jit, static_argnames=["crr"])
@jit
def rolling_force(mass, theta, crr):
return normal_force(mass, theta) * crr
@partial(jit, static_argnames=["area", "cd", "rho"])
@jit
def drag_force(u, area, cd, rho):
return 0.5 * rho * jnp.pow(u, 2) * cd * area

View file

@ -93,7 +93,7 @@ class SolarRaceV1(gym.Env):
}
)
self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(1,)) # velocity, m/s
self.action_space = gym.spaces.Box(0.0, 1.0, shape=(1,)) # velocity, m/s
def reset(self, *, seed = None, options = None):
@ -117,7 +117,7 @@ class SolarRaceV1(gym.Env):
reward -= 600 - self._state[1][0]
reward += 1e-6 * (self._state[2][0]) # net energy is negative.
if jnp.all(self._state[1] > 600):
reward -= 50000
reward -= 500
truncated = True
return self._get_obs(), reward, terminated, truncated, {}

View file

@ -5,7 +5,7 @@ import jax
import jax.numpy as jnp
import chex
from flax import struct
from jax import lax
from jax import lax, vmap
from gymnax.environments import environment
from gymnax.environments import spaces
@ -59,19 +59,6 @@ class Snax(environment.Environment[SimState, SimParams]):
dtype=jnp.float32,
)
return spaces.Box(low, high, shape=(shape,))
# return spaces.Dict(
# {
# "position": spaces.Box(0.0, params.map_size, (), jnp.float32),
# "realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
# "energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
# "upcoming_terrain": spaces.Box(
# -1.0, 1.0, shape=(100,), dtype=jnp.float32
# ),
# # skip wind for now
# }
# )
def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict:
if params is None:
@ -170,9 +157,15 @@ class Snax(environment.Environment[SimState, SimParams]):
# ):
# reward -= 500
# we have to vectorize that.
# # we have to vectorize that.
# reward = new_state.position / params.goal_dist # constant reward for moving forward
# # reward for finishing
# reward += (new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy)
# # reward for failure
# reward += (new_state.realtime >= params.goal_time) * -500
reward = new_state.position / params.goal_dist + \
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy) + \
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-6*new_state.energy) + \
(new_state.realtime >= params.goal_time) * -500
reward = reward.squeeze()
terminal = self.is_terminal(state, params)