jax cuda, separate v1 sim

This commit is contained in:
saji 2024-12-13 19:40:15 -06:00
parent 104dfab637
commit aaac5df404
6 changed files with 6353 additions and 208 deletions

File diff suppressed because it is too large Load diff

128
pdm.lock
View file

@ -5,7 +5,7 @@
groups = ["default", "dev"] groups = ["default", "dev"]
strategy = ["inherit_metadata"] strategy = ["inherit_metadata"]
lock_version = "4.5.0" lock_version = "4.5.0"
content_hash = "sha256:a1edba805cc867a6316cea6c754bc112f0f79046b604ee515c541505f9c546f7" content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4"
[[metadata.targets]] [[metadata.targets]]
requires_python = ">=3.12,<3.13" requires_python = ">=3.12,<3.13"
@ -522,13 +522,13 @@ files = [
[[package]] [[package]]
name = "jax" name = "jax"
version = "0.4.35" version = "0.4.37"
requires_python = ">=3.10" requires_python = ">=3.10"
summary = "Differentiate, compile, and transform Numpy code." summary = "Differentiate, compile, and transform Numpy code."
groups = ["default"] groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [ dependencies = [
"jaxlib<=0.4.35,>=0.4.34", "jaxlib<=0.4.37,>=0.4.36",
"ml-dtypes>=0.4.0", "ml-dtypes>=0.4.0",
"numpy>=1.24", "numpy>=1.24",
"numpy>=1.26.0; python_version >= \"3.12\"", "numpy>=1.26.0; python_version >= \"3.12\"",
@ -537,13 +537,83 @@ dependencies = [
"scipy>=1.11.1; python_version >= \"3.12\"", "scipy>=1.11.1; python_version >= \"3.12\"",
] ]
files = [ files = [
{file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"}, {file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
{file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"}, {file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
]
[[package]]
name = "jax-cuda12-pjrt"
version = "0.4.36"
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-pjrt==0.4.36",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax-cuda12-plugin"
version = "0.4.36"
extras = ["with_cuda"]
requires_python = ">=3.10"
summary = "JAX Plugin for NVIDIA GPUs"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin==0.4.36",
"nvidia-cublas-cu12>=12.1.3.1",
"nvidia-cuda-cupti-cu12>=12.1.105",
"nvidia-cuda-nvcc-cu12>=12.6.85",
"nvidia-cuda-runtime-cu12>=12.1.105",
"nvidia-cudnn-cu12<10.0,>=9.1",
"nvidia-cufft-cu12>=11.0.2.54",
"nvidia-cusolver-cu12>=11.4.5.107",
"nvidia-cusparse-cu12>=12.1.0.106",
"nvidia-nccl-cu12>=2.18.1",
"nvidia-nvjitlink-cu12>=12.1.105",
]
files = [
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
]
[[package]]
name = "jax"
version = "0.4.37"
extras = ["cuda12"]
requires_python = ">=3.10"
summary = "Differentiate, compile, and transform Numpy code."
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
"jax==0.4.37",
"jaxlib==0.4.36",
]
files = [
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
] ]
[[package]] [[package]]
name = "jaxlib" name = "jaxlib"
version = "0.4.35" version = "0.4.36"
requires_python = ">=3.10" requires_python = ">=3.10"
summary = "XLA library for JAX" summary = "XLA library for JAX"
groups = ["default"] groups = ["default"]
@ -555,16 +625,11 @@ dependencies = [
"scipy>=1.11.1; python_version >= \"3.12\"", "scipy>=1.11.1; python_version >= \"3.12\"",
] ]
files = [ files = [
{file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"}, {file = "jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d"},
{file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"}, {file = "jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e"},
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"}, {file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442"},
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"}, {file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357"},
{file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"}, {file = "jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3"},
{file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"},
{file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"},
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"},
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"},
{file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"},
] ]
[[package]] [[package]]
@ -861,7 +926,7 @@ version = "12.4.5.8"
requires_python = ">=3" requires_python = ">=3"
summary = "CUBLAS native runtime libraries" summary = "CUBLAS native runtime libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [ files = [
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"}, {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"}, {file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
@ -874,13 +939,26 @@ version = "12.4.127"
requires_python = ">=3" requires_python = ">=3"
summary = "CUDA profiling tools runtime libs." summary = "CUDA profiling tools runtime libs."
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [ files = [
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"}, {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"}, {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"}, {file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
] ]
[[package]]
name = "nvidia-cuda-nvcc-cu12"
version = "12.6.85"
requires_python = ">=3"
summary = "CUDA nvcc"
groups = ["default"]
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
]
[[package]] [[package]]
name = "nvidia-cuda-nvrtc-cu12" name = "nvidia-cuda-nvrtc-cu12"
version = "12.4.127" version = "12.4.127"
@ -900,7 +978,7 @@ version = "12.4.127"
requires_python = ">=3" requires_python = ">=3"
summary = "CUDA Runtime native Libraries" summary = "CUDA Runtime native Libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [ files = [
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"}, {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"}, {file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
@ -913,7 +991,7 @@ version = "9.1.0.70"
requires_python = ">=3" requires_python = ">=3"
summary = "cuDNN runtime libraries" summary = "cuDNN runtime libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [ dependencies = [
"nvidia-cublas-cu12", "nvidia-cublas-cu12",
] ]
@ -928,7 +1006,7 @@ version = "11.2.1.3"
requires_python = ">=3" requires_python = ">=3"
summary = "CUFFT native runtime libraries" summary = "CUFFT native runtime libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [ dependencies = [
"nvidia-nvjitlink-cu12", "nvidia-nvjitlink-cu12",
] ]
@ -957,7 +1035,7 @@ version = "11.6.1.9"
requires_python = ">=3" requires_python = ">=3"
summary = "CUDA solver native runtime libraries" summary = "CUDA solver native runtime libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [ dependencies = [
"nvidia-cublas-cu12", "nvidia-cublas-cu12",
"nvidia-cusparse-cu12", "nvidia-cusparse-cu12",
@ -975,7 +1053,7 @@ version = "12.3.1.170"
requires_python = ">=3" requires_python = ">=3"
summary = "CUSPARSE native runtime libraries" summary = "CUSPARSE native runtime libraries"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
dependencies = [ dependencies = [
"nvidia-nvjitlink-cu12", "nvidia-nvjitlink-cu12",
] ]
@ -991,7 +1069,7 @@ version = "2.21.5"
requires_python = ">=3" requires_python = ">=3"
summary = "NVIDIA Collective Communication Library (NCCL) Runtime" summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [ files = [
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"}, {file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
] ]
@ -1002,7 +1080,7 @@ version = "12.4.127"
requires_python = ">=3" requires_python = ">=3"
summary = "Nvidia JIT LTO Library" summary = "Nvidia JIT LTO Library"
groups = ["default"] groups = ["default"]
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\"" marker = "python_version >= \"3.12\" and python_version < \"3.13\""
files = [ files = [
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},

View file

@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
authors = [ authors = [
{name = "saji", email = "saji@saji.dev"}, {name = "saji", email = "saji@saji.dev"},
] ]
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.35", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"] dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
requires-python = ">=3.10,<3.13" requires-python = ">=3.10,<3.13"
readme = "README.md" readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}

