jax cuda, separate v1 sim

This commit is contained in:
saji 2024-12-13 19:40:15 -06:00
parent 104dfab637
commit aaac5df404
6 changed files with 6353 additions and 208 deletions

File diff suppressed because it is too large Load diff

128
pdm.lock
View file

@ -5,7 +5,7 @@
groups = ["default", "dev"]
strategy = ["inherit_metadata"]
lock_version = "4.5.0"
content_hash = "sha256:a1edba805cc867a6316cea6c754bc112f0f79046b604ee515c541505f9c546f7"
content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4"
[[metadata.targets]]
requires_python = ">=3.12,<3.13"
@ -522,13 +522,13 @@ files = [
[[package]]
name = "jax"
version = "0.4.35"
version = "0.4.37"
requires_python = ">=3.10"
summary = "Differentiate, compile, and transform Numpy code."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jaxlib<=0.4.35,>=0.4.34",
"jaxlib<=0.4.37,>=0.4.36",
"ml-dtypes>=0.4.0",
"numpy>=1.24",
"numpy>=1.26.0; python_version >= \"3.12\"",
@ -537,13 +537,83 @@ dependencies = [
"scipy>=1.11.1; python_version >= \"3.12\"",
]
files = [
{file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"},
{file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"},
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
]
[[package]]
name = "jax-cuda12-pjrt"
version = "0.4.36"
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-pjrt==0.4.36",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
extras = ["with_cuda"]
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin==0.4.36",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.6.85",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12<10.0,>=9.1",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
"nvidia-nvjitlink-cu12>=12.1.105",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax"
version = "0.4.37"
extras = ["cuda12"]
requires_python = ">=3.10"
summary = "Differentiate, compile, and transform Numpy code."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
"jax==0.4.37",
"jaxlib==0.4.36",
]
files = [
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
]
[[package]]
name = "jaxlib"
version = "0.4.35"
version = "0.4.36"
requires_python = ">=3.10"
summary = "XLA library for JAX"
groups = ["default"]
@ -555,16 +625,11 @@ dependencies = [
"scipy>=1.11.1; python_version >= \"3.12\"",
]
files = [
{file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"},
{file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"},
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"},
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"},
{file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"},
{file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"},
{file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"},
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"},
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"},
{file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"},
{file = "jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d"},
{file = "jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e"},
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442"},
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357"},
{file = "jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3"},
]
[[package]]
@ -861,7 +926,7 @@ version = "12.4.5.8"
requires_python = ">=3"
summary = "CUBLAS native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
@ -874,13 +939,26 @@ version = "12.4.127"
requires_python = ">=3"
summary = "CUDA profiling tools runtime libs."
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
]
[[package]]
name = "nvidia-cuda-nvcc-cu12"
version = "12.6.85"
requires_python = ">=3"
summary = "CUDA nvcc"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
]
[[package]]
name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127"
@ -900,7 +978,7 @@ version = "12.4.127"
requires_python = ">=3"
summary = "CUDA Runtime native Libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
@ -913,7 +991,7 @@ version = "9.1.0.70"
requires_python = ">=3"
summary = "cuDNN runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-cublas-cu12",
]
@ -928,7 +1006,7 @@ version = "11.2.1.3"
requires_python = ">=3"
summary = "CUFFT native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-nvjitlink-cu12",
]
@ -957,7 +1035,7 @@ version = "11.6.1.9"
requires_python = ">=3"
summary = "CUDA solver native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-cublas-cu12",
"nvidia-cusparse-cu12",
@ -975,7 +1053,7 @@ version = "12.3.1.170"
requires_python = ">=3"
summary = "CUSPARSE native runtime libraries"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"nvidia-nvjitlink-cu12",
]
@ -991,7 +1069,7 @@ version = "2.21.5"
requires_python = ">=3"
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
]
@ -1002,7 +1080,7 @@ version = "12.4.127"
requires_python = ">=3"
summary = "Nvidia JIT LTO Library"
groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},

View file

@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
authors = [
{name = "saji", email = "saji@saji.dev"},
]
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.35", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
requires-python = ">=3.10,<3.13"
readme = "README.md"
license = {text = "MIT"}

View file

@ -1,3 +1,5 @@
"""Physical equations and models for building a simulation environment"""
import jax.numpy as jnp
import jax
from jax import grad, jit, vmap, lax
@ -5,7 +7,11 @@ from functools import partial
from typing import NamedTuple, Tuple
from solarcarsim.noise import fractal_noise_1d, generate_elevation_profile, generate_wind_field
from solarcarsim.noise import (
fractal_noise_1d,
generate_wind_field,
)
class MotorParams(NamedTuple):
kv: float
@ -16,33 +22,37 @@ class MotorParams(NamedTuple):
class BatteryParams(NamedTuple):
shape: Tuple[int, int] # (series,parallel) array of batteries
resistance: float # ohms
initial_energy: float # joules
shape: Tuple[int, int] # (series,parallel) array of batteries
resistance: float # ohms
initial_energy: float # joules
class CarParams(NamedTuple):
""" Physical Data for Solar Car Parameters """
mass: float = 800 # kg
frontal_area: float = 1.3 # m^2
drag_coeff: float = 0.18 # drag coefficient, dimensionless
rolling_coeff: float = 0.002 # rolling resistance.
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
wheel_radius: float = 0.23 # wheel radius in meters
max_speed: float = 30.0 # m/s top speed
solar_area: float = 5.0 # m^2, typically 5.0
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
n_motors: int = 2 # how many motors we have.
motor: MotorParams = MotorParams(8.43, 1.1, 100.0, 0.001, 0.001) # mitsuba m2090 estimate
battery: BatteryParams = BatteryParams((36,19), 0.0126, 66.6e3) # freebasing 50s pack.
"""Physical Data for Solar Car Parameters"""
mass: float = 800 # kg
frontal_area: float = 1.3 # m^2
drag_coeff: float = 0.18 # drag coefficient, dimensionless
rolling_coeff: float = 0.002 # rolling resistance.
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
wheel_radius: float = 0.23 # wheel radius in meters
max_speed: float = 30.0 # m/s top speed
solar_area: float = 5.0 # m^2, typically 5.0
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
n_motors: int = 2 # how many motors we have.
motor: MotorParams = MotorParams(
8.43, 1.1, 100.0, 0.001, 0.001
) # mitsuba m2090 estimate
battery: BatteryParams = BatteryParams(
(36, 19), 0.0126, 66.6e3
) # freebasing 50s pack.
def DefaultCar() -> CarParams:
""" Creates a basic car """
"""Creates a basic car"""
return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23)
# some physics equations using jax
@ -50,26 +60,31 @@ def DefaultCar() -> CarParams:
def normal_force(mass, theta):
return mass * 9.8 * jnp.cos(theta)
@jit
def downslope_force(mass, theta):
return mass * 9.8 * jnp.sin(theta)
@partial(jit, static_argnames=['crr'])
@partial(jit, static_argnames=["crr"])
def rolling_force(mass, theta, crr):
return normal_force(mass, theta) * crr
@partial(jit, static_argnames=['area', 'cd', 'rho'])
@partial(jit, static_argnames=["area", "cd", "rho"])
def drag_force(u, area, cd, rho):
return 0.5 * rho * jnp.pow(u, 2) * cd * area
# we can use those forces above to determine what forces we have to overcome. Sum(F)=0
# @partial(jit, static_argnums=(2,))
@jit
def bldc_power_draw(torque, velocity, params: MotorParams):
"""
Approximates power draw of a BLDC motor outputting a torque at a given velocity
Args:
torq: Applied force in Newton/meters
velocity: Angular velocity in rad/s
@ -77,32 +92,32 @@ def bldc_power_draw(torque, velocity, params: MotorParams):
kt: Torque constant (Nm/A)
friction_coeff: Mechanical friction coefficient
iron_loss_coeff: Iron loss coefficient (core losses)
Returns:
Total electrical power draw in Watts
"""
# Current required for torque (simplified relationship)
current = torque / params.kt
# Copper losses (I²R)
copper_losses = params.resistance * current**2
copper_losses = params.resistance * current**2
# Mechanical friction losses
friction_losses = params.friction_coeff * velocity**2
friction_losses = params.friction_coeff * velocity**2
# Iron losses (simplified model - primarily dependent on speed)
iron_losses = params.iron_coeff * velocity**2
iron_losses = params.iron_coeff * velocity**2
# Mechanical power output
mechanical_power = torque * velocity
# Total electrical power input
total_power = mechanical_power + copper_losses + friction_losses + iron_losses
return total_power
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
@jit
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
bemf = velocity / kv
v_avail = jnp.clip(vmax - bemf, 0.0, vmax)
current = jnp.clip(v_avail / resistance, 0.0, current_limit)
@ -113,8 +128,15 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
stall_torque = kt * current_limit
return jnp.where(velocity < 0.01, stall_torque, net_torque)
@partial(jit, static_argnums=(1,2,))
def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
@partial(
jit,
static_argnums=(
1,
2,
),
)
def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]):
r_array = jnp.full(battery_shape, cell_r)
branch_current = current / battery_shape[1]
I_array = jnp.full(battery_shape, branch_current)
@ -122,7 +144,6 @@ def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
return jnp.sum(cell_Ploss)
def forward(state, timestep, control, params: CarParams):
# state is (position, time, energy)
# control is velocity
@ -130,7 +151,7 @@ def forward(state, timestep, control, params: CarParams):
# params is the params dictionary.
# returns the next state with (position', time + timestep, energy')
# TODO: terrain, weather, solar
# determine the forces acting on the car.
dragf = drag_force(control, params.frontal_area, params.drag_coeff, 1.184)
rollf = rolling_force(params.mass, 0, params.rolling_coeff)
@ -139,18 +160,24 @@ def forward(state, timestep, control, params: CarParams):
# determine the power needed to make this force
tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, control, params.motor)
net_power = 0 - pdraw # watts aka j/s
net_power = 0 - pdraw # watts aka j/s
# TODO: calculate battery-based power losses.
# TODO: support regenerative braking when going downhill
# TODO: delta x = cos(theta) * velocity * timestep
new_state = jnp.array([state[0] + control * timestep, state[1] + timestep, state[2] + net_power * timestep])
new_state = jnp.array(
[
state[0] + control * timestep,
state[1] + timestep,
state[2] + net_power * timestep,
]
)
return new_state
def make_environment(seed):
""" Generate a race environment: terrain function, wind function, wrapped forward function."""
"""Generate a race environment: terrain function, wind function, wrapped forward function."""
key, subkey = jax.random.split(seed)
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
key, subkey = jax.random.split(key)
@ -161,46 +188,3 @@ def make_environment(seed):
return wind, elevation, slope
@partial(jit, static_argnames=['params'])
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
pos = jnp.astype(jnp.round(state[0]), "int32")
time = jnp.astype(jnp.round(state[1]), "int32")
theta = slope[pos]
velocity = control * params.max_speed
# sum up the forces acting on the car
dragf = drag_force(velocity, params.frontal_area, params.drag_coeff, 1.184)
rollf = rolling_force(params.mass, theta, params.rolling_coeff)
hillforce = downslope_force(params.mass, theta)
windf = wind[pos, time]
totalf = dragf + rollf + hillforce + windf
# with the sum of forces, determine the needed torque at the wheels, and then power
tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, velocity, params.motor)
# determine the energy needed to do this power for the time step
net_power = state[2] - delta_time * pdraw # joules
dpos = jnp.cos(theta) * velocity * delta_time
dist_remaining = 10000.0 - dpos
time_remaining = 600 - (state[1] + delta_time)
return jnp.array([dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining])
def reward(state):
progress = state[0] / 10000 * 100
energy_usage = -10 * state[2]
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
# now we have an environment tuned in.
# we want to take an environment, and bind it to the forward function
def make_simulator(params: CarParams, wind, elevation, slope):
def reward(state):
progress = state[0] / 10000 * 100
energy_usage = -10 * state[2]
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
return forwardv2, reward

