jax cuda, separate v1 sim
This commit is contained in:
parent
104dfab637
commit
aaac5df404
File diff suppressed because it is too large
Load diff
128
pdm.lock
128
pdm.lock
|
@ -5,7 +5,7 @@
|
||||||
groups = ["default", "dev"]
|
groups = ["default", "dev"]
|
||||||
strategy = ["inherit_metadata"]
|
strategy = ["inherit_metadata"]
|
||||||
lock_version = "4.5.0"
|
lock_version = "4.5.0"
|
||||||
content_hash = "sha256:a1edba805cc867a6316cea6c754bc112f0f79046b604ee515c541505f9c546f7"
|
content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4"
|
||||||
|
|
||||||
[[metadata.targets]]
|
[[metadata.targets]]
|
||||||
requires_python = ">=3.12,<3.13"
|
requires_python = ">=3.12,<3.13"
|
||||||
|
@ -522,13 +522,13 @@ files = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jax"
|
name = "jax"
|
||||||
version = "0.4.35"
|
version = "0.4.37"
|
||||||
requires_python = ">=3.10"
|
requires_python = ">=3.10"
|
||||||
summary = "Differentiate, compile, and transform Numpy code."
|
summary = "Differentiate, compile, and transform Numpy code."
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"jaxlib<=0.4.35,>=0.4.34",
|
"jaxlib<=0.4.37,>=0.4.36",
|
||||||
"ml-dtypes>=0.4.0",
|
"ml-dtypes>=0.4.0",
|
||||||
"numpy>=1.24",
|
"numpy>=1.24",
|
||||||
"numpy>=1.26.0; python_version >= \"3.12\"",
|
"numpy>=1.26.0; python_version >= \"3.12\"",
|
||||||
|
@ -537,13 +537,83 @@ dependencies = [
|
||||||
"scipy>=1.11.1; python_version >= \"3.12\"",
|
"scipy>=1.11.1; python_version >= \"3.12\"",
|
||||||
]
|
]
|
||||||
files = [
|
files = [
|
||||||
{file = "jax-0.4.35-py3-none-any.whl", hash = "sha256:fa99e909a31424abfec750019a6dd36f6acc18a6e7d40e2c0086b932cc351325"},
|
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||||
{file = "jax-0.4.35.tar.gz", hash = "sha256:c0c986993026b10bf6f607fecb7417377460254640766ce40f1fef3fd139c12e"},
|
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jax-cuda12-pjrt"
|
||||||
|
version = "0.4.36"
|
||||||
|
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
|
||||||
|
groups = ["default"]
|
||||||
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
|
files = [
|
||||||
|
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
|
||||||
|
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jax-cuda12-plugin"
|
||||||
|
version = "0.4.36"
|
||||||
|
requires_python = ">=3.10"
|
||||||
|
summary = "JAX Plugin for NVIDIA GPUs"
|
||||||
|
groups = ["default"]
|
||||||
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
|
dependencies = [
|
||||||
|
"jax-cuda12-pjrt==0.4.36",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||||
|
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jax-cuda12-plugin"
|
||||||
|
version = "0.4.36"
|
||||||
|
extras = ["with_cuda"]
|
||||||
|
requires_python = ">=3.10"
|
||||||
|
summary = "JAX Plugin for NVIDIA GPUs"
|
||||||
|
groups = ["default"]
|
||||||
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
|
dependencies = [
|
||||||
|
"jax-cuda12-plugin==0.4.36",
|
||||||
|
"nvidia-cublas-cu12>=12.1.3.1",
|
||||||
|
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||||
|
"nvidia-cuda-nvcc-cu12>=12.6.85",
|
||||||
|
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||||
|
"nvidia-cudnn-cu12<10.0,>=9.1",
|
||||||
|
"nvidia-cufft-cu12>=11.0.2.54",
|
||||||
|
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||||
|
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||||
|
"nvidia-nccl-cu12>=2.18.1",
|
||||||
|
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||||
|
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "jax"
|
||||||
|
version = "0.4.37"
|
||||||
|
extras = ["cuda12"]
|
||||||
|
requires_python = ">=3.10"
|
||||||
|
summary = "Differentiate, compile, and transform Numpy code."
|
||||||
|
groups = ["default"]
|
||||||
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
|
dependencies = [
|
||||||
|
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
|
||||||
|
"jax==0.4.37",
|
||||||
|
"jaxlib==0.4.36",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||||
|
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "jaxlib"
|
name = "jaxlib"
|
||||||
version = "0.4.35"
|
version = "0.4.36"
|
||||||
requires_python = ">=3.10"
|
requires_python = ">=3.10"
|
||||||
summary = "XLA library for JAX"
|
summary = "XLA library for JAX"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
|
@ -555,16 +625,11 @@ dependencies = [
|
||||||
"scipy>=1.11.1; python_version >= \"3.12\"",
|
"scipy>=1.11.1; python_version >= \"3.12\"",
|
||||||
]
|
]
|
||||||
files = [
|
files = [
|
||||||
{file = "jaxlib-0.4.35-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:b44f3e6e9fb748bb43df914356cf9d0d0c9a6e446a12c21fe843db25ed0df65f"},
|
{file = "jaxlib-0.4.36-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:5972aa85f6d771ecc8cc72148c1fa64250ca33cbdf2bf24407cdee8a5299d25d"},
|
||||||
{file = "jaxlib-0.4.35-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:504d0a2e2117724359d99d7e3663022686dcdddd85aa14bdad02008d444481ad"},
|
{file = "jaxlib-0.4.36-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5597908cd10418c0b42e9af807fc8112036703533cf501a5255a8fbf4011867e"},
|
||||||
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:187cb6929dc139b75d952d67c33118473c1b4105525a3e5607f064e7b8efdc74"},
|
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:fbbabaa287378a78a3cf9cbe4de30a1f6f19a99116feb4bd687ff256415cd442"},
|
||||||
{file = "jaxlib-0.4.35-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:04d1db3bf0050d120238bfb9b686b58fefcc4d9dd9e2d96aecd3f68a1f1f5e0a"},
|
{file = "jaxlib-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:be295abc209c980817db0488f21f1fbc0644f87326522895e2b9b64729106357"},
|
||||||
{file = "jaxlib-0.4.35-cp312-cp312-win_amd64.whl", hash = "sha256:dddffce48d7e6057008999aed2d8a9daecc57a48c45a4f8c475e00880eb2e41d"},
|
{file = "jaxlib-0.4.36-cp312-cp312-win_amd64.whl", hash = "sha256:d4bbb5d2970628dcd3dabc28a5b97a1125ad3e06a1be822d340fd9f06f7449b3"},
|
||||||
{file = "jaxlib-0.4.35-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:14aeac3fea2ca1d5afb1878f72470b159cc89adb2633c5f0686f5d7c39f2ac18"},
|
|
||||||
{file = "jaxlib-0.4.35-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e8c9579e20d5ecdc4f61336cdd032710cb8c38d5ae9c4fce0cf9ea031cef21cb"},
|
|
||||||
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_aarch64.whl", hash = "sha256:7b11ad7c13f7f96f36efd303711ecac425f19ca2ddf65cf1be1541167a959ee5"},
|
|
||||||
{file = "jaxlib-0.4.35-cp313-cp313-manylinux2014_x86_64.whl", hash = "sha256:0be3cf9df879d9ae1b5b92fc281f77d21f522fcbae1a48a02661026bbd9b9309"},
|
|
||||||
{file = "jaxlib-0.4.35-cp313-cp313-win_amd64.whl", hash = "sha256:330c090bb9af413f552d8a92d097e50baec6b75823430fb2966a49f5298d4c43"},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -861,7 +926,7 @@ version = "12.4.5.8"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUBLAS native runtime libraries"
|
summary = "CUBLAS native runtime libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
||||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
||||||
|
@ -874,13 +939,26 @@ version = "12.4.127"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUDA profiling tools runtime libs."
|
summary = "CUDA profiling tools runtime libs."
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
||||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
||||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "nvidia-cuda-nvcc-cu12"
|
||||||
|
version = "12.6.85"
|
||||||
|
requires_python = ">=3"
|
||||||
|
summary = "CUDA nvcc"
|
||||||
|
groups = ["default"]
|
||||||
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
|
files = [
|
||||||
|
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
|
||||||
|
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
|
||||||
|
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nvidia-cuda-nvrtc-cu12"
|
name = "nvidia-cuda-nvrtc-cu12"
|
||||||
version = "12.4.127"
|
version = "12.4.127"
|
||||||
|
@ -900,7 +978,7 @@ version = "12.4.127"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUDA Runtime native Libraries"
|
summary = "CUDA Runtime native Libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
||||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
||||||
|
@ -913,7 +991,7 @@ version = "9.1.0.70"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "cuDNN runtime libraries"
|
summary = "cuDNN runtime libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"nvidia-cublas-cu12",
|
"nvidia-cublas-cu12",
|
||||||
]
|
]
|
||||||
|
@ -928,7 +1006,7 @@ version = "11.2.1.3"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUFFT native runtime libraries"
|
summary = "CUFFT native runtime libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"nvidia-nvjitlink-cu12",
|
"nvidia-nvjitlink-cu12",
|
||||||
]
|
]
|
||||||
|
@ -957,7 +1035,7 @@ version = "11.6.1.9"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUDA solver native runtime libraries"
|
summary = "CUDA solver native runtime libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"nvidia-cublas-cu12",
|
"nvidia-cublas-cu12",
|
||||||
"nvidia-cusparse-cu12",
|
"nvidia-cusparse-cu12",
|
||||||
|
@ -975,7 +1053,7 @@ version = "12.3.1.170"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "CUSPARSE native runtime libraries"
|
summary = "CUSPARSE native runtime libraries"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"nvidia-nvjitlink-cu12",
|
"nvidia-nvjitlink-cu12",
|
||||||
]
|
]
|
||||||
|
@ -991,7 +1069,7 @@ version = "2.21.5"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||||
]
|
]
|
||||||
|
@ -1002,7 +1080,7 @@ version = "12.4.127"
|
||||||
requires_python = ">=3"
|
requires_python = ">=3"
|
||||||
summary = "Nvidia JIT LTO Library"
|
summary = "Nvidia JIT LTO Library"
|
||||||
groups = ["default"]
|
groups = ["default"]
|
||||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||||
files = [
|
files = [
|
||||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||||
|
|
|
@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "saji", email = "saji@saji.dev"},
|
{name = "saji", email = "saji@saji.dev"},
|
||||||
]
|
]
|
||||||
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.35", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
|
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
|
||||||
requires-python = ">=3.10,<3.13"
|
requires-python = ">=3.10,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
"""Physical equations and models for building a simulation environment"""
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import jax
|
import jax
|
||||||
from jax import grad, jit, vmap, lax
|
from jax import grad, jit, vmap, lax
|
||||||
|
@ -5,7 +7,11 @@ from functools import partial
|
||||||
|
|
||||||
from typing import NamedTuple, Tuple
|
from typing import NamedTuple, Tuple
|
||||||
|
|
||||||
from solarcarsim.noise import fractal_noise_1d, generate_elevation_profile, generate_wind_field
|
from solarcarsim.noise import (
|
||||||
|
fractal_noise_1d,
|
||||||
|
generate_wind_field,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MotorParams(NamedTuple):
|
class MotorParams(NamedTuple):
|
||||||
kv: float
|
kv: float
|
||||||
|
@ -20,8 +26,10 @@ class BatteryParams(NamedTuple):
|
||||||
resistance: float # ohms
|
resistance: float # ohms
|
||||||
initial_energy: float # joules
|
initial_energy: float # joules
|
||||||
|
|
||||||
|
|
||||||
class CarParams(NamedTuple):
|
class CarParams(NamedTuple):
|
||||||
""" Physical Data for Solar Car Parameters """
|
"""Physical Data for Solar Car Parameters"""
|
||||||
|
|
||||||
mass: float = 800 # kg
|
mass: float = 800 # kg
|
||||||
frontal_area: float = 1.3 # m^2
|
frontal_area: float = 1.3 # m^2
|
||||||
drag_coeff: float = 0.18 # drag coefficient, dimensionless
|
drag_coeff: float = 0.18 # drag coefficient, dimensionless
|
||||||
|
@ -32,17 +40,19 @@ class CarParams(NamedTuple):
|
||||||
solar_area: float = 5.0 # m^2, typically 5.0
|
solar_area: float = 5.0 # m^2, typically 5.0
|
||||||
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
|
solar_eff: float = 0.20 # 0 < x < 1, typically ~.25
|
||||||
n_motors: int = 2 # how many motors we have.
|
n_motors: int = 2 # how many motors we have.
|
||||||
motor: MotorParams = MotorParams(8.43, 1.1, 100.0, 0.001, 0.001) # mitsuba m2090 estimate
|
motor: MotorParams = MotorParams(
|
||||||
battery: BatteryParams = BatteryParams((36,19), 0.0126, 66.6e3) # freebasing 50s pack.
|
8.43, 1.1, 100.0, 0.001, 0.001
|
||||||
|
) # mitsuba m2090 estimate
|
||||||
|
battery: BatteryParams = BatteryParams(
|
||||||
|
(36, 19), 0.0126, 66.6e3
|
||||||
|
) # freebasing 50s pack.
|
||||||
|
|
||||||
|
|
||||||
def DefaultCar() -> CarParams:
|
def DefaultCar() -> CarParams:
|
||||||
""" Creates a basic car """
|
"""Creates a basic car"""
|
||||||
return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23)
|
return CarParams(1000, 1.3, 0.18, 0.002, 0.85, 5.0, 0.23)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# some physics equations using jax
|
# some physics equations using jax
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,20 +60,25 @@ def DefaultCar() -> CarParams:
|
||||||
def normal_force(mass, theta):
|
def normal_force(mass, theta):
|
||||||
return mass * 9.8 * jnp.cos(theta)
|
return mass * 9.8 * jnp.cos(theta)
|
||||||
|
|
||||||
|
|
||||||
@jit
|
@jit
|
||||||
def downslope_force(mass, theta):
|
def downslope_force(mass, theta):
|
||||||
return mass * 9.8 * jnp.sin(theta)
|
return mass * 9.8 * jnp.sin(theta)
|
||||||
|
|
||||||
@partial(jit, static_argnames=['crr'])
|
|
||||||
|
@partial(jit, static_argnames=["crr"])
|
||||||
def rolling_force(mass, theta, crr):
|
def rolling_force(mass, theta, crr):
|
||||||
return normal_force(mass, theta) * crr
|
return normal_force(mass, theta) * crr
|
||||||
|
|
||||||
@partial(jit, static_argnames=['area', 'cd', 'rho'])
|
|
||||||
|
@partial(jit, static_argnames=["area", "cd", "rho"])
|
||||||
def drag_force(u, area, cd, rho):
|
def drag_force(u, area, cd, rho):
|
||||||
return 0.5 * rho * jnp.pow(u, 2) * cd * area
|
return 0.5 * rho * jnp.pow(u, 2) * cd * area
|
||||||
|
|
||||||
|
|
||||||
# we can use those forces above to determine what forces we have to overcome. Sum(F)=0
|
# we can use those forces above to determine what forces we have to overcome. Sum(F)=0
|
||||||
|
|
||||||
|
|
||||||
# @partial(jit, static_argnums=(2,))
|
# @partial(jit, static_argnums=(2,))
|
||||||
@jit
|
@jit
|
||||||
def bldc_power_draw(torque, velocity, params: MotorParams):
|
def bldc_power_draw(torque, velocity, params: MotorParams):
|
||||||
|
@ -99,10 +114,10 @@ def bldc_power_draw(torque, velocity, params: MotorParams):
|
||||||
|
|
||||||
return total_power
|
return total_power
|
||||||
|
|
||||||
|
|
||||||
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
|
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
|
||||||
@jit
|
@jit
|
||||||
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
||||||
|
|
||||||
bemf = velocity / kv
|
bemf = velocity / kv
|
||||||
v_avail = jnp.clip(vmax - bemf, 0.0, vmax)
|
v_avail = jnp.clip(vmax - bemf, 0.0, vmax)
|
||||||
current = jnp.clip(v_avail / resistance, 0.0, current_limit)
|
current = jnp.clip(v_avail / resistance, 0.0, current_limit)
|
||||||
|
@ -113,8 +128,15 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
||||||
stall_torque = kt * current_limit
|
stall_torque = kt * current_limit
|
||||||
return jnp.where(velocity < 0.01, stall_torque, net_torque)
|
return jnp.where(velocity < 0.01, stall_torque, net_torque)
|
||||||
|
|
||||||
@partial(jit, static_argnums=(1,2,))
|
|
||||||
def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
|
@partial(
|
||||||
|
jit,
|
||||||
|
static_argnums=(
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]):
|
||||||
r_array = jnp.full(battery_shape, cell_r)
|
r_array = jnp.full(battery_shape, cell_r)
|
||||||
branch_current = current / battery_shape[1]
|
branch_current = current / battery_shape[1]
|
||||||
I_array = jnp.full(battery_shape, branch_current)
|
I_array = jnp.full(battery_shape, branch_current)
|
||||||
|
@ -122,7 +144,6 @@ def battery_powerloss(current,cell_r, battery_shape: Tuple[int,int]):
|
||||||
return jnp.sum(cell_Ploss)
|
return jnp.sum(cell_Ploss)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(state, timestep, control, params: CarParams):
|
def forward(state, timestep, control, params: CarParams):
|
||||||
# state is (position, time, energy)
|
# state is (position, time, energy)
|
||||||
# control is velocity
|
# control is velocity
|
||||||
|
@ -145,12 +166,18 @@ def forward(state, timestep, control, params: CarParams):
|
||||||
# TODO: support regenerative braking when going downhill
|
# TODO: support regenerative braking when going downhill
|
||||||
# TODO: delta x = cos(theta) * velocity * timestep
|
# TODO: delta x = cos(theta) * velocity * timestep
|
||||||
|
|
||||||
new_state = jnp.array([state[0] + control * timestep, state[1] + timestep, state[2] + net_power * timestep])
|
new_state = jnp.array(
|
||||||
|
[
|
||||||
|
state[0] + control * timestep,
|
||||||
|
state[1] + timestep,
|
||||||
|
state[2] + net_power * timestep,
|
||||||
|
]
|
||||||
|
)
|
||||||
return new_state
|
return new_state
|
||||||
|
|
||||||
|
|
||||||
def make_environment(seed):
|
def make_environment(seed):
|
||||||
""" Generate a race environment: terrain function, wind function, wrapped forward function."""
|
"""Generate a race environment: terrain function, wind function, wrapped forward function."""
|
||||||
key, subkey = jax.random.split(seed)
|
key, subkey = jax.random.split(seed)
|
||||||
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
|
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
|
||||||
key, subkey = jax.random.split(key)
|
key, subkey = jax.random.split(key)
|
||||||
|
@ -161,46 +188,3 @@ def make_environment(seed):
|
||||||
|
|
||||||
return wind, elevation, slope
|
return wind, elevation, slope
|
||||||
|
|
||||||
@partial(jit, static_argnames=['params'])
|
|
||||||
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
|
||||||
pos = jnp.astype(jnp.round(state[0]), "int32")
|
|
||||||
time = jnp.astype(jnp.round(state[1]), "int32")
|
|
||||||
theta = slope[pos]
|
|
||||||
|
|
||||||
velocity = control * params.max_speed
|
|
||||||
|
|
||||||
# sum up the forces acting on the car
|
|
||||||
dragf = drag_force(velocity, params.frontal_area, params.drag_coeff, 1.184)
|
|
||||||
rollf = rolling_force(params.mass, theta, params.rolling_coeff)
|
|
||||||
hillforce = downslope_force(params.mass, theta)
|
|
||||||
windf = wind[pos, time]
|
|
||||||
totalf = dragf + rollf + hillforce + windf
|
|
||||||
# with the sum of forces, determine the needed torque at the wheels, and then power
|
|
||||||
tau = params.wheel_radius * totalf
|
|
||||||
pdraw = bldc_power_draw(tau, velocity, params.motor)
|
|
||||||
# determine the energy needed to do this power for the time step
|
|
||||||
net_power = state[2] - delta_time * pdraw # joules
|
|
||||||
|
|
||||||
dpos = jnp.cos(theta) * velocity * delta_time
|
|
||||||
dist_remaining = 10000.0 - dpos
|
|
||||||
time_remaining = 600 - (state[1] + delta_time)
|
|
||||||
return jnp.array([dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining])
|
|
||||||
|
|
||||||
def reward(state):
|
|
||||||
progress = state[0] / 10000 * 100
|
|
||||||
energy_usage = -10 * state[2]
|
|
||||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
|
||||||
reward = progress + energy_usage + time_factor
|
|
||||||
return reward
|
|
||||||
# now we have an environment tuned in.
|
|
||||||
# we want to take an environment, and bind it to the forward function
|
|
||||||
def make_simulator(params: CarParams, wind, elevation, slope):
|
|
||||||
def reward(state):
|
|
||||||
progress = state[0] / 10000 * 100
|
|
||||||
energy_usage = -10 * state[2]
|
|
||||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
|
||||||
reward = progress + energy_usage + time_factor
|
|
||||||
return reward
|
|
||||||
return forwardv2, reward
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,10 +2,45 @@ import gymnasium as gym
|
||||||
import solarcarsim.physsim as sim
|
import solarcarsim.physsim as sim
|
||||||
import jax
|
import jax
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import numpy as np
|
from jax import jit
|
||||||
from typing import Any
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from jax import vmap
|
from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw
|
||||||
|
|
||||||
|
|
||||||
|
@partial(jit, static_argnames=["params"])
|
||||||
|
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
||||||
|
pos = jnp.astype(jnp.round(state[0]), "int32")
|
||||||
|
time = jnp.astype(jnp.round(state[1]), "int32")
|
||||||
|
theta = slope[pos]
|
||||||
|
|
||||||
|
velocity = control * params.max_speed
|
||||||
|
|
||||||
|
# sum up the forces acting on the car
|
||||||
|
windspeed = wind[pos, time]
|
||||||
|
dragf = sim.drag_force(velocity + windspeed, params.frontal_area, params.drag_coeff, 1.184)
|
||||||
|
rollf = sim.rolling_force(params.mass, theta, params.rolling_coeff)
|
||||||
|
hillforce = sim.downslope_force(params.mass, theta)
|
||||||
|
totalf = dragf + rollf + hillforce
|
||||||
|
# with the sum of forces, determine the needed torque at the wheels, and then power
|
||||||
|
tau = params.wheel_radius * totalf
|
||||||
|
pdraw = bldc_power_draw(tau, velocity, params.motor)
|
||||||
|
# determine the energy needed to do this power for the time step
|
||||||
|
net_power = state[2] - delta_time * pdraw # joules
|
||||||
|
|
||||||
|
dpos = jnp.cos(theta) * velocity * delta_time
|
||||||
|
dist_remaining = 10000.0 - dpos
|
||||||
|
time_remaining = 600 - (state[1] + delta_time)
|
||||||
|
return jnp.array(
|
||||||
|
[dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reward(state, prev_energy):
|
||||||
|
progress = state[0] / 10000 * 100
|
||||||
|
energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy.
|
||||||
|
time_factor = (1.0 - (state[1] / 600)) * 50
|
||||||
|
reward = progress + energy_usage + time_factor
|
||||||
|
return reward
|
||||||
|
|
||||||
class SolarRaceV1(gym.Env):
|
class SolarRaceV1(gym.Env):
|
||||||
"""A primitive hill climber. Aims to solve the given route optimizing
|
"""A primitive hill climber. Aims to solve the given route optimizing
|
||||||
|
@ -46,8 +81,8 @@ class SolarRaceV1(gym.Env):
|
||||||
self._reset_sim(jax.random.key(seed))
|
self._reset_sim(jax.random.key(seed))
|
||||||
self._timestep = timestep
|
self._timestep = timestep
|
||||||
self._car = car
|
self._car = car
|
||||||
self._simstep = sim.forwardv2
|
self._simstep = forwardv2
|
||||||
self._simreward = sim.reward
|
self._simreward = reward
|
||||||
|
|
||||||
self.observation_space = gym.spaces.Dict(
|
self.observation_space = gym.spaces.Dict(
|
||||||
{
|
{
|
||||||
|
@ -73,8 +108,10 @@ class SolarRaceV1(gym.Env):
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
wind, elevation, slope = self._environment
|
wind, elevation, slope = self._environment
|
||||||
|
|
||||||
|
old_energy = self._state[2]
|
||||||
|
|
||||||
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
|
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
|
||||||
reward = self._simreward(self._state)[0]
|
reward = self._simreward(self._state, old_energy)[0]
|
||||||
terminated = False
|
terminated = False
|
||||||
truncated = False
|
truncated = False
|
||||||
if jnp.all(self._state[0] > 10000):
|
if jnp.all(self._state[0] > 10000):
|
||||||
|
@ -83,3 +120,5 @@ class SolarRaceV1(gym.Env):
|
||||||
truncated = True
|
truncated = True
|
||||||
|
|
||||||
return self._get_obs(), reward, terminated, truncated, {}
|
return self._get_obs(), reward, terminated, truncated, {}
|
||||||
|
|
||||||
|
|
14
src/solarcarsim/simv2.py
Normal file
14
src/solarcarsim/simv2.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
""" Second-generation simulator. More functional, cleaner code, faster """
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
|
||||||
|
class SimState(NamedTuple):
|
||||||
|
position: float
|
||||||
|
time: float
|
||||||
|
energy: float
|
||||||
|
distance_remaining: float
|
||||||
|
time_remaining: float
|
||||||
|
|
Loading…
Reference in a new issue