final bit
This commit is contained in:
parent
70a659f468
commit
7ad7070129
110
notebooks/gymv2_jax.ipynb
Normal file
110
notebooks/gymv2_jax.ipynb
Normal file
|
@ -0,0 +1,110 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"from gymnasium.utils.env_checker import check_env as gym_check_env\n",
|
||||
"from stable_baselines3 import TD3\n",
|
||||
"from stable_baselines3.common.env_checker import check_env\n",
|
||||
"from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n",
|
||||
"from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n",
|
||||
"from gymnax.wrappers.gym import GymnaxToVectorGymWrapper, GymnaxToGymWrapper\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from solarcarsim.simv2 import Snax\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"env = Snax()\n",
|
||||
"wrapped_env = GymnaxToGymWrapper(env)\n",
|
||||
"vector_gym_env = GymnaxToVectorGymWrapper(env)\n",
|
||||
"np_wrapper = JaxToNumpy(wrapped_env)\n",
|
||||
"np_vec_wrapper = VJaxToNumpy(vector_gym_env)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Using cuda device\n",
|
||||
"Wrapping the env with a `Monitor` wrapper\n",
|
||||
"Wrapping the env in a DummyVecEnv.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "ValueError",
|
||||
"evalue": "Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, np_wrapper, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:328\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_freq, TrainFreq) \u001b[38;5;66;03m# check done in _setup_learn()\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 328\u001b[0m rollout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect_rollouts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_freq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[43m \u001b[49m\u001b[43maction_noise\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maction_noise\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 332\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43mlearning_starts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearning_starts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mreplay_buffer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreplay_buffer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m rollout\u001b[38;5;241m.\u001b[39mcontinue_training:\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:560\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.collect_rollouts\u001b[0;34m(self, env, callback, train_freq, replay_buffer, action_noise, learning_starts, log_interval)\u001b[0m\n\u001b[1;32m 557\u001b[0m actions, buffer_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sample_action(learning_starts, action_noise, env\u001b[38;5;241m.\u001b[39mnum_envs)\n\u001b[1;32m 559\u001b[0m \u001b[38;5;66;03m# Rescale and perform action\u001b[39;00m\n\u001b[0;32m--> 560\u001b[0m new_obs, rewards, dones, infos \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mnum_envs\n\u001b[1;32m 563\u001b[0m num_collected_steps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/base_vec_env.py:206\u001b[0m, in \u001b[0;36mVecEnv.step\u001b[0;34m(self, actions)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;124;03mStep the environments with the given action\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \n\u001b[1;32m 202\u001b[0m \u001b[38;5;124;03m:param actions: the action\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;124;03m:return: observation, reward, done, information\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstep_async(actions)\n\u001b[0;32m--> 206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_wait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:58\u001b[0m, in \u001b[0;36mDummyVecEnv.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep_wait\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m VecEnvStepReturn:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# Avoid circular imports\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m env_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_envs):\n\u001b[0;32m---> 58\u001b[0m obs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_rews[env_idx], terminated, truncated, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_infos[env_idx] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menvs\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactions\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;66;03m# convert to SB3 VecEnv api\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_dones[env_idx] \u001b[38;5;241m=\u001b[39m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/monitor.py:94\u001b[0m, in \u001b[0;36mMonitor.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneeds_reset:\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTried to step environment that needs reset\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 94\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrewards\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mfloat\u001b[39m(reward))\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated:\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/wrappers/jax_to_numpy.py:166\u001b[0m, in \u001b[0;36mJaxToNumpy.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Transforms the action to a jax array .\u001b[39;00m\n\u001b[1;32m 158\u001b[0m \n\u001b[1;32m 159\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;124;03m A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.\u001b[39;00m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m jax_action \u001b[38;5;241m=\u001b[39m numpy_to_jax(action)\n\u001b[0;32m--> 166\u001b[0m obs, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax_action\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 169\u001b[0m jax_to_numpy(obs),\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28mfloat\u001b[39m(reward),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 173\u001b[0m jax_to_numpy(info),\n\u001b[1;32m 174\u001b[0m )\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/wrappers/gym.py:70\u001b[0m, in \u001b[0;36mGymnaxToGymWrapper.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Step environment, follow new step API.\"\"\"\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng, step_key \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng)\n\u001b[0;32m---> 70\u001b[0m o, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_state, r, d, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mstep_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_params\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m o, r, d, d, info\n",
|
||||
" \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/environments/environment.py:45\u001b[0m, in \u001b[0;36mEnvironment.step\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 43\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_params\n\u001b[1;32m 44\u001b[0m key, key_reset \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(key)\n\u001b[0;32m---> 45\u001b[0m obs_st, state_st, reward, done, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m obs_re, state_re \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_env(key_reset, params)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Auto-reset environment based on termination\u001b[39;00m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/src/solarcarsim/simv2.py:138\u001b[0m, in \u001b[0;36mSnax.step_env\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 136\u001b[0m theta \u001b[38;5;241m=\u001b[39m state\u001b[38;5;241m.\u001b[39mslope[pos]\n\u001b[1;32m 137\u001b[0m velocity \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([action \u001b[38;5;241m*\u001b[39m params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmax_speed])\u001b[38;5;241m.\u001b[39msqueeze()\n\u001b[0;32m--> 138\u001b[0m dragf \u001b[38;5;241m=\u001b[39m \u001b[43msim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_force\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[43m \u001b[49m\u001b[43mvelocity\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrontal_area\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_coeff\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.184\u001b[39;49m\n\u001b[1;32m 140\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m rollf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mrolling_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta, params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mrolling_coeff)\n\u001b[1;32m 142\u001b[0m hillf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mdownslope_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta)\n",
|
||||
" \u001b[0;31m[... skipping hidden 3 frame]\u001b[0m\n",
|
||||
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/jax/_src/pjit.py:768\u001b[0m, in \u001b[0;36m_infer_params\u001b[0;34m(fun, ji, args, kwargs)\u001b[0m\n\u001b[1;32m 764\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,\n\u001b[1;32m 765\u001b[0m kwargs, in_avals\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 766\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m p, p\u001b[38;5;241m.\u001b[39mconsts \u001b[38;5;241m+\u001b[39m args_flat\n\u001b[0;32m--> 768\u001b[0m entry \u001b[38;5;241m=\u001b[39m \u001b[43m_infer_params_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 769\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mji\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msignature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpjit_mesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresource_env\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m entry\u001b[38;5;241m.\u001b[39mpjit_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 771\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(\n\u001b[1;32m 772\u001b[0m fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals\u001b[38;5;241m=\u001b[39mavals)\n",
|
||||
"\u001b[0;31mValueError\u001b[0m: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"\n",
|
||||
"model = TD3(\"MlpPolicy\", np_wrapper, verbose=1)\n",
|
||||
"model.learn(total_timesteps=1000)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
273
pdm.lock
273
pdm.lock
|
@ -5,7 +5,7 @@
|
|||
groups = ["default", "dev"]
|
||||
strategy = ["inherit_metadata"]
|
||||
lock_version = "4.5.0"
|
||||
content_hash = "sha256:a3b65f863c554725c33d452fd759776141740661fa3555d306ed08563a7e16e2"
|
||||
content_hash = "sha256:2f7c4bee801973a3b7856ba0707891eb01fd05659948707f44be4aa302e5dabd"
|
||||
|
||||
[[metadata.targets]]
|
||||
requires_python = ">=3.12,<3.13"
|
||||
|
@ -237,6 +237,27 @@ files = [
|
|||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distrax"
|
||||
version = "0.1.5"
|
||||
requires_python = ">=3.9"
|
||||
summary = "Distrax: Probability distributions in JAX."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"absl-py>=0.9.0",
|
||||
"chex>=0.1.8",
|
||||
"jax>=0.1.55",
|
||||
"jaxlib>=0.1.67",
|
||||
"numpy>=1.23.0",
|
||||
"setuptools; python_version >= \"3.12\"",
|
||||
"tensorflow-probability>=0.15.0",
|
||||
]
|
||||
files = [
|
||||
{file = "distrax-0.1.5-py3-none-any.whl", hash = "sha256:5020f4b53a9a480d019c12e44292fbacb7de857cce478bc594dacf29519c61b7"},
|
||||
{file = "distrax-0.1.5.tar.gz", hash = "sha256:ec41522d389af69efedc8d475a7e6d8f229429c00f2140dcd641feacf7e21948"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dm-tree"
|
||||
version = "0.1.8"
|
||||
|
@ -254,6 +275,18 @@ files = [
|
|||
{file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "docstring-parser"
|
||||
version = "0.16"
|
||||
requires_python = ">=3.6,<4.0"
|
||||
summary = "Parse Python docstrings in reST, Google and Numpydoc format"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637"},
|
||||
{file = "docstring_parser-0.16.tar.gz", hash = "sha256:538beabd0af1e2db0146b6bd3caa526c35a34d61af9fd2887f3a8a27a739aa6e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etils"
|
||||
version = "1.11.0"
|
||||
|
@ -408,6 +441,26 @@ files = [
|
|||
{file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "grpcio"
|
||||
version = "1.68.1"
|
||||
requires_python = ">=3.8"
|
||||
summary = "HTTP/2-based RPC framework"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "grpcio-1.68.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:8829924fffb25386995a31998ccbbeaa7367223e647e0122043dfc485a87c666"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:3aed6544e4d523cd6b3119b0916cef3d15ef2da51e088211e4d1eb91a6c7f4f1"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:4efac5481c696d5cb124ff1c119a78bddbfdd13fc499e3bc0ca81e95fc573684"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ab2d912ca39c51f46baf2a0d92aa265aa96b2443266fc50d234fa88bf877d8e"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c87ce2a97434dffe7327a4071839ab8e8bffd0054cc74cbe971fba98aedd60"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:e4842e4872ae4ae0f5497bf60a0498fa778c192cc7a9e87877abd2814aca9475"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:255b1635b0ed81e9f91da4fcc8d43b7ea5520090b9a9ad9340d147066d1d3613"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-win32.whl", hash = "sha256:7dfc914cc31c906297b30463dde0b9be48e36939575eaf2a0a22a8096e69afe5"},
|
||||
{file = "grpcio-1.68.1-cp312-cp312-win_amd64.whl", hash = "sha256:a0c8ddabef9c8f41617f213e527254c41e8b96ea9d387c632af878d05db9229c"},
|
||||
{file = "grpcio-1.68.1.tar.gz", hash = "sha256:44a8502dd5de653ae6a73e2de50a401d84184f0331d0ac3daeb044e66d5c5054"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym"
|
||||
version = "0.26.2"
|
||||
|
@ -622,6 +675,76 @@ files = [
|
|||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-pjrt"
|
||||
version = "0.4.36"
|
||||
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-pjrt==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
extras = ["with_cuda"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin==0.4.36",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.6.85",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12<10.0,>=9.1",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax"
|
||||
version = "0.4.37"
|
||||
extras = ["cuda12"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "Differentiate, compile, and transform Numpy code."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
|
||||
"jax==0.4.37",
|
||||
"jaxlib==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jaxlib"
|
||||
version = "0.4.36"
|
||||
|
@ -737,6 +860,21 @@ files = [
|
|||
{file = "kiwisolver-1.4.7.tar.gz", hash = "sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown"
|
||||
version = "3.7"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Python implementation of John Gruber's Markdown."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"importlib-metadata>=4.4; python_version < \"3.10\"",
|
||||
]
|
||||
files = [
|
||||
{file = "Markdown-3.7-py3-none-any.whl", hash = "sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803"},
|
||||
{file = "markdown-3.7.tar.gz", hash = "sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "markdown-it-py"
|
||||
version = "3.0.0"
|
||||
|
@ -937,7 +1075,7 @@ version = "12.4.5.8"
|
|||
requires_python = ">=3"
|
||||
summary = "CUBLAS native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
||||
|
@ -950,13 +1088,26 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA profiling tools runtime libs."
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvcc-cu12"
|
||||
version = "12.6.85"
|
||||
requires_python = ">=3"
|
||||
summary = "CUDA nvcc"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvrtc-cu12"
|
||||
version = "12.4.127"
|
||||
|
@ -976,7 +1127,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA Runtime native Libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
||||
|
@ -989,7 +1140,7 @@ version = "9.1.0.70"
|
|||
requires_python = ">=3"
|
||||
summary = "cuDNN runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
]
|
||||
|
@ -1004,7 +1155,7 @@ version = "11.2.1.3"
|
|||
requires_python = ">=3"
|
||||
summary = "CUFFT native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -1033,7 +1184,7 @@ version = "11.6.1.9"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA solver native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
"nvidia-cusparse-cu12",
|
||||
|
@ -1051,7 +1202,7 @@ version = "12.3.1.170"
|
|||
requires_python = ">=3"
|
||||
summary = "CUSPARSE native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -1067,7 +1218,7 @@ version = "2.21.5"
|
|||
requires_python = ">=3"
|
||||
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||
]
|
||||
|
@ -1078,7 +1229,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "Nvidia JIT LTO Library"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
|
@ -1759,6 +1910,18 @@ files = [
|
|||
{file = "shiboken6-6.8.0.2-cp39-abi3-win_amd64.whl", hash = "sha256:b11e750e696bb565d897e0f5836710edfb86bd355f87b09988bd31b2aad404d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "shtab"
|
||||
version = "1.7.1"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Automagic shell tab completion for Python CLI applications"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "shtab-1.7.1-py3-none-any.whl", hash = "sha256:32d3d2ff9022d4c77a62492b6ec875527883891e33c6b479ba4d41a51e259983"},
|
||||
{file = "shtab-1.7.1.tar.gz", hash = "sha256:4e4bcb02eeb82ec45920a5d0add92eac9c9b63b2804c9196c1f1fdc2d039243c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simplejson"
|
||||
version = "3.19.3"
|
||||
|
@ -1847,6 +2010,42 @@ files = [
|
|||
{file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorboard"
|
||||
version = "2.18.0"
|
||||
requires_python = ">=3.9"
|
||||
summary = "TensorBoard lets you watch Tensors Flow"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"absl-py>=0.4",
|
||||
"grpcio>=1.48.2",
|
||||
"markdown>=2.6.8",
|
||||
"numpy>=1.12.0",
|
||||
"packaging",
|
||||
"protobuf!=4.24.0,>=3.19.6",
|
||||
"setuptools>=41.0.0",
|
||||
"six>1.9",
|
||||
"tensorboard-data-server<0.8.0,>=0.7.0",
|
||||
"werkzeug>=1.0.1",
|
||||
]
|
||||
files = [
|
||||
{file = "tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorboard-data-server"
|
||||
version = "0.7.2"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Fast data loading for TensorBoard"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"},
|
||||
{file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"},
|
||||
{file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorflow-probability"
|
||||
version = "0.25.0"
|
||||
|
@ -1997,6 +2196,22 @@ files = [
|
|||
{file = "triton-3.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8182f42fd8080a7d39d666814fa36c5e30cc00ea7eeeb1a2983dbb4c99a0fdc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeguard"
|
||||
version = "4.4.1"
|
||||
requires_python = ">=3.9"
|
||||
summary = "Run-time type checker for Python"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"importlib-metadata>=3.6; python_version < \"3.10\"",
|
||||
"typing-extensions>=4.10.0",
|
||||
]
|
||||
files = [
|
||||
{file = "typeguard-4.4.1-py3-none-any.whl", hash = "sha256:9324ec07a27ec67fc54a9c063020ca4c0ae6abad5e9f0f9804ca59aee68c6e21"},
|
||||
{file = "typeguard-4.4.1.tar.gz", hash = "sha256:0d22a89d00b453b47c49875f42b6601b961757541a2e1e0ef517b6e24213c21b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.12.2"
|
||||
|
@ -2009,6 +2224,29 @@ files = [
|
|||
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tyro"
|
||||
version = "0.9.2"
|
||||
requires_python = ">=3.7"
|
||||
summary = "CLI interfaces & config objects, from types"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"backports-cached-property>=1.0.2; python_version < \"3.8\"",
|
||||
"colorama>=0.4.0; platform_system == \"Windows\"",
|
||||
"docstring-parser>=0.16",
|
||||
"eval-type-backport>=0.1.3; python_version < \"3.10\"",
|
||||
"rich>=11.1.0",
|
||||
"shtab>=1.5.6",
|
||||
"typeguard>=4.0.0",
|
||||
"typing-extensions>=4.7.0; python_version < \"3.8\"",
|
||||
"typing-extensions>=4.9.0; python_version >= \"3.8\"",
|
||||
]
|
||||
files = [
|
||||
{file = "tyro-0.9.2-py3-none-any.whl", hash = "sha256:f7c301b30b1ac7b18672f234e45013786c494d64c0e3621b25b8414637af8f90"},
|
||||
{file = "tyro-0.9.2.tar.gz", hash = "sha256:692687e07c1ed35cc3a841e8c4a188424023f16bdef37f2d9c23cbeb8a3b51aa"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tzdata"
|
||||
version = "2024.2"
|
||||
|
@ -2063,6 +2301,21 @@ files = [
|
|||
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "werkzeug"
|
||||
version = "3.1.3"
|
||||
requires_python = ">=3.9"
|
||||
summary = "The comprehensive WSGI web application library."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"MarkupSafe>=2.1.1",
|
||||
]
|
||||
files = [
|
||||
{file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"},
|
||||
{file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zipp"
|
||||
version = "3.21.0"
|
||||
|
|
|
@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
|
|||
authors = [
|
||||
{name = "saji", email = "saji@saji.dev"},
|
||||
]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0"]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0", "tyro>=0.9.2", "tensorboard>=2.18.0", "distrax>=0.1.5"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
@ -7,21 +7,23 @@
|
|||
\usepackage{graphicx}
|
||||
\usepackage[margin=1in]{geometry}
|
||||
\usepackage{hyperref}
|
||||
\usepackage{pdfpages}
|
||||
\usepackage{algorithm}
|
||||
\usepackage{algorithmic}
|
||||
\usepackage{float}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{caption}
|
||||
\usepackage{subcaption}
|
||||
\usepackage{tikz}
|
||||
|
||||
% Custom commands
|
||||
\newcommand{\sectionheading}[1]{\noindent\textbf{#1}}
|
||||
|
||||
% Title and author information
|
||||
\title{\Large{Your Project Title: A Study in Optimal Control and Reinforcement Learning}}
|
||||
\author{Your Name\\
|
||||
Course Name\\
|
||||
Institution}
|
||||
\title{\Large{Solarcarsim: A Solar Racing Environment for RL Agents}}
|
||||
\author{Saji Champlin\\
|
||||
EE5241
|
||||
}
|
||||
\date{\today}
|
||||
|
||||
\begin{document}
|
||||
|
@ -33,102 +35,273 @@ Solar Racing is a competition with the goal of creating highly efficient solar-a
|
|||
requires awareness and complex decision making to determine optimal speeds to exploit the environmental conditions, such as winds,
|
||||
cloud cover, and changes in elevation. We present an environment modelled on the dynamics involved for a race, including generated
|
||||
elevation and wind profiles. The model uses the \texttt{gymnasium} interface to allow it to be used by a variety of algorithms.
|
||||
We demonstrate a method of designing reward functions for multi-objective problems. The environment shows to be solvable by modern
|
||||
reinforcement learning algorithms.
|
||||
We demonstrate a method of designing reward functions for multi-objective problems. We show learning using an Jax-based PPO model.
|
||||
\end{abstract}
|
||||
|
||||
\section{Introduction}
|
||||
Start with a broad context of your problem area in optimal control/reinforcement learning. Then narrow down to your specific focus. Include:
|
||||
\begin{itemize}
|
||||
\item Problem motivation
|
||||
\item Brief overview of existing approaches
|
||||
\item Your specific contributions
|
||||
\item Paper organization
|
||||
\end{itemize}
|
||||
|
||||
Solar racing was invented in the early 90s as a technology incubator for high-efficiency motor vehicles. The first solar races were speed
|
||||
focused, however a style of race that focused on minimal energy use within a given route was developed to push focus towards vehicle efficiency.
|
||||
The goal of these races is to arrive at a destination within a given time frame, while using as little grid (non-solar) energy as possible.
|
||||
Aerodynamic drag is one of the most significant sources of energy consumption, along with elevation changes. The simplest policy to meet
|
||||
the constraints of on-time arrival is:
|
||||
$$
|
||||
V_{\text{avg}} = \frac{D}{T}
|
||||
$$
|
||||
Where $D$ is the distance needed to travel, and $T$ is the maximum allowed time.
|
||||
Optimal driving is a complex policy based on terrain slope, wind forecasts, and solar forecasts.
|
||||
|
||||
Direct solutions to find the global minima of the energy usage on a route
|
||||
segment are difficult to compute. Instead, we present a reinforcement learning
|
||||
environment that can be used to train RL agents to race efficiently given
|
||||
limited foresight. The environment simulates key components of the race, such
|
||||
as terrain and wind, as well as car dynamics. The simulator is written using
|
||||
the Jax\cite{jax2018github} library which enables computations to be offloaded to the GPU. We
|
||||
provide wrappers for the \texttt{gymnasium} API as well as a
|
||||
\texttt{purejaxrl}\cite{lu2022discovered} implementation which can train a PPO
|
||||
agent with millions of timesteps in several minutes. We present an exploration of reward function design with regards to sparsity,
|
||||
magnitude, and learning efficiency.
|
||||
|
||||
\section{Background}
|
||||
Provide necessary background on:
|
||||
\begin{itemize}
|
||||
\item Your specific application domain
|
||||
\item Relevant algorithms and methods
|
||||
\item Previous work in this area
|
||||
\end{itemize}
|
||||
Performance evaluation for solar races typically take the form of
|
||||
$$
|
||||
S = D/E \times T
|
||||
$$
|
||||
Where $S$ is the score, $D$ is the distance travelled, $E$ is the energy consumed, and $T$ is the speed derating.
|
||||
The speed derate is calculated based on a desired average speed throughout the race. If average velocity is at or above the target,
|
||||
$T=1$, however $T$ approaches $0$ exponentially as the average velocity goes below the target.
|
||||
Based on this metric we conclude that the optimal score:
|
||||
\begin{enumerate}
|
||||
\item Maintains an average speed $V_{\text{avg}}$ as required by the derate.
|
||||
\item Minimizes energy usage otherwise.
|
||||
\end{enumerate}
|
||||
The simplest control to meet the constraints of on-time arrival is:
|
||||
$$
|
||||
V_{\text{avg}} = \frac{D_{goal}}{T_{goal}}
|
||||
$$
|
||||
Where $D_{goal}$ is the distance needed to travel, and $T_{goal}$ is the maximum allowed time. The average speed is nearly-optimal in most cases, but
|
||||
is not a globally optimal solution. Consider a small - there is much more energy being used from the battery when going uphill,
|
||||
but the same energy is returned to the car going downhill. Losses in the vehicle dictate that it is more effective to drive slowly
|
||||
up the hill, and speed up on the descent. The decision is further complicated by wind and cloud cover, which can aid or neuter the
|
||||
performance of the car. It is therefore of great interest for solar racing teams to have advanced strategies that can effectively
|
||||
traverse the terrain while minimizing environmental resistances.
|
||||
|
||||
Existing research on this subject is limited, as advanced solar car strategy is a competitive differentiator and is usually kept secret.
|
||||
However, the author knows that most of the work on this subject involves use of Modelica or similar acausal system simulators, and
|
||||
non-linear solvers that use multi-starts to attempt to find the global optimum. Other methods include exhaustive search, genetic algorithms,
|
||||
and Big Bang-Big Crunch optimization\cite{heuristicsolar}.
|
||||
|
||||
We start by analyzing a simple force-based model of the car, and then connect this to an energy system using motor equations. We generate
|
||||
a simulated environment including terrain and wind. Then, we develop a reward system that encapsulates the goals of the environment.
|
||||
Finally, we train off-the-shelf RL models from Stable Baselines3 and purejaxrl to show learning on the environment.
|
||||
|
||||
\section{Methodology}
|
||||
Describe your approach in detail:
|
||||
\begin{itemize}
|
||||
\item Problem formulation
|
||||
\item Algorithm description
|
||||
\item Implementation details
|
||||
\end{itemize}
|
||||
% \begin{tikzpicture}[scale=1.5]
|
||||
% % Define slope angle
|
||||
% \def\angle{30}
|
||||
%
|
||||
% % Draw ground/slope
|
||||
% \draw[thick] (-3,0) -- (3,0);
|
||||
% \draw[thick] (-2,0) -- (2.5,2);
|
||||
%
|
||||
% % Draw angle arc and label
|
||||
% \draw (-,0) arc (0:\angle:0.5);
|
||||
% \node at (0.4,0.3) {$\theta$};
|
||||
%
|
||||
% % Draw simplified car (rectangle)
|
||||
% \begin{scope}[rotate=\angle,shift={(0,1)}]
|
||||
% \draw[thick] (-0.8,-0.4) rectangle (0.8,0.4);
|
||||
% % Add wheels
|
||||
% \fill[black] (-0.6,-0.4) circle (0.1);
|
||||
% \fill[black] (0.6,-0.4) circle (0.1);
|
||||
% \end{scope}
|
||||
%
|
||||
% % Draw forces
|
||||
% % Weight force
|
||||
% \draw[->,thick,red] (0,1) -- (0,0) node[midway,right] {$W$};
|
||||
%
|
||||
% % Normal force
|
||||
% \draw[->,thick,blue] (0,1) -- ({-sin(\angle)},{cos(\angle)}) node[midway,above left] {$N$};
|
||||
%
|
||||
% % Downslope component
|
||||
% \draw[->,thick,green] (0,1) -- ({cos(\angle)},{sin(\angle)}) node[midway,below right] {$W\sin\theta$};
|
||||
% \end{tikzpicture}
|
||||
|
||||
% Example of how to include an algorithm
|
||||
\begin{algorithm}[H]
|
||||
\caption{Your Algorithm Name}
|
||||
\begin{algorithmic}[1]
|
||||
\STATE Initialize parameters
|
||||
\WHILE{not converged}
|
||||
\STATE Update step
|
||||
\ENDWHILE
|
||||
\RETURN Result
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
|
||||
\section{Experiments and Results}
|
||||
Present your findings:
|
||||
\begin{itemize}
|
||||
\item Experimental setup
|
||||
\item Results and analysis
|
||||
\item Comparison with baselines (if applicable)
|
||||
\end{itemize}
|
||||
|
||||
% Example of how to include figures
|
||||
\begin{figure}[H]
|
||||
\begin{tikzpicture}[scale=1.5]
|
||||
% Define slope angle
|
||||
\def\angle{30}
|
||||
|
||||
% Calculate some points for consistent geometry
|
||||
\def\slopeStart{-2}
|
||||
\def\slopeEnd{2}
|
||||
\def\slopeHeight{2.309} % tan(30°) * 2
|
||||
|
||||
% Draw ground (horizontal line)
|
||||
\draw[thick] (-3,0) -- (3,0);
|
||||
|
||||
% Draw slope
|
||||
\draw[thick] (\slopeStart,0) -- (\slopeEnd,\slopeHeight);
|
||||
|
||||
|
||||
% Calculate car center position on slope
|
||||
\def\carX{0} % Center position along x-axis
|
||||
\def\carY{1.6} % tan(30°) * carX + appropriate offset
|
||||
|
||||
% Draw car (rectangle) exactly on slope
|
||||
\begin{scope}[shift={(\carX,\carY)}, rotate=\angle]
|
||||
\draw[thick] (-0.6,-0.3) rectangle (0.6,0.3);
|
||||
% Add wheels aligned with slope
|
||||
\fill[black] (-0.45,-0.3) circle (0.08);
|
||||
\fill[black] (0.45,-0.3) circle (0.08);
|
||||
\draw[->,thick] (0,0) -- ++(-0.8, 0) node[left] {$F_{slope} + F_{drag} + F_{rolling}$};
|
||||
\draw[->,thick] (0,0) -- ++(0.8, 0) node[right] {$F_{motor}$};
|
||||
\node at (0,0) [circle,fill,inner sep=1.5pt]{};
|
||||
\end{scope}
|
||||
|
||||
% Draw forces from center of car
|
||||
% Center point of car for forces
|
||||
\coordinate (carCenter) at (\carX,\carY);
|
||||
|
||||
|
||||
\end{tikzpicture}
|
||||
\centering
|
||||
\caption{Description of your figure}
|
||||
\label{fig:example}
|
||||
\caption{Free body diagram showing relevant forces on a 2-dimensional car}
|
||||
\label{fig:freebody}
|
||||
\end{figure}
|
||||
|
||||
% Example of how to include tables
|
||||
\begin{table}[H]
|
||||
To model the vehicle dynamics, we simplify the system to a 2d plane. As seen in Figure~\ref{fig:freebody}, the forces on the car
|
||||
are due to intrinsic vehicle properties, current velocity, and environment conditions like slope and wind. If the velocity is held
|
||||
constant, we can assert that the sum of the forces on the car is zero:
|
||||
\begin{align}
|
||||
F_{drag} + F_{slope} + F_{rolling} + F_{motor} &= 0 \\
|
||||
F_{drag} &= \frac{1}{2} \rho v^2 C_dA \\
|
||||
F_{slope} &= mg\sin {\theta} \\
|
||||
F_{rolling} &= mg\cos {\theta} C_{rr}
|
||||
\end{align}
|
||||
The $F_{motor}$ term is modulated by the driver. In our case, we give the agent a simpler control mechanism with a normalized
|
||||
velocity control instead of a force-based control. This is written as $v = \alpha v_{max}$ where $\alpha$ is the action taken
|
||||
by the agent in $\left[-1,1\right]$.
|
||||
From the velocity, and the forces acting on the car, we can determine
|
||||
the power of the car using a simple $K_t$ model:
|
||||
\begin{align}
|
||||
\tau &= \left(F_{drag} + F_{slope} + F_{rolling}\right) r \\
|
||||
P_{motor} &= \tau v + R_{motor} \left(\frac{\tau}{K_t}\right)^2
|
||||
\end{align}
|
||||
The torque of motor is the sum of the outstanding forces times the wheel radius. $K_t$ is a motor parameter, as is $R_{motor}$.
|
||||
Both can be extracted from physical motors to simulate them, but simple "rule-of-thumb" numbers were used during development.
|
||||
The power of the motor is given in watts. Based on the time-step of the simulation, we can determine the energy consumed in joules
|
||||
with $W \times s = J$. A time-step of 1 second was chosen to accelerate simulation. Lower values result in reduced integration
|
||||
errors over time at the cost of longer episodes.
|
||||
|
||||
|
||||
\subsection{Environment Generation}
|
||||
It is important that our agent learns not just the optimal policy for a fixed course, but an approximate optimal policy
|
||||
for any course. To this end we must be able to generate a wide variety of terrain and wind scenarios. Perlin noise
|
||||
is typically used in this context. We use a 1D Perlin noise to generate the slope of the terrain, and then integrate the slope to create
|
||||
the elevation profile. Currently the elevation profile is unused, but it can be important for drag force due to changes in air pressure.
|
||||
This was done because differentiated Perlin noise is not smooth, and is not an accurate representation of slope. The wind was
|
||||
generated with a 2D Perlin noise, where one axis was time, and the other was position. The noise was blurred in the time-axis
|
||||
to ease the changes in wind at any given point.
|
||||
An example of the environment can be seen in Figure~\ref{fig:env_vis}.
|
||||
|
||||
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\caption{Your Table Caption}
|
||||
\begin{tabular}{lcc}
|
||||
\toprule
|
||||
Method & Metric 1 & Metric 2 \\
|
||||
\midrule
|
||||
Approach 1 & Value & Value \\
|
||||
Approach 2 & Value & Value \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\label{tab:results}
|
||||
\end{table}
|
||||
\includegraphics[width=\textwidth]{environment.pdf}
|
||||
\caption{Visualization of the generated environment}
|
||||
\label{fig:env_vis}
|
||||
\end{figure}
|
||||
|
||||
\subsection{Performance Evaluation}
|
||||
|
||||
To quantify agent performance, we must produce a single value reward. While multi-objective learning is an interesting
|
||||
subject\footnote{I especially wanted to do meta-rl using a neural net to compute a single reward from inputs} it is out of the scope
|
||||
of this project. Additionally, sparse rewards can significantly slow down learning. A poor reward function can prevent
|
||||
agents from approaching optimal policy. With these factors in mind, we use the following:
|
||||
\[
|
||||
R = x/D_{goal} + (x > D_{goal}) * \left(100 - E - 10(t - T_{goal})\right) + (t > T_{goal}) * -500
|
||||
\]
|
||||
To understand this, there are three major components: the continuous reward, which is rewarded at every step, and is the position of the car
|
||||
relative to the goal distance. The victory reward is a constant, minus the energy used and the early arrival penalty.
|
||||
This was added to help guide the agent towards arriving with as little time left as possible. Finally, there's a penalty for the time
|
||||
going above the goal time, as after that point the car is disqualified from the race.
|
||||
|
||||
It took a few iterations to find a reward metric that promoted fast learning. Some of these issues were exacerbated by the initially low
|
||||
performance when using stable baselines. A crucial part of the improvement was the energy loss only being applied during wins.
|
||||
This allowed the model to quickly learn to go forward to finish, after which refinement of speed could take
|
||||
place\footnote{I looked into Q-initialization but couldn't figure out a way to implement it easily.}.
|
||||
|
||||
|
||||
\subsection{State and Observation Spaces}
|
||||
|
||||
The complete state of the simulator is the position, velocity, and energy of the car, as well as the entire environment.
|
||||
These parameters are sufficient for a deterministic snapshot of the simulator. However, one goal of the project
|
||||
was to enable partial-observation of the system. To this end, we separate the observation space into a small snippet
|
||||
of the upcoming wind and slope. This also simplifies the agent calculations since the view of the environment is
|
||||
relative to its current position. The size of the view can be controlled as a parameter.
|
||||
|
||||
|
||||
\section{Experiments and Results}
|
||||
|
||||
An implementation of the aforementioned simulator was developed with Jax. Jax was chosen as it enables
|
||||
vectorization and optimization to improve performance. Additionally, Jax allows for gradients of any function
|
||||
to be computed, which is useful for certain classes of reinforcement learning. In our case, we didn't
|
||||
use this as there seemed to be very little available off the shelf.
|
||||
|
||||
Initially Stable Baselines was used since it is one of the most popular implemntations of common RL algorithms.
|
||||
Stable Baselines3\cite{stable-baselines3} is written in PyTorch\cite{Ansel_PyTorch_2_Faster_2024}, and uses the Gym\cite{gymnasium} format for environments. A basic Gym wrapper
|
||||
was created to connect SB3 to our environment.
|
||||
PPO was chosen as the RL algorithm as it is very simple, while still being effective \cite{proximalpolicyoptimization}
|
||||
The performance and convergence was very bad. This made
|
||||
it difficult to diagnose as the model would need potentially millions of steps before it would learn anything interesting.
|
||||
The primary performance loss was in the Jax/Numpy/PyTorch conversion, as this requires a CPU roundtrip.
|
||||
To combat this I found a Jax-based implementation of PPO called \texttt{purejaxrl}. This library is
|
||||
written in the style of CleanRL but instead uses pure Jax and an environment library called \texttt{gymnax}\cite{gymnax2022github}.
|
||||
The primary advantage of writing everything in Jax is that both the RL agent and the environment can be offloaded to the GPU.
|
||||
Additionally, the infrastructure provided by \texttt{gymnax} allows for environments to be vectorized. The speedup from
|
||||
using this library cannot be understated. The SB3 PPO implementation ran at around 150 actions per second. After rewriting
|
||||
some of the code to make it work with \texttt{purejaxrl}, the effective action rate\footnote{I ran 2048 environments in parallel}
|
||||
was nearly$238000$ actions per second\footnote{It's likely that performance with SB3 could have been improved, but I was struggling to figure out exactly how.}.
|
||||
|
||||
|
||||
The episode returns after 50 million timesteps with a PPO model can be seen in Figure~\ref{fig:returns}. Each update step
|
||||
is performed after collecting minibatches of rollouts based on the current policy. We can see a clean ascent at the start of training,
|
||||
this is the agent learning to drive forward. After a certain point, the returns become noisy. This is likely due to energy scoring
|
||||
being random based on the terrain. A solution to this, which wasn't pursued due to lack of time, would be to compute the
|
||||
"nominal energy" use based on travelling at $v_{avg}$. Energy consumption that was above the nominal use would be penalized, and
|
||||
below would be heavily rewarded. Despite this, performance continued to improve, which is a good sign for the agent being
|
||||
able to learn the underlying dynamics.
|
||||
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\includegraphics[width=0.8\textwidth]{PPO_results.pdf}
|
||||
\caption{Episodic Returns during PPO training}
|
||||
\label{fig:returns}
|
||||
\end{figure}
|
||||
|
||||
|
||||
Initially I thought that this was actually pretty impressive, but I looked at an individual level
|
||||
and it seemed to just drive forward too fast. Reworking the reward function might cause this to converge better.
|
||||
I wish I had a graph, but I keep running out of memory when I try to capture a rollout.
|
||||
|
||||
|
||||
\section{Discussion}
|
||||
Analyze your results:
|
||||
\begin{itemize}
|
||||
\item Interpretation of findings
|
||||
\item Limitations and challenges
|
||||
\item Potential improvements
|
||||
\end{itemize}
|
||||
|
||||
While the PPO performance was decent, it still had a significant amount of improvement on the table. Tuning the reward function
|
||||
would probably help it find a solution better. One strategy that would help significantly is to pre-tune the model to output the
|
||||
average speed by default, so the model doesn't have to learn that at the beginning. This is called Q-initialization and is a common
|
||||
trick for problem spaces where an initial estimate exists and is easy to define. Perhaps the most important takeaway from this
|
||||
work is the power of end-to-end Jax RL. \texttt{purejaxrl} is CleanRL levels of code clarity, with everything for an agent
|
||||
being contained in one file, but surpassing Stable Baselines3 significantly in terms of performance. One drawback is that
|
||||
the ecosystem is very new, so there was very little to reference when I was developing my simulator. Often there would be
|
||||
an opaque error message that would yield no results on search engines, and would require digging into the Jax source code to diagnose.
|
||||
Typically this was some misunderstanding about the inner works of Jax. Future work on this project would involve trying out
|
||||
other agents, and comparing different reward functions. Adjusting the actor-critic network would also be an interesting avenue,
|
||||
especially since a CNN will likely work well with wind and cloud information, which have both a spatial and temporal
|
||||
axis\footnote{You can probably tell that the quality dropped off near the end - bunch of life things got in the way, so this didn't go as well as I'd hoped. Learned a lot though.}.
|
||||
|
||||
|
||||
|
||||
\section{Conclusion}
|
||||
Summarize your work:
|
||||
\begin{itemize}
|
||||
\item Key contributions
|
||||
\item Practical implications
|
||||
\item Future work directions
|
||||
\end{itemize}
|
||||
|
||||
We outline the design of a physics based model of solar car races. We implement this model and create a simulation environment
|
||||
for use with popular RL algorithm packages. We demonstrate the performance and learning ability of these algorithms on our model.
|
||||
Further work includes more accurate modelling, improved reward functions, and hyperparameter tuning.
|
||||
|
||||
\bibliography{references}
|
||||
\bibliographystyle{plain}
|
||||
|
|
368
src/solarcarsim/cleanrl_td3_jax.py
Normal file
368
src/solarcarsim/cleanrl_td3_jax.py
Normal file
|
@ -0,0 +1,368 @@
|
|||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import gymnasium as gym
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tyro
|
||||
from flax.training.train_state import TrainState
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
exp_name: str = os.path.basename(__file__)[: -len(".py")]
|
||||
"""the name of this experiment"""
|
||||
seed: int = 1
|
||||
"""seed of the experiment"""
|
||||
track: bool = False
|
||||
"""if toggled, this experiment will be tracked with Weights and Biases"""
|
||||
wandb_project_name: str = "cleanRL"
|
||||
"""the wandb's project name"""
|
||||
wandb_entity: str = None
|
||||
"""the entity (team) of wandb's project"""
|
||||
capture_video: bool = False
|
||||
"""whether to capture videos of the agent performances (check out `videos` folder)"""
|
||||
save_model: bool = False
|
||||
"""whether to save model into the `runs/{run_name}` folder"""
|
||||
upload_model: bool = False
|
||||
"""whether to upload the saved model to huggingface"""
|
||||
hf_entity: str = ""
|
||||
"""the user or org name of the model repository from the Hugging Face Hub"""
|
||||
|
||||
# Algorithm specific arguments
|
||||
env_id: str = "MountainCarContinuous-v0"
|
||||
"""the id of the environment"""
|
||||
total_timesteps: int = 1000000
|
||||
"""total timesteps of the experiments"""
|
||||
learning_rate: float = 3e-4
|
||||
"""the learning rate of the optimizer"""
|
||||
buffer_size: int = int(1e6)
|
||||
"""the replay memory buffer size"""
|
||||
gamma: float = 0.99
|
||||
"""the discount factor gamma"""
|
||||
tau: float = 0.005
|
||||
"""target smoothing coefficient (default: 0.005)"""
|
||||
batch_size: int = 256
|
||||
"""the batch size of sample from the reply memory"""
|
||||
policy_noise: float = 0.2
|
||||
"""the scale of policy noise"""
|
||||
exploration_noise: float = 0.1
|
||||
"""the scale of exploration noise"""
|
||||
learning_starts: int = 25e3
|
||||
"""timestep to start learning"""
|
||||
policy_frequency: int = 2
|
||||
"""the frequency of training policy (delayed)"""
|
||||
noise_clip: float = 0.5
|
||||
"""noise clip parameter of the Target Policy Smoothing Regularization"""
|
||||
|
||||
|
||||
def make_env(env_id, seed, idx, capture_video, run_name):
|
||||
def thunk():
|
||||
if capture_video and idx == 0:
|
||||
env = gym.make(env_id, render_mode="rgb_array")
|
||||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env)
|
||||
env.action_space.seed(seed)
|
||||
return env
|
||||
|
||||
return thunk
|
||||
|
||||
|
||||
# ALGO LOGIC: initialize agent here:
|
||||
class QNetwork(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
|
||||
x = jnp.concatenate([x, a], -1)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(1)(x)
|
||||
return x
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
action_dim: int
|
||||
action_scale: jnp.ndarray
|
||||
action_bias: jnp.ndarray
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(self.action_dim)(x)
|
||||
x = nn.tanh(x)
|
||||
x = x * self.action_scale + self.action_bias
|
||||
return x
|
||||
|
||||
|
||||
class TrainState(TrainState):
|
||||
target_params: flax.core.FrozenDict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
if sb3.__version__ < "2.0":
|
||||
raise ValueError(
|
||||
"""Ongoing migration: run the following command to install the new dependencies:
|
||||
poetry run pip install "stable_baselines3==2.0.0a1"
|
||||
"""
|
||||
)
|
||||
args = tyro.cli(Args)
|
||||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
|
||||
if args.track:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb_project_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=vars(args),
|
||||
name=run_name,
|
||||
monitor_gym=True,
|
||||
save_code=True,
|
||||
)
|
||||
writer = SummaryWriter(f"runs/{run_name}")
|
||||
writer.add_text(
|
||||
"hyperparameters",
|
||||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: seeding
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
key = jax.random.PRNGKey(args.seed)
|
||||
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)
|
||||
|
||||
# env setup
|
||||
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
|
||||
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
|
||||
|
||||
max_action = float(envs.single_action_space.high[0])
|
||||
envs.single_observation_space.dtype = np.float32
|
||||
rb = ReplayBuffer(
|
||||
args.buffer_size,
|
||||
envs.single_observation_space,
|
||||
envs.single_action_space,
|
||||
device="cpu",
|
||||
handle_timeout_termination=False,
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: start the game
|
||||
obs, _ = envs.reset(seed=args.seed)
|
||||
|
||||
actor = Actor(
|
||||
action_dim=np.prod(envs.single_action_space.shape),
|
||||
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0),
|
||||
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0),
|
||||
)
|
||||
actor_state = TrainState.create(
|
||||
apply_fn=actor.apply,
|
||||
params=actor.init(actor_key, obs),
|
||||
target_params=actor.init(actor_key, obs),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
qf = QNetwork()
|
||||
qf1_state = TrainState.create(
|
||||
apply_fn=qf.apply,
|
||||
params=qf.init(qf1_key, obs, envs.action_space.sample()),
|
||||
target_params=qf.init(qf1_key, obs, envs.action_space.sample()),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
qf2_state = TrainState.create(
|
||||
apply_fn=qf.apply,
|
||||
params=qf.init(qf2_key, obs, envs.action_space.sample()),
|
||||
target_params=qf.init(qf2_key, obs, envs.action_space.sample()),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
actor.apply = jax.jit(actor.apply)
|
||||
qf.apply = jax.jit(qf.apply)
|
||||
|
||||
@jax.jit
|
||||
def update_critic(
|
||||
actor_state: TrainState,
|
||||
qf1_state: TrainState,
|
||||
qf2_state: TrainState,
|
||||
observations: np.ndarray,
|
||||
actions: np.ndarray,
|
||||
next_observations: np.ndarray,
|
||||
rewards: np.ndarray,
|
||||
terminations: np.ndarray,
|
||||
key: jnp.ndarray,
|
||||
):
|
||||
# TODO Maybe pre-generate a lot of random keys
|
||||
# also check https://jax.readthedocs.io/en/latest/jax.random.html
|
||||
key, noise_key = jax.random.split(key, 2)
|
||||
clipped_noise = (
|
||||
jnp.clip(
|
||||
(jax.random.normal(noise_key, actions.shape) * args.policy_noise),
|
||||
-args.noise_clip,
|
||||
args.noise_clip,
|
||||
)
|
||||
* actor.action_scale
|
||||
)
|
||||
next_state_actions = jnp.clip(
|
||||
actor.apply(actor_state.target_params, next_observations) + clipped_noise,
|
||||
envs.single_action_space.low,
|
||||
envs.single_action_space.high,
|
||||
)
|
||||
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)
|
||||
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1)
|
||||
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target)
|
||||
next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1)
|
||||
|
||||
def mse_loss(params):
|
||||
qf_a_values = qf.apply(params, observations, actions).squeeze()
|
||||
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()
|
||||
|
||||
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)
|
||||
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params)
|
||||
qf1_state = qf1_state.apply_gradients(grads=grads1)
|
||||
qf2_state = qf2_state.apply_gradients(grads=grads2)
|
||||
|
||||
return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key
|
||||
|
||||
@jax.jit
|
||||
def update_actor(
|
||||
actor_state: TrainState,
|
||||
qf1_state: TrainState,
|
||||
qf2_state: TrainState,
|
||||
observations: np.ndarray,
|
||||
):
|
||||
def actor_loss(params):
|
||||
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean()
|
||||
|
||||
actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params)
|
||||
actor_state = actor_state.apply_gradients(grads=grads)
|
||||
actor_state = actor_state.replace(
|
||||
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
|
||||
)
|
||||
|
||||
qf1_state = qf1_state.replace(
|
||||
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau)
|
||||
)
|
||||
qf2_state = qf2_state.replace(
|
||||
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau)
|
||||
)
|
||||
return actor_state, (qf1_state, qf2_state), actor_loss_value
|
||||
|
||||
start_time = time.time()
|
||||
for global_step in range(args.total_timesteps):
|
||||
# ALGO LOGIC: put action logic here
|
||||
if global_step < args.learning_starts:
|
||||
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
|
||||
else:
|
||||
actions = actor.apply(actor_state.params, obs)
|
||||
actions = np.array(
|
||||
[
|
||||
(
|
||||
jax.device_get(actions)[0]
|
||||
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape)
|
||||
).clip(envs.single_action_space.low, envs.single_action_space.high)
|
||||
]
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: execute the game and log data.
|
||||
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
|
||||
|
||||
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
||||
if "final_info" in infos:
|
||||
for info in infos["final_info"]:
|
||||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
|
||||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
|
||||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
|
||||
break
|
||||
|
||||
# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
|
||||
real_next_obs = next_obs.copy()
|
||||
for idx, trunc in enumerate(truncations):
|
||||
if trunc:
|
||||
real_next_obs[idx] = infos["final_observation"][idx]
|
||||
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
|
||||
|
||||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
|
||||
obs = next_obs
|
||||
|
||||
# ALGO LOGIC: training.
|
||||
if global_step > args.learning_starts:
|
||||
data = rb.sample(args.batch_size)
|
||||
|
||||
(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic(
|
||||
actor_state,
|
||||
qf1_state,
|
||||
qf2_state,
|
||||
data.observations.numpy(),
|
||||
data.actions.numpy(),
|
||||
data.next_observations.numpy(),
|
||||
data.rewards.flatten().numpy(),
|
||||
data.dones.flatten().numpy(),
|
||||
key,
|
||||
)
|
||||
|
||||
if global_step % args.policy_frequency == 0:
|
||||
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor(
|
||||
actor_state,
|
||||
qf1_state,
|
||||
qf2_state,
|
||||
data.observations.numpy(),
|
||||
)
|
||||
|
||||
if global_step % 100 == 0:
|
||||
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
|
||||
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step)
|
||||
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
|
||||
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step)
|
||||
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
|
||||
print("SPS:", int(global_step / (time.time() - start_time)))
|
||||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
|
||||
|
||||
if args.save_model:
|
||||
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(
|
||||
flax.serialization.to_bytes(
|
||||
[
|
||||
actor_state.params,
|
||||
qf1_state.params,
|
||||
qf2_state.params,
|
||||
]
|
||||
)
|
||||
)
|
||||
print(f"model saved to {model_path}")
|
||||
from cleanrl_utils.evals.td3_jax_eval import evaluate
|
||||
|
||||
episodic_returns = evaluate(
|
||||
model_path,
|
||||
make_env,
|
||||
args.env_id,
|
||||
eval_episodes=10,
|
||||
run_name=f"{run_name}-eval",
|
||||
Model=(Actor, QNetwork),
|
||||
exploration_noise=args.exploration_noise,
|
||||
)
|
||||
for idx, episodic_return in enumerate(episodic_returns):
|
||||
writer.add_scalar("eval/episodic_return", episodic_return, idx)
|
||||
|
||||
if args.upload_model:
|
||||
from cleanrl_utils.huggingface import push_to_hub
|
||||
|
||||
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
|
||||
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
|
||||
push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval")
|
||||
|
||||
envs.close()
|
||||
writer.close()
|
|
@ -66,12 +66,12 @@ def downslope_force(mass, theta):
|
|||
return mass * 9.8 * jnp.sin(theta)
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["crr"])
|
||||
@jit
|
||||
def rolling_force(mass, theta, crr):
|
||||
return normal_force(mass, theta) * crr
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["area", "cd", "rho"])
|
||||
@jit
|
||||
def drag_force(u, area, cd, rho):
|
||||
return 0.5 * rho * jnp.pow(u, 2) * cd * area
|
||||
|
||||
|
|
|
@ -93,7 +93,7 @@ class SolarRaceV1(gym.Env):
|
|||
}
|
||||
)
|
||||
|
||||
self.action_space = gym.spaces.Box(-1.0, 1.0, shape=(1,)) # velocity, m/s
|
||||
self.action_space = gym.spaces.Box(0.0, 1.0, shape=(1,)) # velocity, m/s
|
||||
|
||||
|
||||
def reset(self, *, seed = None, options = None):
|
||||
|
@ -117,7 +117,7 @@ class SolarRaceV1(gym.Env):
|
|||
reward -= 600 - self._state[1][0]
|
||||
reward += 1e-6 * (self._state[2][0]) # net energy is negative.
|
||||
if jnp.all(self._state[1] > 600):
|
||||
reward -= 50000
|
||||
reward -= 500
|
||||
truncated = True
|
||||
|
||||
return self._get_obs(), reward, terminated, truncated, {}
|
||||
|
|
|
@ -5,7 +5,7 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
import chex
|
||||
from flax import struct
|
||||
from jax import lax
|
||||
from jax import lax, vmap
|
||||
from gymnax.environments import environment
|
||||
from gymnax.environments import spaces
|
||||
|
||||
|
@ -59,19 +59,6 @@ class Snax(environment.Environment[SimState, SimParams]):
|
|||
dtype=jnp.float32,
|
||||
)
|
||||
return spaces.Box(low, high, shape=(shape,))
|
||||
# return spaces.Dict(
|
||||
# {
|
||||
# "position": spaces.Box(0.0, params.map_size, (), jnp.float32),
|
||||
# "realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
|
||||
# "energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
|
||||
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
|
||||
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
|
||||
# "upcoming_terrain": spaces.Box(
|
||||
# -1.0, 1.0, shape=(100,), dtype=jnp.float32
|
||||
# ),
|
||||
# # skip wind for now
|
||||
# }
|
||||
# )
|
||||
|
||||
def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict:
|
||||
if params is None:
|
||||
|
@ -170,9 +157,15 @@ class Snax(environment.Environment[SimState, SimParams]):
|
|||
# ):
|
||||
# reward -= 500
|
||||
|
||||
# we have to vectorize that.
|
||||
# # we have to vectorize that.
|
||||
# reward = new_state.position / params.goal_dist # constant reward for moving forward
|
||||
# # reward for finishing
|
||||
# reward += (new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy)
|
||||
# # reward for failure
|
||||
# reward += (new_state.realtime >= params.goal_time) * -500
|
||||
|
||||
reward = new_state.position / params.goal_dist + \
|
||||
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy) + \
|
||||
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-6*new_state.energy) + \
|
||||
(new_state.realtime >= params.goal_time) * -500
|
||||
reward = reward.squeeze()
|
||||
terminal = self.is_terminal(state, params)
|
||||
|
|
Loading…
Reference in a new issue