111 lines
21 KiB
Plaintext
111 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"from gymnasium.utils.env_checker import check_env as gym_check_env\n",
|
|
"from stable_baselines3 import TD3\n",
|
|
"from stable_baselines3.common.env_checker import check_env\n",
|
|
"from gymnasium.wrappers.jax_to_numpy import JaxToNumpy\n",
|
|
"from gymnasium.wrappers.vector import JaxToNumpy as VJaxToNumpy\n",
|
|
"from gymnax.wrappers.gym import GymnaxToVectorGymWrapper, GymnaxToGymWrapper\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import jax.numpy as jnp\n",
|
|
"from solarcarsim.simv2 import Snax\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"\n",
|
|
"env = Snax()\n",
|
|
"wrapped_env = GymnaxToGymWrapper(env)\n",
|
|
"vector_gym_env = GymnaxToVectorGymWrapper(env)\n",
|
|
"np_wrapper = JaxToNumpy(wrapped_env)\n",
|
|
"np_vec_wrapper = VJaxToNumpy(vector_gym_env)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using cuda device\n",
|
|
"Wrapping the env with a `Monitor` wrapper\n",
|
|
"Wrapping the env in a DummyVecEnv.\n"
|
|
]
|
|
},
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
|
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
|
|
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m model \u001b[38;5;241m=\u001b[39m TD3(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mMlpPolicy\u001b[39m\u001b[38;5;124m\"\u001b[39m, np_wrapper, verbose\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1000\u001b[39;49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/td3/td3.py:222\u001b[0m, in \u001b[0;36mTD3.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 213\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mlearn\u001b[39m(\n\u001b[1;32m 214\u001b[0m \u001b[38;5;28mself\u001b[39m: SelfTD3,\n\u001b[1;32m 215\u001b[0m total_timesteps: \u001b[38;5;28mint\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 220\u001b[0m progress_bar: \u001b[38;5;28mbool\u001b[39m \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m 221\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m SelfTD3:\n\u001b[0;32m--> 222\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearn\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 223\u001b[0m \u001b[43m \u001b[49m\u001b[43mtotal_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtotal_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 224\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 225\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 226\u001b[0m \u001b[43m \u001b[49m\u001b[43mtb_log_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtb_log_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 227\u001b[0m \u001b[43m \u001b[49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mreset_num_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 228\u001b[0m \u001b[43m \u001b[49m\u001b[43mprogress_bar\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprogress_bar\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 229\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:328\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.learn\u001b[0;34m(self, total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps, progress_bar)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrain_freq, TrainFreq) \u001b[38;5;66;03m# check done in _setup_learn()\u001b[39;00m\n\u001b[1;32m 327\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m<\u001b[39m total_timesteps:\n\u001b[0;32m--> 328\u001b[0m rollout \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcollect_rollouts\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 329\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 330\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrain_freq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 331\u001b[0m \u001b[43m \u001b[49m\u001b[43maction_noise\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43maction_noise\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 332\u001b[0m \u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallback\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 333\u001b[0m \u001b[43m \u001b[49m\u001b[43mlearning_starts\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlearning_starts\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 334\u001b[0m \u001b[43m \u001b[49m\u001b[43mreplay_buffer\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreplay_buffer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 335\u001b[0m \u001b[43m \u001b[49m\u001b[43mlog_interval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlog_interval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 336\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m rollout\u001b[38;5;241m.\u001b[39mcontinue_training:\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mbreak\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/off_policy_algorithm.py:560\u001b[0m, in \u001b[0;36mOffPolicyAlgorithm.collect_rollouts\u001b[0;34m(self, env, callback, train_freq, replay_buffer, action_noise, learning_starts, log_interval)\u001b[0m\n\u001b[1;32m 557\u001b[0m actions, buffer_actions \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sample_action(learning_starts, action_noise, env\u001b[38;5;241m.\u001b[39mnum_envs)\n\u001b[1;32m 559\u001b[0m \u001b[38;5;66;03m# Rescale and perform action\u001b[39;00m\n\u001b[0;32m--> 560\u001b[0m new_obs, rewards, dones, infos \u001b[38;5;241m=\u001b[39m \u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mactions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 562\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_timesteps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mnum_envs\n\u001b[1;32m 563\u001b[0m num_collected_steps \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/base_vec_env.py:206\u001b[0m, in \u001b[0;36mVecEnv.step\u001b[0;34m(self, actions)\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 200\u001b[0m \u001b[38;5;124;03mStep the environments with the given action\u001b[39;00m\n\u001b[1;32m 201\u001b[0m \n\u001b[1;32m 202\u001b[0m \u001b[38;5;124;03m:param actions: the action\u001b[39;00m\n\u001b[1;32m 203\u001b[0m \u001b[38;5;124;03m:return: observation, reward, done, information\u001b[39;00m\n\u001b[1;32m 204\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 205\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstep_async(actions)\n\u001b[0;32m--> 206\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_wait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py:58\u001b[0m, in \u001b[0;36mDummyVecEnv.step_wait\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 55\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mstep_wait\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m VecEnvStepReturn:\n\u001b[1;32m 56\u001b[0m \u001b[38;5;66;03m# Avoid circular imports\u001b[39;00m\n\u001b[1;32m 57\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m env_idx \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnum_envs):\n\u001b[0;32m---> 58\u001b[0m obs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_rews[env_idx], terminated, truncated, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_infos[env_idx] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menvs\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 59\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mactions\u001b[49m\u001b[43m[\u001b[49m\u001b[43menv_idx\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 60\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 61\u001b[0m \u001b[38;5;66;03m# convert to SB3 VecEnv api\u001b[39;00m\n\u001b[1;32m 62\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbuf_dones[env_idx] \u001b[38;5;241m=\u001b[39m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/stable_baselines3/common/monitor.py:94\u001b[0m, in \u001b[0;36mMonitor.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mneeds_reset:\n\u001b[1;32m 93\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTried to step environment that needs reset\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 94\u001b[0m observation, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43maction\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 95\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrewards\u001b[38;5;241m.\u001b[39mappend(\u001b[38;5;28mfloat\u001b[39m(reward))\n\u001b[1;32m 96\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m terminated \u001b[38;5;129;01mor\u001b[39;00m truncated:\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnasium/wrappers/jax_to_numpy.py:166\u001b[0m, in \u001b[0;36mJaxToNumpy.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Transforms the action to a jax array .\u001b[39;00m\n\u001b[1;32m 158\u001b[0m \n\u001b[1;32m 159\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;124;03m A tuple containing numpy versions of the next observation, reward, termination, truncation, and extra info.\u001b[39;00m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 165\u001b[0m jax_action \u001b[38;5;241m=\u001b[39m numpy_to_jax(action)\n\u001b[0;32m--> 166\u001b[0m obs, reward, terminated, truncated, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mjax_action\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 168\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m (\n\u001b[1;32m 169\u001b[0m jax_to_numpy(obs),\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28mfloat\u001b[39m(reward),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 173\u001b[0m jax_to_numpy(info),\n\u001b[1;32m 174\u001b[0m )\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/wrappers/gym.py:70\u001b[0m, in \u001b[0;36mGymnaxToGymWrapper.step\u001b[0;34m(self, action)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Step environment, follow new step API.\"\"\"\u001b[39;00m\n\u001b[1;32m 69\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng, step_key \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrng)\n\u001b[0;32m---> 70\u001b[0m o, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39menv_state, r, d, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_env\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 71\u001b[0m \u001b[43m \u001b[49m\u001b[43mstep_key\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43menv_params\u001b[49m\n\u001b[1;32m 72\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m o, r, d, d, info\n",
|
|
" \u001b[0;31m[... skipping hidden 15 frame]\u001b[0m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/gymnax/environments/environment.py:45\u001b[0m, in \u001b[0;36mEnvironment.step\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 43\u001b[0m params \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdefault_params\n\u001b[1;32m 44\u001b[0m key, key_reset \u001b[38;5;241m=\u001b[39m jax\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39msplit(key)\n\u001b[0;32m---> 45\u001b[0m obs_st, state_st, reward, done, info \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep_env\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 46\u001b[0m obs_re, state_re \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_env(key_reset, params)\n\u001b[1;32m 47\u001b[0m \u001b[38;5;66;03m# Auto-reset environment based on termination\u001b[39;00m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/src/solarcarsim/simv2.py:138\u001b[0m, in \u001b[0;36mSnax.step_env\u001b[0;34m(self, key, state, action, params)\u001b[0m\n\u001b[1;32m 136\u001b[0m theta \u001b[38;5;241m=\u001b[39m state\u001b[38;5;241m.\u001b[39mslope[pos]\n\u001b[1;32m 137\u001b[0m velocity \u001b[38;5;241m=\u001b[39m jnp\u001b[38;5;241m.\u001b[39marray([action \u001b[38;5;241m*\u001b[39m params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmax_speed])\u001b[38;5;241m.\u001b[39msqueeze()\n\u001b[0;32m--> 138\u001b[0m dragf \u001b[38;5;241m=\u001b[39m \u001b[43msim\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_force\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 139\u001b[0m \u001b[43m \u001b[49m\u001b[43mvelocity\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfrontal_area\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcar\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdrag_coeff\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1.184\u001b[39;49m\n\u001b[1;32m 140\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m rollf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mrolling_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta, params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mrolling_coeff)\n\u001b[1;32m 142\u001b[0m hillf \u001b[38;5;241m=\u001b[39m sim\u001b[38;5;241m.\u001b[39mdownslope_force(params\u001b[38;5;241m.\u001b[39mcar\u001b[38;5;241m.\u001b[39mmass, theta)\n",
|
|
" \u001b[0;31m[... skipping hidden 3 frame]\u001b[0m\n",
|
|
"File \u001b[0;32m~/Documents/Code/solarcarsim/.venv/lib/python3.12/site-packages/jax/_src/pjit.py:768\u001b[0m, in \u001b[0;36m_infer_params\u001b[0;34m(fun, ji, args, kwargs)\u001b[0m\n\u001b[1;32m 764\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,\n\u001b[1;32m 765\u001b[0m kwargs, in_avals\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[1;32m 766\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m p, p\u001b[38;5;241m.\u001b[39mconsts \u001b[38;5;241m+\u001b[39m args_flat\n\u001b[0;32m--> 768\u001b[0m entry \u001b[38;5;241m=\u001b[39m \u001b[43m_infer_params_cached\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 769\u001b[0m \u001b[43m \u001b[49m\u001b[43mfun\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mji\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msignature\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mavals\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpjit_mesh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mresource_env\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m entry\u001b[38;5;241m.\u001b[39mpjit_params \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 771\u001b[0m p, args_flat \u001b[38;5;241m=\u001b[39m _infer_params_impl(\n\u001b[1;32m 772\u001b[0m fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals\u001b[38;5;241m=\u001b[39mavals)\n",
|
|
"\u001b[0;31mValueError\u001b[0m: Non-hashable static arguments are not supported. An error occurred while trying to hash an object of type <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace>. The error was:\nTypeError: unhashable type: 'DynamicJaxprTracer'\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"\n",
|
|
"model = TD3(\"MlpPolicy\", np_wrapper, verbose=1)\n",
|
|
"model.learn(total_timesteps=1000)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": ".venv",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.12.7"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|