View file

@ -1,3 +1,5 @@
"""Physical equations and models for building a simulation environment"""
import jax.numpy as jnp import jax.numpy as jnp
import jax import jax
from jax import grad, jit, vmap, lax from jax import grad, jit, vmap, lax
@ -5,7 +7,11 @@ from functools import partial
from typing import NamedTuple, Tuple from typing import NamedTuple, Tuple
from solarcarsim.noise import fractal_noise_1d, generate_elevation_profile, generate_wind_field from solarcarsim.noise import (
fractal_noise_1d,
generate_wind_field,
)
class MotorParams(NamedTuple): class MotorParams(NamedTuple):
kv: float kv: float
@ -16,33 +22,37 @@ class MotorParams(NamedTuple):
class BatteryParams(NamedTuple): class BatteryParams(NamedTuple):
shape: Tuple[int, int] # (series,parallel) array of batteries shape: Tuple[int, int] # (series,parallel) array of batteries
resistance: float # ohms resistance: float # ohms
initial_energy: float # joules initial_energy: float # joules
class CarParams(NamedTuple): class CarParams(NamedTuple):
""" Physical Data for Solar Car Parameters """ """Physical Data for Solar Car Parameters"""
mass: float = 800 # kg
frontal_area: float = 1.3 # m^2
drag_coeff: float = 0.18 # drag coefficient, dimensionless
rolling_coeff: float = 0.002 # rolling resistance.
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
wheel_radius: float = 0.23 # wheel radius in meters
max_speed: float = 30.0 # m/s top speed
solar_area: float = 5.0 # m^2, typically 5.0
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
n_motors: int = 2 # how many motors we have.
motor: MotorParams = MotorParams(8.43, 1.1, 100.0, 0.001, 0.001) # mitsuba m2090 estimate
battery: BatteryParams = BatteryParams((36,19), 0.0126, 66.6e3) # freebasing 50s pack.
mass: float = 800 # kg
frontal_area: float = 1.3 # m^2
drag_coeff: float = 0.18 # drag coefficient, dimensionless
rolling_coeff: float = 0.002 # rolling resistance.
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
wheel_radius: float = 0.23 # wheel radius in meters
max_speed: float = 30.0 # m/s top speed
solar_area: float = 5.0 # m^2, typically 5.0
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
n_motors: int = 2 # how many motors we have.
motor: MotorParams = MotorParams(
8.43, 1.1, 100.0, 0.001, 0.001
) # mitsuba m2090 estimate
battery: BatteryParams = BatteryParams(
(36, 19), 0.0126, 66.6e3
) # freebasing 50s pack.
def DefaultCar() -> CarParams: def DefaultCar() -> CarParams:
""" Creates a basic car """ """Creates a basic car"""
return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23) return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23)
# some physics equations using jax # some physics equations using jax
@ -50,20 +60,25 @@ def DefaultCar() -> CarParams:
def normal_force(mass, theta): def normal_force(mass, theta):
return mass * 9.8 * jnp.cos(theta) return mass * 9.8 * jnp.cos(theta)
@jit @jit
def downslope_force(mass, theta): def downslope_force(mass, theta):
return mass * 9.8 * jnp.sin(theta) return mass * 9.8 * jnp.sin(theta)
@partial(jit, static_argnames=['crr'])
@partial(jit, static_argnames=["crr"])
def rolling_force(mass, theta, crr): def rolling_force(mass, theta, crr):
return normal_force(mass, theta) * crr return normal_force(mass, theta) * crr
@partial(jit, static_argnames=['area', 'cd', 'rho'])
@partial(jit, static_argnames=["area", "cd", "rho"])
def drag_force(u, area, cd, rho): def drag_force(u, area, cd, rho):
return 0.5 * rho * jnp.pow(u, 2) * cd * area return 0.5 * rho * jnp.pow(u, 2) * cd * area
# we can use those forces above to determine what forces we have to overcome. Sum(F)=0 # we can use those forces above to determine what forces we have to overcome. Sum(F)=0
# @partial(jit, static_argnums=(2,)) # @partial(jit, static_argnums=(2,))
@jit @jit
def bldc_power_draw(torque, velocity, params: MotorParams): def bldc_power_draw(torque, velocity, params: MotorParams):
@ -99,10 +114,10 @@ def bldc_power_draw(torque, velocity, params: MotorParams):
return total_power return total_power
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf']) # @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
@jit @jit
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf): def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
bemf = velocity / kv bemf = velocity / kv
v_avail = jnp.clip(vmax - bemf, 0.0, vmax) v_avail = jnp.clip(vmax - bemf, 0.0, vmax)
current = jnp.clip(v_avail / resistance, 0.0, current_limit) current = jnp.clip(v_avail / resistance, 0.0, current_limit)
@ -113,8 +128,15 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
stall_torque = kt * current_limit stall_torque = kt * current_limit
return jnp.where(velocity < 0.01, stall_torque, net_torque) return jnp.where(velocity < 0.01, stall_torque, net_torque)
@partial(jit, static_argnums=(1,2,))
def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]): @partial(
jit,
static_argnums=(
1,
2,
),
)
def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]):
r_array = jnp.full(battery_shape, cell_r) r_array = jnp.full(battery_shape, cell_r)
branch_current = current / battery_shape[1] branch_current = current / battery_shape[1]
I_array = jnp.full(battery_shape, branch_current) I_array = jnp.full(battery_shape, branch_current)
@ -122,7 +144,6 @@ def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
return jnp.sum(cell_Ploss) return jnp.sum(cell_Ploss)
def forward(state, timestep, control, params: CarParams): def forward(state, timestep, control, params: CarParams):
# state is (position, time, energy) # state is (position, time, energy)
# control is velocity # control is velocity
@ -139,18 +160,24 @@ def forward(state, timestep, control, params: CarParams):
# determine the power needed to make this force # determine the power needed to make this force
tau = params.wheel_radius * totalf tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, control, params.motor) pdraw = bldc_power_draw(tau, control, params.motor)
net_power = 0 - pdraw # watts aka j/s net_power = 0 - pdraw # watts aka j/s
# TODO: calculate battery-based power losses. # TODO: calculate battery-based power losses.
# TODO: support regenerative braking when going downhill # TODO: support regenerative braking when going downhill
# TODO: delta x = cos(theta) * velocity * timestep # TODO: delta x = cos(theta) * velocity * timestep
new_state = jnp.array([state[0] + control * timestep, state[1] + timestep, state[2] + net_power * timestep]) new_state = jnp.array(
[
state[0] + control * timestep,
state[1] + timestep,
state[2] + net_power * timestep,
]
)
return new_state return new_state
def make_environment(seed): def make_environment(seed):
""" Generate a race environment: terrain function, wind function, wrapped forward function.""" """Generate a race environment: terrain function, wind function, wrapped forward function."""
key, subkey = jax.random.split(seed) key, subkey = jax.random.split(seed)
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000) wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
key, subkey = jax.random.split(key) key, subkey = jax.random.split(key)
@ -161,46 +188,3 @@ def make_environment(seed):
return wind, elevation, slope return wind, elevation, slope
@partial(jit, static_argnames=['params'])
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
pos = jnp.astype(jnp.round(state[0]), "int32")
time = jnp.astype(jnp.round(state[1]), "int32")
theta = slope[pos]
velocity = control * params.max_speed
# sum up the forces acting on the car
dragf = drag_force(velocity, params.frontal_area, params.drag_coeff, 1.184)
rollf = rolling_force(params.mass, theta, params.rolling_coeff)
hillforce = downslope_force(params.mass, theta)
windf = wind[pos, time]
totalf = dragf + rollf + hillforce + windf
# with the sum of forces, determine the needed torque at the wheels, and then power
tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, velocity, params.motor)
# determine the energy needed to do this power for the time step
net_power = state[2] - delta_time * pdraw # joules
dpos = jnp.cos(theta) * velocity * delta_time
dist_remaining = 10000.0 - dpos
time_remaining = 600 - (state[1] + delta_time)
return jnp.array([dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining])
def reward(state):
progress = state[0] / 10000 * 100
energy_usage = -10 * state[2]
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
# now we have an environment tuned in.
# we want to take an environment, and bind it to the forward function
def make_simulator(params: CarParams, wind, elevation, slope):
def reward(state):
progress = state[0] / 10000 * 100
energy_usage = -10 * state[2]
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
return forwardv2, reward

