simv2
This commit is contained in:
parent
aaac5df404
commit
70a659f468
10
README.md
10
README.md
|
@ -1 +1,11 @@
|
|||
# solarcarsim
|
||||
|
||||
|
||||
|
||||
TODO:
|
||||
fix wind (velocity + wind for drag)
|
||||
make more functional
|
||||
cleanup sim code
|
||||
parameterize the environment
|
||||
vectorize
|
||||
cleanrl jax td3
|
65
notebooks/testing.ipynb
Normal file
65
notebooks/testing.ipynb
Normal file
|
@ -0,0 +1,65 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"Array([ 0.0000000e+00, -2.5544850e-07, -1.4012958e-06, ...,\n",
|
||||
" -1.1142221e-02, -1.1067827e-02, -1.1001030e-02], dtype=float32)"
|
||||
]
|
||||
},
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import jax\n",
|
||||
"import jax.numpy as jnp\n",
|
||||
"from solarcarsim.physsim import CarParams, fractal_noise_1d\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"key = jax.random.key(0)\n",
|
||||
"\n",
|
||||
"slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08)\n",
|
||||
"\n",
|
||||
"slope"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get an array of positions\n",
|
||||
"positions = jnp.array([1.1,2.2,3.3,5,200.0], dtype=jnp.float32)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
247
notebooks/v1gym.ipynb
Normal file
247
notebooks/v1gym.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -564,7 +564,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -576,7 +576,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
|
263
pdm.lock
263
pdm.lock
|
@ -5,7 +5,7 @@
|
|||
groups = ["default", "dev"]
|
||||
strategy = ["inherit_metadata"]
|
||||
lock_version = "4.5.0"
|
||||
content_hash = "sha256:81e26f71acf1a583b21280b235fa2ac16165ac824ae8483bd391b88406421aa4"
|
||||
content_hash = "sha256:a3b65f863c554725c33d452fd759776141740661fa3555d306ed08563a7e16e2"
|
||||
|
||||
[[metadata.targets]]
|
||||
requires_python = ">=3.12,<3.13"
|
||||
|
@ -152,7 +152,7 @@ version = "0.4.6"
|
|||
requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
summary = "Cross-platform colored terminal text."
|
||||
groups = ["default", "dev"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\" and sys_platform == \"win32\""
|
||||
marker = "sys_platform == \"win32\" and python_version >= \"3.12\" and python_version < \"3.13\" or platform_system == \"Windows\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
|
@ -230,13 +230,30 @@ name = "decorator"
|
|||
version = "5.1.1"
|
||||
requires_python = ">=3.5"
|
||||
summary = "Decorators for Humans"
|
||||
groups = ["dev"]
|
||||
groups = ["default", "dev"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
|
||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "dm-tree"
|
||||
version = "0.1.8"
|
||||
summary = "Tree is a library for working with nested data structures."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "dm-tree-0.1.8.tar.gz", hash = "sha256:0fcaabbb14e7980377439e7140bd05552739ca5e515ecb3119f234acee4b9430"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ea9e59e0451e7d29aece402d9f908f2e2a80922bcde2ebfd5dcb07750fcbfee8"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:94d3f0826311f45ee19b75f5b48c99466e4218a0489e81c0f0167bda50cacf22"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:435227cf3c5dc63f4de054cf3d00183790bd9ead4c3623138c74dde7f67f521b"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09964470f76a5201aff2e8f9b26842976de7889300676f927930f6285e256760"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:75c5d528bb992981c20793b6b453e91560784215dffb8a5440ba999753c14ceb"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0a94aba18a35457a1b5cd716fd7b46c5dafdc4cf7869b4bae665b91c4682a8e"},
|
||||
{file = "dm_tree-0.1.8-cp312-cp312-win_amd64.whl", hash = "sha256:96a548a406a6fb15fe58f6a30a57ff2f2aafbf25f05afab00c8f5e5977b6c715"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "etils"
|
||||
version = "1.11.0"
|
||||
|
@ -379,6 +396,47 @@ files = [
|
|||
{file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gast"
|
||||
version = "0.6.0"
|
||||
requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,>=2.7"
|
||||
summary = "Python AST that abstracts the underlying Python version"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "gast-0.6.0-py3-none-any.whl", hash = "sha256:52b182313f7330389f72b069ba00f174cfe2a06411099547288839c6cbafbd54"},
|
||||
{file = "gast-0.6.0.tar.gz", hash = "sha256:88fc5300d32c7ac6ca7b515310862f71e6fdf2c029bbec7c66c0f5dd47b6b1fb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym"
|
||||
version = "0.26.2"
|
||||
requires_python = ">=3.6"
|
||||
summary = "Gym: A universal API for reinforcement learning environments"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"cloudpickle>=1.2.0",
|
||||
"dataclasses==0.8; python_version == \"3.6\"",
|
||||
"gym-notices>=0.0.4",
|
||||
"importlib-metadata>=4.8.0; python_version < \"3.10\"",
|
||||
"numpy>=1.18.0",
|
||||
]
|
||||
files = [
|
||||
{file = "gym-0.26.2.tar.gz", hash = "sha256:e0d882f4b54f0c65f203104c24ab8a38b039f1289986803c7d02cdbe214fbcc4"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gym-notices"
|
||||
version = "0.0.8"
|
||||
summary = "Notices for gym"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "gym-notices-0.0.8.tar.gz", hash = "sha256:ad25e200487cafa369728625fe064e88ada1346618526102659b4640f2b4b911"},
|
||||
{file = "gym_notices-0.0.8-py3-none-any.whl", hash = "sha256:e5f82e00823a166747b4c2a07de63b6560b1acb880638547e0cabf825a01e463"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gymnasium"
|
||||
version = "1.0.0"
|
||||
|
@ -417,6 +475,29 @@ files = [
|
|||
{file = "gymnasium-1.0.0.tar.gz", hash = "sha256:9d2b66f30c1b34fe3c2ce7fae65ecf365d0e9982d2b3d860235e773328a3b403"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gymnax"
|
||||
version = "0.0.8"
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX-compatible version of Open AI's gym environments"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"chex",
|
||||
"flax",
|
||||
"gym>=0.26",
|
||||
"gymnasium",
|
||||
"jax",
|
||||
"jaxlib",
|
||||
"matplotlib",
|
||||
"pyyaml",
|
||||
"seaborn",
|
||||
]
|
||||
files = [
|
||||
{file = "gymnax-0.0.8-py3-none-any.whl", hash = "sha256:0af7edde1b71d74be8007ffe1e6338f8ce66693b1b78ae479c0c0cd02b10de03"},
|
||||
{file = "gymnax-0.0.8.tar.gz", hash = "sha256:81defc17f52a30a84338b3daa574d7a3bb112f2656f45c783a71efe31eea68ff"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "humanize"
|
||||
version = "4.11.0"
|
||||
|
@ -541,76 +622,6 @@ files = [
|
|||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-pjrt"
|
||||
version = "0.4.36"
|
||||
summary = "JAX XLA PJRT Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1dfc0bec0820ba801b61e9421064b6e58238c430b4ad8f54043323d93c0217c6"},
|
||||
{file = "jax_cuda12_pjrt-0.4.36-py3-none-manylinux2014_x86_64.whl", hash = "sha256:e3c3705d8db7d63da9abfaebf06f5cd0667f5acb0748a5c5eb00d80041e922ed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-pjrt==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax-cuda12-plugin"
|
||||
version = "0.4.36"
|
||||
extras = ["with_cuda"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "JAX Plugin for NVIDIA GPUs"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin==0.4.36",
|
||||
"nvidia-cublas-cu12>=12.1.3.1",
|
||||
"nvidia-cuda-cupti-cu12>=12.1.105",
|
||||
"nvidia-cuda-nvcc-cu12>=12.6.85",
|
||||
"nvidia-cuda-runtime-cu12>=12.1.105",
|
||||
"nvidia-cudnn-cu12<10.0,>=9.1",
|
||||
"nvidia-cufft-cu12>=11.0.2.54",
|
||||
"nvidia-cusolver-cu12>=11.4.5.107",
|
||||
"nvidia-cusparse-cu12>=12.1.0.106",
|
||||
"nvidia-nccl-cu12>=2.18.1",
|
||||
"nvidia-nvjitlink-cu12>=12.1.105",
|
||||
]
|
||||
files = [
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:6a0b0c2bdc1da2eea2c20723a1e8f39b3cda67d24c665de936647e8091f5790d"},
|
||||
{file = "jax_cuda12_plugin-0.4.36-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:5d4727fb519fedc06a9a984d5a0714804d81ef126a2cb60cefd5cbc4a3ea2627"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jax"
|
||||
version = "0.4.37"
|
||||
extras = ["cuda12"]
|
||||
requires_python = ">=3.10"
|
||||
summary = "Differentiate, compile, and transform Numpy code."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"jax-cuda12-plugin[with_cuda]<=0.4.37,>=0.4.36",
|
||||
"jax==0.4.37",
|
||||
"jaxlib==0.4.36",
|
||||
]
|
||||
files = [
|
||||
{file = "jax-0.4.37-py3-none-any.whl", hash = "sha256:bdc0686d7e5a944e2d38026eae632214d98dd2d91869cbcedbf1c11298ae3e3e"},
|
||||
{file = "jax-0.4.37.tar.gz", hash = "sha256:7774f3d9e23fe199c65589c680c5a5be87a183b89598421a632d8245222b637b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jaxlib"
|
||||
version = "0.4.36"
|
||||
|
@ -926,7 +937,7 @@ version = "12.4.5.8"
|
|||
requires_python = ">=3"
|
||||
summary = "CUBLAS native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl", hash = "sha256:0f8aa1706812e00b9f19dfe0cdb3999b092ccb8ca168c0db5b8ea712456fd9b3"},
|
||||
{file = "nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl", hash = "sha256:2fc8da60df463fdefa81e323eef2e36489e1c94335b5358bcb38360adf75ac9b"},
|
||||
|
@ -939,26 +950,13 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA profiling tools runtime libs."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:79279b35cf6f91da114182a5ce1864997fd52294a87a16179ce275773799458a"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:9dec60f5ac126f7bb551c055072b69d85392b13311fcc1bcda2202d172df30fb"},
|
||||
{file = "nvidia_cuda_cupti_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:5688d203301ab051449a2b1cb6690fbe90d2b372f411521c86018b950f3d7922"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvcc-cu12"
|
||||
version = "12.6.85"
|
||||
requires_python = ">=3"
|
||||
summary = "CUDA nvcc"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux1_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d75d9d74599f4d7c0865df19ed21b739e6cb77a6497a3f73d6f61e8038a765e4"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d2edd5531b13e3daac8ffee9fc2b70a147e6088b2af2565924773d63d36d294"},
|
||||
{file = "nvidia_cuda_nvcc_cu12-12.6.85-py3-none-win_amd64.whl", hash = "sha256:aa04742337973dcb5bcccabb590edc8834c60ebfaf971847888d24ffef6c46b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nvidia-cuda-nvrtc-cu12"
|
||||
version = "12.4.127"
|
||||
|
@ -978,7 +976,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA Runtime native Libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:961fe0e2e716a2a1d967aab7caee97512f71767f852f67432d572e36cb3a11f3"},
|
||||
{file = "nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:64403288fa2136ee8e467cdc9c9427e0434110899d07c779f25b5c068934faa5"},
|
||||
|
@ -991,7 +989,7 @@ version = "9.1.0.70"
|
|||
requires_python = ">=3"
|
||||
summary = "cuDNN runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
]
|
||||
|
@ -1006,7 +1004,7 @@ version = "11.2.1.3"
|
|||
requires_python = ">=3"
|
||||
summary = "CUFFT native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -1035,7 +1033,7 @@ version = "11.6.1.9"
|
|||
requires_python = ">=3"
|
||||
summary = "CUDA solver native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-cublas-cu12",
|
||||
"nvidia-cusparse-cu12",
|
||||
|
@ -1053,7 +1051,7 @@ version = "12.3.1.170"
|
|||
requires_python = ">=3"
|
||||
summary = "CUSPARSE native runtime libraries"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"nvidia-nvjitlink-cu12",
|
||||
]
|
||||
|
@ -1069,7 +1067,7 @@ version = "2.21.5"
|
|||
requires_python = ">=3"
|
||||
summary = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:8579076d30a8c24988834445f8d633c697d42397e92ffc3f63fa26766d25e0a0"},
|
||||
]
|
||||
|
@ -1080,7 +1078,7 @@ version = "12.4.127"
|
|||
requires_python = ">=3"
|
||||
summary = "Nvidia JIT LTO Library"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
marker = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
files = [
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl", hash = "sha256:4abe7fef64914ccfa909bc2ba39739670ecc9e820c83ccc7a6ed414122599b83"},
|
||||
{file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"},
|
||||
|
@ -1653,6 +1651,29 @@ files = [
|
|||
{file = "rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sbx-rl"
|
||||
version = "0.18.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Jax version of Stable Baselines, implementations of reinforcement learning algorithms."
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"flax",
|
||||
"jax",
|
||||
"jaxlib",
|
||||
"optax; python_version >= \"3.9.0\"",
|
||||
"optax<0.1.8; python_version < \"3.9.0\"",
|
||||
"rich",
|
||||
"stable-baselines3<3.0,>=2.4.0a4",
|
||||
"tensorflow-probability",
|
||||
"tqdm",
|
||||
]
|
||||
files = [
|
||||
{file = "sbx_rl-0.18.0-py3-none-any.whl", hash = "sha256:75ade634a33555ad4c4a81523bb0f99c89d1b3bc89fb74990ef87b22379abd9c"},
|
||||
{file = "sbx_rl-0.18.0.tar.gz", hash = "sha256:670f2bf095ec21ba6f8171602294baf0123787fe7be6811ebab276fb5010b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.14.1"
|
||||
|
@ -1695,6 +1716,23 @@ files = [
|
|||
{file = "scooby-0.10.0.tar.gz", hash = "sha256:7ea33c262c0cc6a33c6eeeb5648df787be4f22660e53c114e5fff1b811a8854f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "seaborn"
|
||||
version = "0.13.2"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Statistical data visualization"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"matplotlib!=3.6.1,>=3.4",
|
||||
"numpy!=1.24.0,>=1.20",
|
||||
"pandas>=1.2",
|
||||
]
|
||||
files = [
|
||||
{file = "seaborn-0.13.2-py3-none-any.whl", hash = "sha256:636f8336facf092165e27924f223d3c62ca560b1f2bb5dff7ab7fad265361987"},
|
||||
{file = "seaborn-0.13.2.tar.gz", hash = "sha256:93e60a40988f4d65e9f4885df477e2fdaff6b73a9ded434c1ab356dd57eefff7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "75.6.0"
|
||||
|
@ -1809,6 +1847,26 @@ files = [
|
|||
{file = "sympy-1.13.1.tar.gz", hash = "sha256:9cebf7e04ff162015ce31c9c6c9144daa34a93bd082f54fd8f12deca4f47515f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorflow-probability"
|
||||
version = "0.25.0"
|
||||
requires_python = ">=3.9"
|
||||
summary = "Probabilistic modeling and statistical inference in TensorFlow"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"absl-py",
|
||||
"cloudpickle>=1.3",
|
||||
"decorator",
|
||||
"dm-tree",
|
||||
"gast>=0.3.2",
|
||||
"numpy>=1.13.3",
|
||||
"six>=1.10.0",
|
||||
]
|
||||
files = [
|
||||
{file = "tensorflow_probability-0.25.0-py2.py3-none-any.whl", hash = "sha256:f3f4d6431656c0122906888afe1b67b4400e82bd7f254b45b92e6c5b84ea8e3e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorstore"
|
||||
version = "0.1.71"
|
||||
|
@ -1899,6 +1957,21 @@ files = [
|
|||
{file = "tornado-6.4.2.tar.gz", hash = "sha256:92bad5b4746e9879fd7bf1eb21dce4e3fc5128d71601f80005afa39237ad620b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tqdm"
|
||||
version = "4.67.1"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Fast, Extensible Progress Meter"
|
||||
groups = ["default"]
|
||||
marker = "python_version >= \"3.12\" and python_version < \"3.13\""
|
||||
dependencies = [
|
||||
"colorama; platform_system == \"Windows\"",
|
||||
]
|
||||
files = [
|
||||
{file = "tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2"},
|
||||
{file = "tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "traitlets"
|
||||
version = "5.14.3"
|
||||
|
|
|
@ -5,7 +5,7 @@ description = "A solar car racing simulation library and GUI tool"
|
|||
authors = [
|
||||
{name = "saji", email = "saji@saji.dev"},
|
||||
]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0"]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
|
136
report/report.tex
Normal file
136
report/report.tex
Normal file
|
@ -0,0 +1,136 @@
|
|||
\documentclass[11pt]{article}
|
||||
|
||||
% Essential packages
|
||||
\usepackage[utf8]{inputenc}
|
||||
\usepackage[T1]{fontenc}
|
||||
\usepackage{amsmath,amssymb}
|
||||
\usepackage{graphicx}
|
||||
\usepackage[margin=1in]{geometry}
|
||||
\usepackage{hyperref}
|
||||
\usepackage{algorithm}
|
||||
\usepackage{algorithmic}
|
||||
\usepackage{float}
|
||||
\usepackage{booktabs}
|
||||
\usepackage{caption}
|
||||
\usepackage{subcaption}
|
||||
|
||||
% Custom commands
|
||||
\newcommand{\sectionheading}[1]{\noindent\textbf{#1}}
|
||||
|
||||
% Title and author information
|
||||
\title{\Large{Your Project Title: A Study in Optimal Control and Reinforcement Learning}}
|
||||
\author{Your Name\\
|
||||
Course Name\\
|
||||
Institution}
|
||||
\date{\today}
|
||||
|
||||
\begin{document}
|
||||
|
||||
\maketitle
|
||||
|
||||
\begin{abstract}
|
||||
Solar Racing is a competition with the goal of creating highly efficient solar-assisted electric vehicles. Effective solar racing
|
||||
requires awareness and complex decision making to determine optimal speeds to exploit the environmental conditions, such as winds,
|
||||
cloud cover, and changes in elevation. We present an environment modelled on the dynamics involved for a race, including generated
|
||||
elevation and wind profiles. The model uses the \texttt{gymnasium} interface to allow it to be used by a variety of algorithms.
|
||||
We demonstrate a method of designing reward functions for multi-objective problems. The environment shows to be solvable by modern
|
||||
reinforcement learning algorithms.
|
||||
\end{abstract}
|
||||
|
||||
\section{Introduction}
|
||||
Start with a broad context of your problem area in optimal control/reinforcement learning. Then narrow down to your specific focus. Include:
|
||||
\begin{itemize}
|
||||
\item Problem motivation
|
||||
\item Brief overview of existing approaches
|
||||
\item Your specific contributions
|
||||
\item Paper organization
|
||||
\end{itemize}
|
||||
|
||||
Solar racing was invented in the early 90s as a technology incubator for high-efficiency motor vehicles. The first solar races were speed
|
||||
focused, however a style of race that focused on minimal energy use within a given route was developed to push focus towards vehicle efficiency.
|
||||
The goal of these races is to arrive at a destination within a given time frame, while using as little grid (non-solar) energy as possible.
|
||||
Aerodynamic drag is one of the most significant sources of energy consumption, along with elevation changes. The simplest policy to meet
|
||||
the constraints of on-time arrival is:
|
||||
$$
|
||||
V_{\text{avg}} = \frac{D}{T}
|
||||
$$
|
||||
Where $D$ is the distance needed to travel, and $T$ is the maximum allowed time.
|
||||
|
||||
\section{Background}
|
||||
Provide necessary background on:
|
||||
\begin{itemize}
|
||||
\item Your specific application domain
|
||||
\item Relevant algorithms and methods
|
||||
\item Previous work in this area
|
||||
\end{itemize}
|
||||
|
||||
\section{Methodology}
|
||||
Describe your approach in detail:
|
||||
\begin{itemize}
|
||||
\item Problem formulation
|
||||
\item Algorithm description
|
||||
\item Implementation details
|
||||
\end{itemize}
|
||||
|
||||
% Example of how to include an algorithm
|
||||
\begin{algorithm}[H]
|
||||
\caption{Your Algorithm Name}
|
||||
\begin{algorithmic}[1]
|
||||
\STATE Initialize parameters
|
||||
\WHILE{not converged}
|
||||
\STATE Update step
|
||||
\ENDWHILE
|
||||
\RETURN Result
|
||||
\end{algorithmic}
|
||||
\end{algorithm}
|
||||
|
||||
\section{Experiments and Results}
|
||||
Present your findings:
|
||||
\begin{itemize}
|
||||
\item Experimental setup
|
||||
\item Results and analysis
|
||||
\item Comparison with baselines (if applicable)
|
||||
\end{itemize}
|
||||
|
||||
% Example of how to include figures
|
||||
\begin{figure}[H]
|
||||
\centering
|
||||
\caption{Description of your figure}
|
||||
\label{fig:example}
|
||||
\end{figure}
|
||||
|
||||
% Example of how to include tables
|
||||
\begin{table}[H]
|
||||
\centering
|
||||
\caption{Your Table Caption}
|
||||
\begin{tabular}{lcc}
|
||||
\toprule
|
||||
Method & Metric 1 & Metric 2 \\
|
||||
\midrule
|
||||
Approach 1 & Value & Value \\
|
||||
Approach 2 & Value & Value \\
|
||||
\bottomrule
|
||||
\end{tabular}
|
||||
\label{tab:results}
|
||||
\end{table}
|
||||
|
||||
\section{Discussion}
|
||||
Analyze your results:
|
||||
\begin{itemize}
|
||||
\item Interpretation of findings
|
||||
\item Limitations and challenges
|
||||
\item Potential improvements
|
||||
\end{itemize}
|
||||
|
||||
\section{Conclusion}
|
||||
Summarize your work:
|
||||
\begin{itemize}
|
||||
\item Key contributions
|
||||
\item Practical implications
|
||||
\item Future work directions
|
||||
\end{itemize}
|
||||
|
||||
\bibliography{references}
|
||||
\bibliographystyle{plain}
|
||||
|
||||
\end{document}
|
|
@ -115,7 +115,6 @@ def bldc_power_draw(torque, velocity, params: MotorParams):
|
|||
return total_power
|
||||
|
||||
|
||||
# @partial(jit, static_argnames=['resistance', 'kt', 'kv', 'vmax', 'Cf'])
|
||||
@jit
|
||||
def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
||||
bemf = velocity / kv
|
||||
|
@ -132,7 +131,6 @@ def bldc_torque(velocity, current_limit, resistance, kt, kv, vmax, Cf):
|
|||
@partial(
|
||||
jit,
|
||||
static_argnums=(
|
||||
1,
|
||||
2,
|
||||
),
|
||||
)
|
||||
|
@ -144,47 +142,12 @@ def battery_powerloss(current, cell_r, battery_shape: Tuple[int, int]):
|
|||
return jnp.sum(cell_Ploss)
|
||||
|
||||
|
||||
def forward(state, timestep, control, params: CarParams):
|
||||
# state is (position, time, energy)
|
||||
# control is velocity
|
||||
# timestep is >0 time to advance
|
||||
# params is the params dictionary.
|
||||
# returns the next state with (position', time + timestep, energy')
|
||||
# TODO: terrain, weather, solar
|
||||
|
||||
# determine the forces acting on the car.
|
||||
dragf = drag_force(control, params.frontal_area, params.drag_coeff, 1.184)
|
||||
rollf = rolling_force(params.mass, 0, params.rolling_coeff)
|
||||
hillforce = downslope_force(params.mass, 0)
|
||||
totalf = dragf + rollf + hillforce
|
||||
# determine the power needed to make this force
|
||||
tau = params.wheel_radius * totalf
|
||||
pdraw = bldc_power_draw(tau, control, params.motor)
|
||||
net_power = 0 - pdraw # watts aka j/s
|
||||
|
||||
# TODO: calculate battery-based power losses.
|
||||
# TODO: support regenerative braking when going downhill
|
||||
# TODO: delta x = cos(theta) * velocity * timestep
|
||||
|
||||
new_state = jnp.array(
|
||||
[
|
||||
state[0] + control * timestep,
|
||||
state[1] + timestep,
|
||||
state[2] + net_power * timestep,
|
||||
]
|
||||
)
|
||||
return new_state
|
||||
|
||||
|
||||
def make_environment(seed):
|
||||
"""Generate a race environment: terrain function, wind function, wrapped forward function."""
|
||||
key, subkey = jax.random.split(seed)
|
||||
wind = generate_wind_field(subkey, 10000, 600, spatial_scale=1000)
|
||||
key, subkey = jax.random.split(key)
|
||||
slope = fractal_noise_1d(subkey, 10000, scale=1200, height_scale=0.08)
|
||||
windkey, slopekey = jax.random.split(seed, 2)
|
||||
wind = generate_wind_field(windkey, 10000, 600, spatial_scale=1000)
|
||||
slope = fractal_noise_1d(slopekey, 10000, scale=1200, height_scale=0.08)
|
||||
elevation = jnp.cumsum(slope)
|
||||
# elevation = generate_elevation_profile(subkey, 10000, height_variation=40.0, scale=1200, octaves=5)
|
||||
# slope = jnp.arctan(jnp.diff(elevation, prepend=100.0)) # rise/run
|
||||
|
||||
return wind, elevation, slope
|
||||
|
||||
|
|
|
@ -4,11 +4,9 @@ import jax
|
|||
import jax.numpy as jnp
|
||||
from jax import jit
|
||||
from functools import partial
|
||||
from solarcarsim.physsim import drag_force, rolling_force, downslope_force, bldc_power_draw
|
||||
|
||||
|
||||
@partial(jit, static_argnames=["params"])
|
||||
def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
||||
def forward(state, control, delta_time, wind, elevation, slope, params: sim.CarParams):
|
||||
pos = jnp.astype(jnp.round(state[0]), "int32")
|
||||
time = jnp.astype(jnp.round(state[1]), "int32")
|
||||
theta = slope[pos]
|
||||
|
@ -23,23 +21,22 @@ def forwardv2(state, control, delta_time, wind, elevation, slope, params):
|
|||
totalf = dragf + rollf + hillforce
|
||||
# with the sum of forces, determine the needed torque at the wheels, and then power
|
||||
tau = params.wheel_radius * totalf
|
||||
pdraw = bldc_power_draw(tau, velocity, params.motor)
|
||||
pdraw = sim.bldc_power_draw(tau, velocity, params.motor)
|
||||
# determine the energy needed to do this power for the time step
|
||||
net_power = state[2] - delta_time * pdraw # joules
|
||||
|
||||
dpos = jnp.cos(theta) * velocity * delta_time
|
||||
dist_remaining = 10000.0 - dpos
|
||||
dpos = state[0] + jnp.cos(theta) * velocity * delta_time
|
||||
new_pos = jnp.maximum(dpos, 0)
|
||||
dist_remaining = 10000.0 - (state[0] + dpos)
|
||||
time_remaining = 600 - (state[1] + delta_time)
|
||||
return jnp.array(
|
||||
[dpos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
|
||||
[new_pos, state[1] + delta_time, net_power, dist_remaining, time_remaining]
|
||||
)
|
||||
|
||||
|
||||
def reward(state, prev_energy):
|
||||
progress = state[0] / 10000 * 100
|
||||
energy_usage = 10 * (state[2] - prev_energy) # current energy < previous energy.
|
||||
time_factor = (1.0 - (state[1] / 600)) * 50
|
||||
reward = progress + energy_usage + time_factor
|
||||
def reward(state, prev_state):
|
||||
reward = 0
|
||||
reward += state[0]/8000
|
||||
return reward
|
||||
|
||||
class SolarRaceV1(gym.Env):
|
||||
|
@ -53,7 +50,6 @@ class SolarRaceV1(gym.Env):
|
|||
# self._state = jnp.array([np.array([x], dtype="float32") for x in (0,0,0, 10000.0, 600.0)])
|
||||
self._state = jnp.array([[0],[0],[0],[10000.0], [600.0]])
|
||||
# self._state = jnp.array([0, 0,0,10000.0, 600.0])
|
||||
|
||||
def _vision_function(self):
|
||||
# extract the vision results.
|
||||
def slookup(x):
|
||||
|
@ -81,7 +77,7 @@ class SolarRaceV1(gym.Env):
|
|||
self._reset_sim(jax.random.key(seed))
|
||||
self._timestep = timestep
|
||||
self._car = car
|
||||
self._simstep = forwardv2
|
||||
self._simstep = forward
|
||||
self._simreward = reward
|
||||
|
||||
self.observation_space = gym.spaces.Dict(
|
||||
|
@ -108,15 +104,20 @@ class SolarRaceV1(gym.Env):
|
|||
def step(self, action):
|
||||
wind, elevation, slope = self._environment
|
||||
|
||||
old_energy = self._state[2]
|
||||
old_state = self._state
|
||||
|
||||
self._state = self._simstep(self._state, action, self._timestep,wind, elevation, slope, self._car)
|
||||
reward = self._simreward(self._state, old_energy)[0]
|
||||
reward = self._simreward(self._state, old_state)[0]
|
||||
terminated = False
|
||||
truncated = False
|
||||
if jnp.all(self._state[0] > 10000):
|
||||
if jnp.all(self._state[0] > 8000):
|
||||
reward += 500
|
||||
terminated = True
|
||||
if self._state[1] > 600:
|
||||
# we want the time to be as close to 600 as possible
|
||||
reward -= 600 - self._state[1][0]
|
||||
reward += 1e-6 * (self._state[2][0]) # net energy is negative.
|
||||
if jnp.all(self._state[1] > 600):
|
||||
reward -= 50000
|
||||
truncated = True
|
||||
|
||||
return self._get_obs(), reward, terminated, truncated, {}
|
||||
|
|
|
@ -1,14 +1,190 @@
|
|||
"""Second-generation simulator. More functional, cleaner code, faster"""
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Optional, Tuple, Union, Dict, Any
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import chex
|
||||
from flax import struct
|
||||
from jax import lax
|
||||
from gymnax.environments import environment
|
||||
from gymnax.environments import spaces
|
||||
|
||||
from solarcarsim.physsim import CarParams, fractal_noise_1d
|
||||
import solarcarsim.physsim as sim
|
||||
|
||||
|
||||
class SimState(NamedTuple):
|
||||
position: float
|
||||
time: float
|
||||
energy: float
|
||||
distance_remaining: float
|
||||
time_remaining: float
|
||||
@struct.dataclass
|
||||
class SimState(environment.EnvState):
|
||||
position: jnp.ndarray
|
||||
velocity: jnp.ndarray
|
||||
realtime: jnp.ndarray
|
||||
energy: jnp.ndarray
|
||||
# distance_remaining: jnp.ndarray
|
||||
# time_remaining: jnp.ndarray
|
||||
slope: jnp.ndarray
|
||||
|
||||
|
||||
@struct.dataclass
|
||||
class SimParams(environment.EnvParams):
|
||||
car: CarParams = CarParams()
|
||||
goal_time: int = 600
|
||||
goal_dist: int = 8000
|
||||
map_size: int = 10000
|
||||
time_step: float = 1.0
|
||||
terrain_lookahead: int = 100
|
||||
# skip wind for now
|
||||
|
||||
|
||||
class Snax(environment.Environment[SimState, SimParams]):
|
||||
"""JAX version of the solar race simulator"""
|
||||
|
||||
@property
|
||||
def default_params(self) -> SimParams:
|
||||
return SimParams()
|
||||
|
||||
def action_space(self, params: Optional[SimParams] = None):
|
||||
return spaces.Box(low=-1.0, high=1.0, shape=(1,))
|
||||
|
||||
def observation_space(self, params: Optional[SimParams] = None) -> spaces.Box:
|
||||
if params is None:
|
||||
params = self.default_params
|
||||
# needs to be a box. it will be [pos, time, energy, dist_to_goal, time_remaining, terrain0, terrain1]
|
||||
shape = 5 + params.terrain_lookahead
|
||||
low = jnp.array(
|
||||
[0, 0, -1e11, 0, 0] + [-1.0] * params.terrain_lookahead, dtype=jnp.float32
|
||||
)
|
||||
high = jnp.array(
|
||||
[params.map_size, params.goal_time, 0, params.goal_dist, params.goal_time]
|
||||
+ [1.0] * params.terrain_lookahead,
|
||||
dtype=jnp.float32,
|
||||
)
|
||||
return spaces.Box(low, high, shape=(shape,))
|
||||
# return spaces.Dict(
|
||||
# {
|
||||
# "position": spaces.Box(0.0, params.map_size, (), jnp.float32),
|
||||
# "realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
|
||||
# "energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
|
||||
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
|
||||
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
|
||||
# "upcoming_terrain": spaces.Box(
|
||||
# -1.0, 1.0, shape=(100,), dtype=jnp.float32
|
||||
# ),
|
||||
# # skip wind for now
|
||||
# }
|
||||
# )
|
||||
|
||||
def state_space(self, params: Optional[SimParams] = None) -> spaces.Dict:
|
||||
if params is None:
|
||||
params = self.default_params
|
||||
return spaces.Dict(
|
||||
{
|
||||
"position": spaces.Box(0.0, params.map_size, (), jnp.float32),
|
||||
"realtime": spaces.Box(0.0, params.goal_time + 100, (), jnp.float32),
|
||||
"energy": spaces.Box(-1e11, 0.0, (), jnp.float32),
|
||||
# "dist_to_goal": spaces.Box(0.0, params.goal_dist, (), jnp.float32),
|
||||
# "time_remaining": spaces.Box(0.0, params.goal_time, (), jnp.float32),
|
||||
"slope": spaces.Box(
|
||||
-1.0, 1.0, shape=(params.map_size,), dtype=jnp.float32
|
||||
),
|
||||
"time": spaces.Discrete(int(params.goal_time / params.time_step)),
|
||||
}
|
||||
)
|
||||
|
||||
def reset_env(
|
||||
self, key: chex.PRNGKey, params: Optional[SimParams] = None
|
||||
) -> Tuple[chex.Array, SimState]:
|
||||
if params is None:
|
||||
params = self.default_params
|
||||
slope = fractal_noise_1d(key, 10000, scale=1200, height_scale=0.08)
|
||||
init_state = SimState(
|
||||
position=jnp.array(0.0),
|
||||
velocity=jnp.array(0.0),
|
||||
time=0,
|
||||
realtime=jnp.array(0.0),
|
||||
energy=jnp.array(0.0),
|
||||
# distance_remaining=jnp.array(params.goal_dist),
|
||||
# time_remaining=jnp.array(params.goal_time),
|
||||
slope=slope,
|
||||
)
|
||||
return self.get_obs(init_state, key, params), init_state
|
||||
|
||||
def get_obs(
|
||||
self, state: SimState, key: chex.PRNGKey, params: SimParams
|
||||
) -> chex.Array:
|
||||
if params is None:
|
||||
params = self.default_params
|
||||
|
||||
# get rounded position from state
|
||||
pos_int = jnp.astype(state.position, jnp.int32)
|
||||
|
||||
terrain_view = jax.lax.dynamic_slice(state.slope, (pos_int,), (100,))
|
||||
dist_to_goal = jnp.abs(params.goal_dist - state.position)
|
||||
time_remaining = jnp.abs(params.goal_time - state.realtime)
|
||||
main_state = jnp.array(
|
||||
[state.position, state.realtime, state.energy, dist_to_goal, time_remaining]
|
||||
)
|
||||
return jnp.concat([main_state, terrain_view]).squeeze()
|
||||
|
||||
def step_env(
|
||||
self,
|
||||
key: chex.PRNGKey,
|
||||
state: SimState,
|
||||
action: Union[int, float, chex.Array],
|
||||
params: SimParams,
|
||||
) -> Tuple[chex.Array, SimState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
|
||||
pos = jnp.astype(state.position, jnp.int32)
|
||||
theta = state.slope[pos]
|
||||
velocity = jnp.array([action * params.car.max_speed]).squeeze()
|
||||
dragf = sim.drag_force(
|
||||
velocity, params.car.frontal_area, params.car.drag_coeff, 1.184
|
||||
)
|
||||
rollf = sim.rolling_force(params.car.mass, theta, params.car.rolling_coeff)
|
||||
hillf = sim.downslope_force(params.car.mass, theta)
|
||||
total_f = dragf + rollf + hillf
|
||||
tau = params.car.wheel_radius * total_f / params.car.n_motors
|
||||
p_draw = (
|
||||
sim.bldc_power_draw(tau, velocity, params.car.motor) * params.car.n_motors
|
||||
)
|
||||
|
||||
new_energy = state.energy - params.time_step * p_draw
|
||||
new_position = state.position + jnp.cos(theta) * velocity * params.time_step
|
||||
new_state = SimState(
|
||||
position=new_position.squeeze(),
|
||||
velocity=velocity.squeeze(),
|
||||
realtime=state.realtime + params.time_step,
|
||||
energy=new_energy.squeeze(),
|
||||
slope=state.slope,
|
||||
time=state.time + 1,
|
||||
)
|
||||
|
||||
# compute reward
|
||||
# reward = new_state.position / params.goal_dist
|
||||
# if new_state.position >= params.goal_dist:
|
||||
# reward += 100
|
||||
# reward += params.goal_time - new_state.realtime
|
||||
# # penalize energy use
|
||||
# reward += 1e-7 * new_state.energy # energy is negative
|
||||
# if (
|
||||
# new_state.realtime >= params.goal_time
|
||||
# or new_state.time > params.max_steps_in_episode
|
||||
# ):
|
||||
# reward -= 500
|
||||
|
||||
# we have to vectorize that.
|
||||
reward = new_state.position / params.goal_dist + \
|
||||
(new_state.position >= params.goal_dist) * (100 + params.goal_time - new_state.realtime + 1e-7*new_state.energy) + \
|
||||
(new_state.realtime >= params.goal_time) * -500
|
||||
reward = reward.squeeze()
|
||||
terminal = self.is_terminal(state, params)
|
||||
return (
|
||||
lax.stop_gradient(self.get_obs(new_state, key, params)),
|
||||
lax.stop_gradient(new_state),
|
||||
reward,
|
||||
terminal,
|
||||
{},
|
||||
)
|
||||
|
||||
def is_terminal(self, state: SimState, params: SimParams) -> jnp.ndarray:
|
||||
finish = state.position >= params.goal_dist
|
||||
timeout = state.time >= params.max_steps_in_episode
|
||||
return jnp.logical_or(finish, timeout).squeeze()
|
||||
|
|
Loading…
Reference in a new issue