View file

@ -2,10 +2,45 @@ import gymnasium as gym
import solarcarsim.physsim as sim
import jax
import jax.numpy as jnp
import numpy as np
from typing import Any
from jax import jit
from functools import partial
from jax import vmap
from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw
@partial(jit, static_argnames=["params"])
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
pos = jnp.astype(jnp.round(state[0]), "int32")
time = jnp.astype(jnp.round(state[1]), "int32")
theta = slope[pos]
velocity = control * params.max_speed
# sum up the forces acting on the car
windspeed = wind[pos, time]
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
hillforce = sim.downslope_force(params.mass, theta)
totalf = dragf + rollf + hillforce
# with the sum of forces, determine the needed torque at the wheels, and then power
tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, velocity, params.motor)
# determine the energy needed to do this power for the time step
net_power = state[2] - delta_time * pdraw # joules
dpos = jnp.cos(theta) * velocity * delta_time
dist_remaining = 10000.0 - dpos
time_remaining = 600 - (state[1] + delta_time)
return jnp.array(
[dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
)
def reward(state, prev_energy):
progress = state[0] / 10000 * 100
energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy.
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
class SolarRaceV1(gym.Env):
"""A primitive hill climber. Aims to solve the given route optimizing
@ -46,8 +81,8 @@ class SolarRaceV1(gym.Env):
self._reset_sim(jax.random.key(seed))
self._timestep = timestep
self._car = car
self._simstep = sim.forwardv2
self._simreward = sim.reward
self._simstep = forwardv2
self._simreward = reward
self.observation_space = gym.spaces.Dict(
{
@ -73,8 +108,10 @@ class SolarRaceV1(gym.Env):
def step(self, action):
wind, elevation, slope = self._environment
old_energy = self._state[2]
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
reward = self._simreward(self._state)[0]
reward = self._simreward(self._state, old_energy)[0]
terminated = False
truncated = False
if jnp.all(self._state[0] > 10000):
@ -82,4 +119,6 @@ class SolarRaceV1(gym.Env):
if self._state[1] > 600:
truncated = True
return self._get_obs(), reward, terminated, truncated, {}
return self._get_obs(), reward, terminated, truncated, {}

14
src/solarcarsim/simv2.py Normal file
View file

@ -0,0 +1,14 @@
""" Second-generation simulator. More functional, cleaner code, faster """
from typing import NamedTuple
import jax
import jax.numpy as jnp
class SimState(NamedTuple):
position: float
time: float
energy: float
distance_remaining: float
time_remaining: float