jax cuda, separate v1 sim
This commit is contained in:
parent
104dfab637
commit
aaac5df404
File diff suppressed because it is too large
Load diff
128
pdm.lock
128
pdm.lock
|
@ -5,7 +5,7 @@
|
|||
groups = ["default", "dev"]
|
||||
strategy = ["inherit_metadata"]
|
||||
lock_version = "4.5.0"
|
||||
content_hash = "sha256:a1edba805cc867a6316cea6c754bc112f0f79046b604ee515c541505f9c546f7"
|
||||
content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4"
|
||||
|
||||
[[metadata.targets]]
|
||||
requires_python = ">=3.12,<3.13"
|
||||
|
@ -522,13 +522,13 @@ files = [
|
|||
|
||||
[[package]]
|
||||
name = "jax"
|
||||
version = "0.4.35"
|
||||
version = "0.4.37"
|
||||
requires_python = ">=3.10"
|
||||
summary = "Differentiate, compile, and transform Numpy code."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jaxlib<=0.4.35,>=0.4.34",
|
||||
"jaxlib<=0.4.37,>=0.4.36",
|
||||
"ml-dtypes>=0.4.0",
|
||||
"numpy>=1.24",
|
||||
"numpy>=1.26.0; python_version >= \"3.12\"",
|
||||
|
@ -537,13 +537,83 @@ dependencies = [
|
|||
"scipy>=1.11.1; python_version >= \"3.12\"",
|
||||
]
|
||||
files = [
|
||||
{file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"},
|
||||
{file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"},
|
||||
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-pjrt"
|
||||
version = "0.4.36"
|
||||
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-pjrt==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
extras = ["with_cuda"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin==0.4.36",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.6.85",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12<10.0,>=9.1",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax"
|
||||
version = "0.4.37"
|
||||
extras = ["cuda12"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "Differentiate, compile, and transform Numpy code."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
|
||||
"jax==0.4.37",
|
||||
"jaxlib==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jaxlib"
|
||||
version = "0.4.35"
|
||||
version = "0.4.36"
|
||||
requires_python = ">=3.10"
|
||||
summary = "XLA library for JAX"
|
||||
groups = ["default"]
|
||||
|
@ -555,16 +625,11 @@ dependencies = [
|
|||
"scipy>=1.11.1; python_version >= \"3.12\"",
|
||||
]
|
||||
files = [
|
||||
{file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"},
|
||||
{file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"},
|
||||
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"},
|
||||
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"},
|
||||
{file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"},
|
||||
{file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"},
|
||||
{file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"},
|
||||
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"},
|
||||
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"},
|
||||
{file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"},
|
||||
{file = "jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d"},
|
||||
{file = "jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e"},
|
||||
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442"},
|
||||
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357"},
|
||||
{file = "jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -861,7 +926,7 @@ version = "12.4.5.8"
|
|||
requires_python = ">=3"
|
||||
summary = "CUBLAS native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
||||
|
@ -874,13 +939,26 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA profiling tools runtime libs."
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvcc-cu12"
|
||||
version = "12.6.85"
|
||||
requires_python = ">=3"
|
||||
summary = "CUDA nvcc"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvrtc-cu12"
|
||||
version = "12.4.127"
|
||||
|
@ -900,7 +978,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA Runtime native Libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
||||
|
@ -913,7 +991,7 @@ version = "9.1.0.70"
|
|||
requires_python = ">=3"
|
||||
summary = "cuDNN runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
]
|
||||
|
@ -928,7 +1006,7 @@ version = "11.2.1.3"
|
|||
requires_python = ">=3"
|
||||
summary = "CUFFT native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -957,7 +1035,7 @@ version = "11.6.1.9"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA solver native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
"nvidia-cusparse-cu12",
|
||||
|
@ -975,7 +1053,7 @@ version = "12.3.1.170"
|
|||
requires_python = ">=3"
|
||||
summary = "CUSPARSE native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -991,7 +1069,7 @@ version = "2.21.5"
|
|||
requires_python = ">=3"
|
||||
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||
]
|
||||
|
@ -1002,7 +1080,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "Nvidia JIT LTO Library"
|
||||
groups = ["default"]
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
|
|
|
@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
|
|||
authors = [
|
||||
{name = "saji", email = "saji@saji.dev"},
|
||||
]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.35", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
"""Physical equations and models for building a simulation environment"""
|
||||
|
||||
import jax.numpy as jnp
|
||||
import jax
|
||||
from jax import grad, jit, vmap, lax
|
||||
|
@ -5,7 +7,11 @@ from functools import partial
|
|||
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
from solarcarsim.noise import fractal_noise_1d, generate_elevation_profile, generate_wind_field
|
||||
from solarcarsim.noise import (
|
||||
fractal_noise_1d,
|
||||
generate_wind_field,
|
||||
)
|
||||
|
||||
|
||||
class MotorParams(NamedTuple):
|
||||
kv: float
|
||||
|
@ -16,33 +22,37 @@ class MotorParams(NamedTuple):
|
|||
|
||||
|
||||
class BatteryParams(NamedTuple):
|
||||
shape: Tuple[int, int] # (series,parallel) array of batteries
|
||||
resistance: float # ohms
|
||||
initial_energy: float # joules
|
||||
shape: Tuple[int, int] # (series,parallel) array of batteries
|
||||
resistance: float # ohms
|
||||
initial_energy: float # joules
|
||||
|
||||
|
||||
class CarParams(NamedTuple):
|
||||
""" Physical Data for Solar Car Parameters """
|
||||
mass: float = 800 # kg
|
||||
frontal_area: float = 1.3 # m^2
|
||||
drag_coeff: float = 0.18 # drag coefficient, dimensionless
|
||||
rolling_coeff: float = 0.002 # rolling resistance.
|
||||
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
|
||||
wheel_radius: float = 0.23 # wheel radius in meters
|
||||
max_speed: float = 30.0 # m/s top speed
|
||||
solar_area: float = 5.0 # m^2, typically 5.0
|
||||
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
|
||||
n_motors: int = 2 # how many motors we have.
|
||||
motor: MotorParams = MotorParams(8.43, 1.1, 100.0, 0.001, 0.001) # mitsuba m2090 estimate
|
||||
battery: BatteryParams = BatteryParams((36,19), 0.0126, 66.6e3) # freebasing 50s pack.
|
||||
"""Physical Data for Solar Car Parameters"""
|
||||
|
||||
mass: float = 800 # kg
|
||||
frontal_area: float = 1.3 # m^2
|
||||
drag_coeff: float = 0.18 # drag coefficient, dimensionless
|
||||
rolling_coeff: float = 0.002 # rolling resistance.
|
||||
moter_eff: float = 0.93 # 0 < x < 1 scaling factor
|
||||
wheel_radius: float = 0.23 # wheel radius in meters
|
||||
max_speed: float = 30.0 # m/s top speed
|
||||
solar_area: float = 5.0 # m^2, typically 5.0
|
||||
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
|
||||
n_motors: int = 2 # how many motors we have.
|
||||
motor: MotorParams = MotorParams(
|
||||
8.43, 1.1, 100.0, 0.001, 0.001
|
||||
) # mitsuba m2090 estimate
|
||||
battery: BatteryParams = BatteryParams(
|
||||
(36, 19), 0.0126, 66.6e3
|
||||
) # freebasing 50s pack.
|
||||
|
||||
|
||||
def DefaultCar() -> CarParams:
|
||||
""" Creates a basic car """
|
||||
"""Creates a basic car"""
|
||||
return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23)
|
||||
|
||||
|
||||
|
||||
# some physics equations using jax
|
||||
|
||||
|
||||
|
@ -50,26 +60,31 @@ def DefaultCar() -> CarParams:
|
|||
def normal_force(mass, theta):
|
||||
return mass * 9.8 * jnp.cos(theta)
|
||||
|
||||
|
||||
@jit
|
||||
def downslope_force(mass, theta):
|
||||
return mass * 9.8 * jnp.sin(theta)
|
||||
|
||||
@partial(jit, static_argnames=['crr'])
|
||||
|
||||
@partial(jit, static_argnames=["crr"])
|
||||
def rolling_force(mass, theta, crr):
|
||||
return normal_force(mass, theta) * crr
|
||||
|
||||
@partial(jit, static_argnames=['area', 'cd', 'rho'])
|
||||
|
||||
@partial(jit, static_argnames=["area", "cd", "rho"])
|
||||
def drag_force(u, area, cd, rho):
|
||||
return 0.5 * rho * jnp.pow(u, 2) * cd * area
|
||||
|
||||
|
||||
# we can use those forces above to determine what forces we have to overcome. Sum(F)=0
|
||||
|
||||
|
||||
# @partial(jit, static_argnums=(2,))
|
||||
@jit
|
||||
def bldc_power_draw(torque, velocity, params: MotorParams):
|
||||
"""
|
||||
Approximates power draw of a BLDC motor outputting a torque at a given velocity
|
||||
|
||||
|
||||
Args:
|
||||
torq: Applied force in Newton/meters
|
||||
velocity: Angular velocity in rad/s
|
||||
|
@ -77,32 +92,32 @@ def bldc_power_draw(torque, velocity, params: MotorParams):
|
|||
kt: Torque constant (N⋅m/A)
|
||||
friction_coeff: Mechanical friction coefficient
|
||||
iron_loss_coeff: Iron loss coefficient (core losses)
|
||||
|
||||
|
||||
Returns:
|
||||
Total electrical power draw in Watts
|
||||
"""
|
||||
|
||||
|
||||
# Current required for torque (simplified relationship)
|
||||
current = torque / params.kt
|
||||
|
||||
|
||||
# Copper losses (I²R)
|
||||
copper_losses = params.resistance * current**2
|
||||
copper_losses = params.resistance * current**2
|
||||
# Mechanical friction losses
|
||||
friction_losses = params.friction_coeff * velocity**2
|
||||
friction_losses = params.friction_coeff * velocity**2
|
||||
# Iron losses (simplified model - primarily dependent on speed)
|
||||
iron_losses = params.iron_coeff * velocity**2
|
||||
iron_losses = params.iron_coeff * velocity**2
|
||||
# Mechanical power output
|
||||
mechanical_power = torque * velocity
|
||||
|
||||
|
||||
# Total electrical power input
|
||||
total_power = mechanical_power + copper_losses + friction_losses + iron_losses
|
||||
|
||||
|
||||
return total_power
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
|
||||
@jit
|
||||
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
||||
|
||||
bemf = velocity / kv
|
||||
v_avail = jnp.clip(vmax - bemf, 0.0, vmax)
|
||||
current = jnp.clip(v_avail / resistance, 0.0, current_limit)
|
||||
|
@ -113,8 +128,15 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
|||
stall_torque = kt * current_limit
|
||||
return jnp.where(velocity < 0.01, stall_torque, net_torque)
|
||||
|
||||
@partial(jit, static_argnums=(1,2,))
|
||||
def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
|
||||
|
||||
@partial(
|
||||
jit,
|
||||
static_argnums=(
|
||||
1,
|
||||
2,
|
||||
),
|
||||
)
|
||||
def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]):
|
||||
r_array = jnp.full(battery_shape, cell_r)
|
||||
branch_current = current / battery_shape[1]
|
||||
I_array = jnp.full(battery_shape, branch_current)
|
||||
|
@ -122,7 +144,6 @@ def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
|
|||
return jnp.sum(cell_Ploss)
|
||||
|
||||
|
||||
|
||||
def forward(state, timestep, control, params: CarParams):
|
||||
# state is (position, time, energy)
|
||||
# control is velocity
|
||||
|
@ -130,7 +151,7 @@ def forward(state, timestep, control, params: CarParams):
|
|||
# params is the params dictionary.
|
||||
# returns the next state with (position', time + timestep, energy')
|
||||
# TODO: terrain, weather, solar
|
||||
|
||||
|
||||
# determine the forces acting on the car.
|
||||
dragf = drag_force(control, params.frontal_area, params.drag_coeff, 1.184)
|
||||
rollf = rolling_force(params.mass, 0, params.rolling_coeff)
|
||||
|
@ -139,18 +160,24 @@ def forward(state, timestep, control, params: CarParams):
|
|||
# determine the power needed to make this force
|
||||
tau = params.wheel_radius * totalf
|
||||
pdraw = bldc_power_draw(tau, control, params.motor)
|
||||
net_power = 0 - pdraw # watts aka j/s
|
||||
|
||||
net_power = 0 - pdraw # watts aka j/s
|
||||
|
||||
# TODO: calculate battery-based power losses.
|
||||
# TODO: support regenerative braking when going downhill
|
||||
# TODO: delta x = cos(theta) * velocity * timestep
|
||||
|
||||
new_state = jnp.array([state[0] + control * timestep, state[1] + timestep, state[2] + net_power * timestep])
|
||||
new_state = jnp.array(
|
||||
[
|
||||
state[0] + control * timestep,
|
||||
state[1] + timestep,
|
||||
state[2] + net_power * timestep,
|
||||
]
|
||||
)
|
||||
return new_state
|
||||
|
||||
|
||||
def make_environment(seed):
|
||||
""" Generate a race environment: terrain function, wind function, wrapped forward function."""
|
||||
"""Generate a race environment: terrain function, wind function, wrapped forward function."""
|
||||
key, subkey = jax.random.split(seed)
|
||||
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
|
||||
key, subkey = jax.random.split(key)
|
||||
|
@ -161,46 +188,3 @@ def make_environment(seed):
|
|||
|
||||
return wind, elevation, slope
|
||||
|
||||
@partial(jit, static_argnames=['params'])
|
||||
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
||||
pos = jnp.astype(jnp.round(state[0]), "int32")
|
||||
time = jnp.astype(jnp.round(state[1]), "int32")
|
||||
theta = slope[pos]
|
||||
|
||||
velocity = control * params.max_speed
|
||||
|
||||
# sum up the forces acting on the car
|
||||
dragf = drag_force(velocity, params.frontal_area, params.drag_coeff, 1.184)
|
||||
rollf = rolling_force(params.mass, theta, params.rolling_coeff)
|
||||
hillforce = downslope_force(params.mass, theta)
|
||||
windf = wind[pos, time]
|
||||
totalf = dragf + rollf + hillforce + windf
|
||||
# with the sum of forces, determine the needed torque at the wheels, and then power
|
||||
tau = params.wheel_radius * totalf
|
||||
pdraw = bldc_power_draw(tau, velocity, params.motor)
|
||||
# determine the energy needed to do this power for the time step
|
||||
net_power = state[2] - delta_time * pdraw # joules
|
||||
|
||||
dpos = jnp.cos(theta) * velocity * delta_time
|
||||
dist_remaining = 10000.0 - dpos
|
||||
time_remaining = 600 - (state[1] + delta_time)
|
||||
return jnp.array([dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining])
|
||||
|
||||
def reward(state):
|
||||
progress = state[0] / 10000 * 100
|
||||
energy_usage = -10 * state[2]
|
||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
||||
reward = progress + energy_usage + time_factor
|
||||
return reward
|
||||
# now we have an environment tuned in.
|
||||
# we want to take an environment, and bind it to the forward function
|
||||
def make_simulator(params: CarParams, wind, elevation, slope):
|
||||
def reward(state):
|
||||
progress = state[0] / 10000 * 100
|
||||
energy_usage = -10 * state[2]
|
||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
||||
reward = progress + energy_usage + time_factor
|
||||
return reward
|
||||
return forwardv2, reward
|
||||
|
||||
|
||||
|
|
|
@ -2,10 +2,45 @@ import gymnasium as gym
|
|||
import solarcarsim.physsim as sim
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from typing import Any
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
from jax import vmap
|
||||
from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["params"])
|
||||
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
||||
pos = jnp.astype(jnp.round(state[0]), "int32")
|
||||
time = jnp.astype(jnp.round(state[1]), "int32")
|
||||
theta = slope[pos]
|
||||
|
||||
velocity = control * params.max_speed
|
||||
|
||||
# sum up the forces acting on the car
|
||||
windspeed = wind[pos, time]
|
||||
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
|
||||
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
|
||||
hillforce = sim.downslope_force(params.mass, theta)
|
||||
totalf = dragf + rollf + hillforce
|
||||
# with the sum of forces, determine the needed torque at the wheels, and then power
|
||||
tau = params.wheel_radius * totalf
|
||||
pdraw = bldc_power_draw(tau, velocity, params.motor)
|
||||
# determine the energy needed to do this power for the time step
|
||||
net_power = state[2] - delta_time * pdraw # joules
|
||||
|
||||
dpos = jnp.cos(theta) * velocity * delta_time
|
||||
dist_remaining = 10000.0 - dpos
|
||||
time_remaining = 600 - (state[1] + delta_time)
|
||||
return jnp.array(
|
||||
[dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
|
||||
)
|
||||
|
||||
|
||||
def reward(state, prev_energy):
|
||||
progress = state[0] / 10000 * 100
|
||||
energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy.
|
||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
||||
reward = progress + energy_usage + time_factor
|
||||
return reward
|
||||
|
||||
class SolarRaceV1(gym.Env):
|
||||
"""A primitive hill climber. Aims to solve the given route optimizing
|
||||
|
@ -46,8 +81,8 @@ class SolarRaceV1(gym.Env):
|
|||
self._reset_sim(jax.random.key(seed))
|
||||
self._timestep = timestep
|
||||
self._car = car
|
||||
self._simstep = sim.forwardv2
|
||||
self._simreward = sim.reward
|
||||
self._simstep = forwardv2
|
||||
self._simreward = reward
|
||||
|
||||
self.observation_space = gym.spaces.Dict(
|
||||
{
|
||||
|
@ -73,8 +108,10 @@ class SolarRaceV1(gym.Env):
|
|||
def step(self, action):
|
||||
wind, elevation, slope = self._environment
|
||||
|
||||
old_energy = self._state[2]
|
||||
|
||||
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
|
||||
reward = self._simreward(self._state)[0]
|
||||
reward = self._simreward(self._state, old_energy)[0]
|
||||
terminated = False
|
||||
truncated = False
|
||||
if jnp.all(self._state[0] > 10000):
|
||||
|
@ -82,4 +119,6 @@ class SolarRaceV1(gym.Env):
|
|||
if self._state[1] > 600:
|
||||
truncated = True
|
||||
|
||||
return self._get_obs(), reward, terminated, truncated, {}
|
||||
return self._get_obs(), reward, terminated, truncated, {}
|
||||
|
||||
|
14
src/solarcarsim/simv2.py
Normal file
14
src/solarcarsim/simv2.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
""" Second-generation simulator. More functional, cleaner code, faster """
|
||||
|
||||
from typing import NamedTuple
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
class SimState(NamedTuple):
|
||||
position: float
|
||||
time: float
|
||||
energy: float
|
||||
distance_remaining: float
|
||||
time_remaining: float
|
||||
|
Loading…
Reference in a new issue