View file

@ -2,10 +2,45 @@ import gymnasium as gym
import solarcarsim.physsim as sim import solarcarsim.physsim as sim
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np from jax import jit
from typing import Any
from functools import partial from functools import partial
from jax import vmap from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw
@partial(jit, static_argnames=["params"])
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
pos = jnp.astype(jnp.round(state[0]), "int32")
time = jnp.astype(jnp.round(state[1]), "int32")
theta = slope[pos]
velocity = control * params.max_speed
# sum up the forces acting on the car
windspeed = wind[pos, time]
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
hillforce = sim.downslope_force(params.mass, theta)
totalf = dragf + rollf + hillforce
# with the sum of forces, determine the needed torque at the wheels, and then power
tau = params.wheel_radius * totalf
pdraw = bldc_power_draw(tau, velocity, params.motor)
# determine the energy needed to do this power for the time step
net_power = state[2] - delta_time * pdraw # joules
dpos = jnp.cos(theta) * velocity * delta_time
dist_remaining = 10000.0 - dpos
time_remaining = 600 - (state[1] + delta_time)
return jnp.array(
[dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
)
def reward(state, prev_energy):
progress = state[0] / 10000 * 100
energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy.
time_factor = (1.0 - (state[1] / 600)) * 50
reward = progress + energy_usage + time_factor
return reward
class SolarRaceV1(gym.Env): class SolarRaceV1(gym.Env):
"""A primitive hill climber. Aims to solve the given route optimizing """A primitive hill climber. Aims to solve the given route optimizing
@ -46,8 +81,8 @@ class SolarRaceV1(gym.Env):
self._reset_sim(jax.random.key(seed)) self._reset_sim(jax.random.key(seed))
self._timestep = timestep self._timestep = timestep
self._car = car self._car = car
self._simstep = sim.forwardv2 self._simstep = forwardv2
self._simreward = sim.reward self._simreward = reward
self.observation_space = gym.spaces.Dict( self.observation_space = gym.spaces.Dict(
{ {
@ -73,8 +108,10 @@ class SolarRaceV1(gym.Env):
def step(self, action): def step(self, action):
wind, elevation, slope = self._environment wind, elevation, slope = self._environment
old_energy = self._state[2]
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car) self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
reward = self._simreward(self._state)[0] reward = self._simreward(self._state, old_energy)[0]
terminated = False terminated = False
truncated = False truncated = False
if jnp.all(self._state[0] > 10000): if jnp.all(self._state[0] > 10000):
@ -83,3 +120,5 @@ class SolarRaceV1(gym.Env):
truncated = True truncated = True
return self._get_obs(), reward, terminated, truncated, {} return self._get_obs(), reward, terminated, truncated, {}

14
src/solarcarsim/simv2.py Normal file
View file

@ -0,0 +1,14 @@
""" Second-generation simulator. More functional, cleaner code, faster """
from typing import NamedTuple
import jax
import jax.numpy as jnp
class SimState(NamedTuple):
position: float
time: float
energy: float
distance_remaining: float
time_remaining: float