things
This commit is contained in:
parent
7ad7070129
commit
ce827bd337
|
@ -5,8 +5,8 @@ description = "A solar car racing simulation library and GUI tool"
|
|||
authors = [
|
||||
{name = "saji", email = "saji@saji.dev"},
|
||||
]
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0", "tyro>=0.9.2", "tensorboard>=2.18.0", "distrax>=0.1.5"]
|
||||
requires-python = ">=3.10,<3.13"
|
||||
dependencies = ["pyqtgraph>=0.13.7", "jax[cuda12]>=0.4.37", "pytest>=8.3.3", "pyside6>=6.8.0.2", "matplotlib>=3.9.2", "gymnasium[jax]>=1.0.0", "pyvista>=0.44.2", "pyvistaqt>=0.11.1", "stable-baselines3>=2.4.0", "gymnax>=0.0.8", "sbx-rl>=0.18.0", "tyro>=0.9.2", "tensorboard>=2.18.0", "distrax>=0.1.5", "satpy>=0.53.0", "cartopy>=0.24.1", "xarray>=2024.11.0"]
|
||||
requires-python = ">=3.12,<3.13"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
||||
|
|
BIN
report/PPO_results.pdf
Normal file
BIN
report/PPO_results.pdf
Normal file
Binary file not shown.
BIN
report/environment.pdf
Normal file
BIN
report/environment.pdf
Normal file
Binary file not shown.
93
report/references.bib
Normal file
93
report/references.bib
Normal file
|
@ -0,0 +1,93 @@
|
|||
@article{lu2022discovered,
|
||||
title={Discovered policy optimisation},
|
||||
author={Lu, Chris and Kuba, Jakub and Letcher, Alistair and Metz, Luke and Schroeder de Witt, Christian and Foerster, Jakob},
|
||||
journal={Advances in Neural Information Processing Systems},
|
||||
volume={35},
|
||||
pages={16455--16468},
|
||||
year={2022}
|
||||
}
|
||||
@software{jax2018github,
|
||||
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
|
||||
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
|
||||
url = {http://github.com/jax-ml/jax},
|
||||
version = {0.3.13},
|
||||
year = {2018},
|
||||
}
|
||||
|
||||
@article{doi:10.1518/106480407X312374,
|
||||
author = {Antony Hilliard and Greg A. Jamieson},
|
||||
title ={Winning Solar Races with Interface Design},
|
||||
|
||||
journal = {Ergonomics in Design},
|
||||
volume = {16},
|
||||
number = {2},
|
||||
pages = {6-11},
|
||||
year = {2008},
|
||||
doi = {10.1518/106480407X312374},
|
||||
URL = {https://doi.org/10.1518/106480407X312374},
|
||||
eprint = {https://doi.org/10.1518/106480407X312374},
|
||||
abstract = { Solar car racing is both a highly competitive sport and a test arena for tomorrow's renewable-energy applications. This article describes the design of a graphical interface for solar car race strategy planning. The coupling, unpredictability, and size of the solar car racing environment present tough challenges to racing strategy teams. Representation-aiding techniques provide a useful approach for managing this complexity, translating difficult problems into visual analogues that are better suited to human information processing. }
|
||||
}
|
||||
|
||||
@Article{heuristicsolar,
|
||||
AUTHOR = {Betancur, Esteban and Osorio-Gómez, Gilberto and Rivera, Juan Carlos},
|
||||
TITLE = {Heuristic Optimization for the Energy Management and Race Strategy of a Solar Car},
|
||||
JOURNAL = {Sustainability},
|
||||
VOLUME = {9},
|
||||
YEAR = {2017},
|
||||
NUMBER = {10},
|
||||
ARTICLE-NUMBER = {1576},
|
||||
URL = {https://www.mdpi.com/2071-1050/9/10/1576},
|
||||
ISSN = {2071-1050},
|
||||
ABSTRACT = {Solar cars are known for their energy efficiency, and different races are designed to measure their performance under certain conditions. For these races, in addition to an efficient vehicle, a competition strategy is required to define the optimal speed, with the objective of finishing the race in the shortest possible time using the energy available. Two heuristic optimization methods are implemented to solve this problem, a convergence and performance comparison of both methods is presented. A computational model of the race is developed, including energy input, consumption and storage systems. Based on this model, the different optimization methods are tested on the optimization of the World Solar Challenge 2015 race strategy under two different environmental conditions. A suitable method for solar car racing strategy is developed with the vehicle specifications taken as an independent input to permit the simulation of different solar or electric vehicles.},
|
||||
DOI = {10.3390/su9101576}
|
||||
}
|
||||
|
||||
|
||||
@misc{gymnasium,
|
||||
title={Gymnasium: A Standard Interface for Reinforcement Learning Environments},
|
||||
author={Mark Towers and Ariel Kwiatkowski and Jordan Terry and John U. Balis and Gianluca De Cola and Tristan Deleu and Manuel Goulão and Andreas Kallinteris and Markus Krimmel and Arjun KG and Rodrigo Perez-Vicente and Andrea Pierré and Sander Schulhoff and Jun Jet Tai and Hannah Tan and Omar G. Younis},
|
||||
year={2024},
|
||||
eprint={2407.17032},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/2407.17032},
|
||||
}
|
||||
@inproceedings{Ansel_PyTorch_2_Faster_2024,
|
||||
author = {Ansel, Jason and Yang, Edward and He, Horace and Gimelshein, Natalia and Jain, Animesh and Voznesensky, Michael and Bao, Bin and Bell, Peter and Berard, David and Burovski, Evgeni and Chauhan, Geeta and Chourdia, Anjali and Constable, Will and Desmaison, Alban and DeVito, Zachary and Ellison, Elias and Feng, Will and Gong, Jiong and Gschwind, Michael and Hirsh, Brian and Huang, Sherlock and Kalambarkar, Kshiteej and Kirsch, Laurent and Lazos, Michael and Lezcano, Mario and Liang, Yanbo and Liang, Jason and Lu, Yinghai and Luk, CK and Maher, Bert and Pan, Yunjie and Puhrsch, Christian and Reso, Matthias and Saroufim, Mark and Siraichi, Marcos Yukio and Suk, Helen and Suo, Michael and Tillet, Phil and Wang, Eikan and Wang, Xiaodong and Wen, William and Zhang, Shunting and Zhao, Xu and Zhou, Keren and Zou, Richard and Mathews, Ajit and Chanan, Gregory and Wu, Peng and Chintala, Soumith},
|
||||
booktitle = {29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems, Volume 2 (ASPLOS '24)},
|
||||
doi = {10.1145/3620665.3640366},
|
||||
month = apr,
|
||||
publisher = {ACM},
|
||||
title = {{PyTorch 2: Faster Machine Learning Through Dynamic Python Bytecode Transformation and Graph Compilation}},
|
||||
url = {https://pytorch.org/assets/pytorch2-2.pdf},
|
||||
year = {2024}
|
||||
}
|
||||
|
||||
@article{stable-baselines3,
|
||||
author = {Antonin Raffin and Ashley Hill and Adam Gleave and Anssi Kanervisto and Maximilian Ernestus and Noah Dormann},
|
||||
title = {Stable-Baselines3: Reliable Reinforcement Learning Implementations},
|
||||
journal = {Journal of Machine Learning Research},
|
||||
year = {2021},
|
||||
volume = {22},
|
||||
number = {268},
|
||||
pages = {1-8},
|
||||
url = {http://jmlr.org/papers/v22/20-1364.html}
|
||||
}
|
||||
|
||||
@software{gymnax2022github,
|
||||
author = {Robert Tjarko Lange},
|
||||
title = {{gymnax}: A {JAX}-based Reinforcement Learning Environment Library},
|
||||
url = {http://github.com/RobertTLange/gymnax},
|
||||
version = {0.0.4},
|
||||
year = {2022},
|
||||
}
|
||||
@misc{proximalpolicyoptimization,
|
||||
title={Proximal Policy Optimization Algorithms},
|
||||
author={John Schulman and Filip Wolski and Prafulla Dhariwal and Alec Radford and Oleg Klimov},
|
||||
year={2017},
|
||||
eprint={1707.06347},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.LG},
|
||||
url={https://arxiv.org/abs/1707.06347},
|
||||
}
|
|
@ -1,368 +0,0 @@
|
|||
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
import flax
|
||||
import flax.linen as nn
|
||||
import gymnasium as gym
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import tyro
|
||||
from flax.training.train_state import TrainState
|
||||
from stable_baselines3.common.buffers import ReplayBuffer
|
||||
from torch.utils.tensorboard.writer import SummaryWriter
|
||||
|
||||
|
||||
@dataclass
|
||||
class Args:
|
||||
exp_name: str = os.path.basename(__file__)[: -len(".py")]
|
||||
"""the name of this experiment"""
|
||||
seed: int = 1
|
||||
"""seed of the experiment"""
|
||||
track: bool = False
|
||||
"""if toggled, this experiment will be tracked with Weights and Biases"""
|
||||
wandb_project_name: str = "cleanRL"
|
||||
"""the wandb's project name"""
|
||||
wandb_entity: str = None
|
||||
"""the entity (team) of wandb's project"""
|
||||
capture_video: bool = False
|
||||
"""whether to capture videos of the agent performances (check out `videos` folder)"""
|
||||
save_model: bool = False
|
||||
"""whether to save model into the `runs/{run_name}` folder"""
|
||||
upload_model: bool = False
|
||||
"""whether to upload the saved model to huggingface"""
|
||||
hf_entity: str = ""
|
||||
"""the user or org name of the model repository from the Hugging Face Hub"""
|
||||
|
||||
# Algorithm specific arguments
|
||||
env_id: str = "MountainCarContinuous-v0"
|
||||
"""the id of the environment"""
|
||||
total_timesteps: int = 1000000
|
||||
"""total timesteps of the experiments"""
|
||||
learning_rate: float = 3e-4
|
||||
"""the learning rate of the optimizer"""
|
||||
buffer_size: int = int(1e6)
|
||||
"""the replay memory buffer size"""
|
||||
gamma: float = 0.99
|
||||
"""the discount factor gamma"""
|
||||
tau: float = 0.005
|
||||
"""target smoothing coefficient (default: 0.005)"""
|
||||
batch_size: int = 256
|
||||
"""the batch size of sample from the reply memory"""
|
||||
policy_noise: float = 0.2
|
||||
"""the scale of policy noise"""
|
||||
exploration_noise: float = 0.1
|
||||
"""the scale of exploration noise"""
|
||||
learning_starts: int = 25e3
|
||||
"""timestep to start learning"""
|
||||
policy_frequency: int = 2
|
||||
"""the frequency of training policy (delayed)"""
|
||||
noise_clip: float = 0.5
|
||||
"""noise clip parameter of the Target Policy Smoothing Regularization"""
|
||||
|
||||
|
||||
def make_env(env_id, seed, idx, capture_video, run_name):
|
||||
def thunk():
|
||||
if capture_video and idx == 0:
|
||||
env = gym.make(env_id, render_mode="rgb_array")
|
||||
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
env = gym.wrappers.RecordEpisodeStatistics(env)
|
||||
env.action_space.seed(seed)
|
||||
return env
|
||||
|
||||
return thunk
|
||||
|
||||
|
||||
# ALGO LOGIC: initialize agent here:
|
||||
class QNetwork(nn.Module):
|
||||
@nn.compact
|
||||
def __call__(self, x: jnp.ndarray, a: jnp.ndarray):
|
||||
x = jnp.concatenate([x, a], -1)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(1)(x)
|
||||
return x
|
||||
|
||||
|
||||
class Actor(nn.Module):
|
||||
action_dim: int
|
||||
action_scale: jnp.ndarray
|
||||
action_bias: jnp.ndarray
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(256)(x)
|
||||
x = nn.relu(x)
|
||||
x = nn.Dense(self.action_dim)(x)
|
||||
x = nn.tanh(x)
|
||||
x = x * self.action_scale + self.action_bias
|
||||
return x
|
||||
|
||||
|
||||
class TrainState(TrainState):
|
||||
target_params: flax.core.FrozenDict
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import stable_baselines3 as sb3
|
||||
|
||||
if sb3.__version__ < "2.0":
|
||||
raise ValueError(
|
||||
"""Ongoing migration: run the following command to install the new dependencies:
|
||||
poetry run pip install "stable_baselines3==2.0.0a1"
|
||||
"""
|
||||
)
|
||||
args = tyro.cli(Args)
|
||||
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
|
||||
if args.track:
|
||||
import wandb
|
||||
|
||||
wandb.init(
|
||||
project=args.wandb_project_name,
|
||||
entity=args.wandb_entity,
|
||||
sync_tensorboard=True,
|
||||
config=vars(args),
|
||||
name=run_name,
|
||||
monitor_gym=True,
|
||||
save_code=True,
|
||||
)
|
||||
writer = SummaryWriter(f"runs/{run_name}")
|
||||
writer.add_text(
|
||||
"hyperparameters",
|
||||
"|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: seeding
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
key = jax.random.PRNGKey(args.seed)
|
||||
key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4)
|
||||
|
||||
# env setup
|
||||
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
|
||||
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
|
||||
|
||||
max_action = float(envs.single_action_space.high[0])
|
||||
envs.single_observation_space.dtype = np.float32
|
||||
rb = ReplayBuffer(
|
||||
args.buffer_size,
|
||||
envs.single_observation_space,
|
||||
envs.single_action_space,
|
||||
device="cpu",
|
||||
handle_timeout_termination=False,
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: start the game
|
||||
obs, _ = envs.reset(seed=args.seed)
|
||||
|
||||
actor = Actor(
|
||||
action_dim=np.prod(envs.single_action_space.shape),
|
||||
action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0),
|
||||
action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0),
|
||||
)
|
||||
actor_state = TrainState.create(
|
||||
apply_fn=actor.apply,
|
||||
params=actor.init(actor_key, obs),
|
||||
target_params=actor.init(actor_key, obs),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
qf = QNetwork()
|
||||
qf1_state = TrainState.create(
|
||||
apply_fn=qf.apply,
|
||||
params=qf.init(qf1_key, obs, envs.action_space.sample()),
|
||||
target_params=qf.init(qf1_key, obs, envs.action_space.sample()),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
qf2_state = TrainState.create(
|
||||
apply_fn=qf.apply,
|
||||
params=qf.init(qf2_key, obs, envs.action_space.sample()),
|
||||
target_params=qf.init(qf2_key, obs, envs.action_space.sample()),
|
||||
tx=optax.adam(learning_rate=args.learning_rate),
|
||||
)
|
||||
actor.apply = jax.jit(actor.apply)
|
||||
qf.apply = jax.jit(qf.apply)
|
||||
|
||||
@jax.jit
|
||||
def update_critic(
|
||||
actor_state: TrainState,
|
||||
qf1_state: TrainState,
|
||||
qf2_state: TrainState,
|
||||
observations: np.ndarray,
|
||||
actions: np.ndarray,
|
||||
next_observations: np.ndarray,
|
||||
rewards: np.ndarray,
|
||||
terminations: np.ndarray,
|
||||
key: jnp.ndarray,
|
||||
):
|
||||
# TODO Maybe pre-generate a lot of random keys
|
||||
# also check https://jax.readthedocs.io/en/latest/jax.random.html
|
||||
key, noise_key = jax.random.split(key, 2)
|
||||
clipped_noise = (
|
||||
jnp.clip(
|
||||
(jax.random.normal(noise_key, actions.shape) * args.policy_noise),
|
||||
-args.noise_clip,
|
||||
args.noise_clip,
|
||||
)
|
||||
* actor.action_scale
|
||||
)
|
||||
next_state_actions = jnp.clip(
|
||||
actor.apply(actor_state.target_params, next_observations) + clipped_noise,
|
||||
envs.single_action_space.low,
|
||||
envs.single_action_space.high,
|
||||
)
|
||||
qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1)
|
||||
qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1)
|
||||
min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target)
|
||||
next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1)
|
||||
|
||||
def mse_loss(params):
|
||||
qf_a_values = qf.apply(params, observations, actions).squeeze()
|
||||
return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean()
|
||||
|
||||
(qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params)
|
||||
(qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params)
|
||||
qf1_state = qf1_state.apply_gradients(grads=grads1)
|
||||
qf2_state = qf2_state.apply_gradients(grads=grads2)
|
||||
|
||||
return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key
|
||||
|
||||
@jax.jit
|
||||
def update_actor(
|
||||
actor_state: TrainState,
|
||||
qf1_state: TrainState,
|
||||
qf2_state: TrainState,
|
||||
observations: np.ndarray,
|
||||
):
|
||||
def actor_loss(params):
|
||||
return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean()
|
||||
|
||||
actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params)
|
||||
actor_state = actor_state.apply_gradients(grads=grads)
|
||||
actor_state = actor_state.replace(
|
||||
target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau)
|
||||
)
|
||||
|
||||
qf1_state = qf1_state.replace(
|
||||
target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau)
|
||||
)
|
||||
qf2_state = qf2_state.replace(
|
||||
target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau)
|
||||
)
|
||||
return actor_state, (qf1_state, qf2_state), actor_loss_value
|
||||
|
||||
start_time = time.time()
|
||||
for global_step in range(args.total_timesteps):
|
||||
# ALGO LOGIC: put action logic here
|
||||
if global_step < args.learning_starts:
|
||||
actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
|
||||
else:
|
||||
actions = actor.apply(actor_state.params, obs)
|
||||
actions = np.array(
|
||||
[
|
||||
(
|
||||
jax.device_get(actions)[0]
|
||||
+ np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape)
|
||||
).clip(envs.single_action_space.low, envs.single_action_space.high)
|
||||
]
|
||||
)
|
||||
|
||||
# TRY NOT TO MODIFY: execute the game and log data.
|
||||
next_obs, rewards, terminations, truncations, infos = envs.step(actions)
|
||||
|
||||
# TRY NOT TO MODIFY: record rewards for plotting purposes
|
||||
if "final_info" in infos:
|
||||
for info in infos["final_info"]:
|
||||
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
|
||||
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
|
||||
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
|
||||
break
|
||||
|
||||
# TRY NOT TO MODIFY: save data to replay buffer; handle `final_observation`
|
||||
real_next_obs = next_obs.copy()
|
||||
for idx, trunc in enumerate(truncations):
|
||||
if trunc:
|
||||
real_next_obs[idx] = infos["final_observation"][idx]
|
||||
rb.add(obs, real_next_obs, actions, rewards, terminations, infos)
|
||||
|
||||
# TRY NOT TO MODIFY: CRUCIAL step easy to overlook
|
||||
obs = next_obs
|
||||
|
||||
# ALGO LOGIC: training.
|
||||
if global_step > args.learning_starts:
|
||||
data = rb.sample(args.batch_size)
|
||||
|
||||
(qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic(
|
||||
actor_state,
|
||||
qf1_state,
|
||||
qf2_state,
|
||||
data.observations.numpy(),
|
||||
data.actions.numpy(),
|
||||
data.next_observations.numpy(),
|
||||
data.rewards.flatten().numpy(),
|
||||
data.dones.flatten().numpy(),
|
||||
key,
|
||||
)
|
||||
|
||||
if global_step % args.policy_frequency == 0:
|
||||
actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor(
|
||||
actor_state,
|
||||
qf1_state,
|
||||
qf2_state,
|
||||
data.observations.numpy(),
|
||||
)
|
||||
|
||||
if global_step % 100 == 0:
|
||||
writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step)
|
||||
writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step)
|
||||
writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step)
|
||||
writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step)
|
||||
writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step)
|
||||
print("SPS:", int(global_step / (time.time() - start_time)))
|
||||
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
|
||||
|
||||
if args.save_model:
|
||||
model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
|
||||
with open(model_path, "wb") as f:
|
||||
f.write(
|
||||
flax.serialization.to_bytes(
|
||||
[
|
||||
actor_state.params,
|
||||
qf1_state.params,
|
||||
qf2_state.params,
|
||||
]
|
||||
)
|
||||
)
|
||||
print(f"model saved to {model_path}")
|
||||
from cleanrl_utils.evals.td3_jax_eval import evaluate
|
||||
|
||||
episodic_returns = evaluate(
|
||||
model_path,
|
||||
make_env,
|
||||
args.env_id,
|
||||
eval_episodes=10,
|
||||
run_name=f"{run_name}-eval",
|
||||
Model=(Actor, QNetwork),
|
||||
exploration_noise=args.exploration_noise,
|
||||
)
|
||||
for idx, episodic_return in enumerate(episodic_returns):
|
||||
writer.add_scalar("eval/episodic_return", episodic_return, idx)
|
||||
|
||||
if args.upload_model:
|
||||
from cleanrl_utils.huggingface import push_to_hub
|
||||
|
||||
repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
|
||||
repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
|
||||
push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval")
|
||||
|
||||
envs.close()
|
||||
writer.close()
|
|
@ -1,121 +0,0 @@
|
|||
# models to generate different environments that the car can drive in.
|
||||
# This includes terrain, clouds, wind, solar conditions, and the route along the terrain.
|
||||
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
import pyqtgraph as pg
|
||||
from functools import partial
|
||||
from pyqtgraph.Qt import QtCore, QtGui
|
||||
from typing import NamedTuple
|
||||
import matplotlib.pyplot as plt
|
||||
import sys
|
||||
|
||||
|
||||
class TerrainParams(NamedTuple):
|
||||
size: int = 256
|
||||
octaves: int = 6
|
||||
persistence: float = 0.5
|
||||
lacunarity: float = 2.0
|
||||
seed: int = 42
|
||||
|
||||
|
||||
def lerp(a, b, t):
|
||||
# assume a and b are pairs of numbers
|
||||
x = jnp.array([0,1])
|
||||
f = jnp.array([a,b])
|
||||
return jnp.interp(t, x, f)
|
||||
|
||||
# @partial(jax.jit, static_argnums=(2,))
|
||||
# def _make_noise_layer(key: random.PRNGKey, frequency: float, shape) -> jnp.ndarray:
|
||||
#
|
||||
# noise = random.normal(key, shape)
|
||||
# # create the grid.
|
||||
# x = jnp.linspace(0, shape[0] - 1, )
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def generate_permutation():
|
||||
"""Generate a permutation table."""
|
||||
p = jnp.arange(256, dtype=jnp.int32)
|
||||
return jnp.concatenate([p, p])
|
||||
|
||||
@jax.jit
|
||||
def fade(t):
|
||||
"""Fade function for smooth interpolation."""
|
||||
return t * t * t * (t * (t * 6 - 15) + 10)
|
||||
|
||||
@jax.jit
|
||||
def lerp(t, a, b):
|
||||
"""Linear interpolation."""
|
||||
return a + t * (b - a)
|
||||
|
||||
@jax.jit
|
||||
def grad(hash, x, y):
|
||||
"""Calculate gradient."""
|
||||
h = hash & 15
|
||||
grad_x = jnp.where(h < 8, x, y)
|
||||
grad_y = jnp.where(h < 4, y, jnp.where((h == 12) | (h == 14), x, y))
|
||||
return jnp.where(h & 1, -grad_x, grad_x) + jnp.where(h & 2, -grad_y, grad_y)
|
||||
|
||||
def perlin(pos):
|
||||
""" Perlin noise. Shape (N) where N = n_dims (2,3) """
|
||||
|
||||
cellpos = pos % 1.0 # get the position inside the cell
|
||||
|
||||
upos = fade(pos)
|
||||
|
||||
@jax.jit
|
||||
def perlin_noise_2d(x, y, p):
|
||||
"""Generate 2D Perlin noise value."""
|
||||
# Floor coordinates
|
||||
xi = jnp.floor(x).astype(jnp.int32) & 255
|
||||
yi = jnp.floor(y).astype(jnp.int32) & 255
|
||||
|
||||
# Fractional coordinates
|
||||
xf = x - jnp.floor(x)
|
||||
yf = y - jnp.floor(y)
|
||||
|
||||
# Fade curves
|
||||
u = fade(xf)
|
||||
v = fade(yf)
|
||||
|
||||
# Hash coordinates of cube corners
|
||||
aa = p[p[xi] + yi]
|
||||
ab = p[p[xi] + yi + 1]
|
||||
ba = p[p[xi + 1] + yi]
|
||||
bb = p[p[xi + 1] + yi + 1]
|
||||
|
||||
# Gradients
|
||||
g1 = grad(aa, xf, yf)
|
||||
g2 = grad(ba, xf - 1, yf)
|
||||
g3 = grad(ab, xf, yf - 1)
|
||||
g4 = grad(bb, xf - 1, yf - 1)
|
||||
|
||||
# Interpolate
|
||||
x1 = lerp(u, g1, g2)
|
||||
x2 = lerp(u, g3, g4)
|
||||
return lerp(v, x1, x2)
|
||||
|
||||
def generate_noise_grid(width, height, scale=50.0):
|
||||
"""Generate a grid of Perlin noise values."""
|
||||
p = generate_permutation()
|
||||
# compute the gradient grid.
|
||||
gradgrid =
|
||||
x = jnp.linspace(0, width/scale, width)
|
||||
y = jnp.linspace(0, height/scale, height)
|
||||
X, Y = jnp.meshgrid(x, y)
|
||||
return perlin_noise_2d(X, Y, p)
|
||||
|
||||
# Example usage:
|
||||
key = jax.random.PRNGKey(23)
|
||||
noise = generate_noise_grid(256, 256)
|
||||
plt.imshow(noise)
|
||||
plt.savefig("output.png")
|
||||
|
||||
|
||||
def GymV1():
|
||||
""" Makes a version 1 gym - simply an elevation profile. """
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# start up the main Qt application and load plugins
|
||||
|
||||
from PySide6.QtCore import QCoreApplication, QObject, Signal, QTimer
|
||||
from PySide6.QtWidgets import QApplication, QToolBar, QWidget, QMainWindow, QPlainTextEdit
|
||||
import logging
|
||||
from logging import LogRecord
|
||||
import sys
|
||||
from pyqtgraph.dockarea import Dock, DockArea
|
||||
|
||||
from solarcarsim.satellaview.ui import SatellaUI
|
||||
|
||||
|
||||
class LogHandler(logging.Handler):
|
||||
class Carrier(QObject):
|
||||
# We need this because both QObject and logging.Handler need an `emit` method
|
||||
# and they collide.
|
||||
appendplaintext = Signal(str)
|
||||
|
||||
def __init__(self, parent) -> None:
|
||||
super().__init__()
|
||||
self.widget = QPlainTextEdit(parent=parent)
|
||||
self.carrier = self.Carrier(parent=parent)
|
||||
self.widget.setReadOnly(True)
|
||||
self.carrier.appendplaintext.connect(self.widget.appendPlainText) # type: ignore
|
||||
self.dock = Dock("Logger", widget=self.widget)
|
||||
|
||||
def emit(self, record: LogRecord) -> None:
|
||||
msg = self.format(record)
|
||||
self.carrier.appendplaintext.emit(msg) # type: ignore
|
||||
|
||||
|
||||
class LogViewer(QPlainTextEdit):
|
||||
def __init__(self, parent=None):
|
||||
super().__init__(parent)
|
||||
|
||||
|
||||
class App():
|
||||
""" Core application. Sets up logger, main window (toolbar, etc),
|
||||
and loads plugins."""
|
||||
@staticmethod
|
||||
def init_core_app():
|
||||
# sets the name of the application etc
|
||||
QCoreApplication.setApplicationName("SolarCarSim")
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.main_window = main = QMainWindow()
|
||||
main.setWindowTitle("SolarCarSim")
|
||||
self.dockarea = DockArea(main)
|
||||
|
||||
self.loghandler = LogHandler(main)
|
||||
self.plugins = {}
|
||||
|
||||
self.plugins['satellaview'] = SatellaUI(main)
|
||||
logging.getLogger().addHandler(self.loghandler)
|
||||
|
||||
main.setCentralWidget(self.dockarea)
|
||||
self.dockarea.addDock(self.loghandler.dock)
|
||||
self.dockarea.addDock(self.plugins['satellaview'].dock)
|
||||
self.file_menu = main.menuBar().addMenu("File")
|
||||
self.settings_menu = main.menuBar().addMenu("Settings")
|
||||
|
||||
|
||||
|
||||
def run(self):
|
||||
return self.main_window.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = QApplication()
|
||||
myapp = App()
|
||||
myapp.run()
|
||||
QTimer.singleShot(1000, lambda: logging.warning("logging test"))
|
||||
QTimer.singleShot(2000, lambda: logging.warning("logging tes2332t"))
|
||||
sys.exit(app.exec())
|
19
src/solarcarsim/plugin.py
Normal file
19
src/solarcarsim/plugin.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
# Plugin base class
|
||||
|
||||
|
||||
|
||||
class BasePlugin():
|
||||
""" Base class for plugins. A plugin is a tool or feature that can extend the application.
|
||||
All application features are implemented as a plugin. The list of plugins can be seen in the toolbar.
|
||||
"""
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def version(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def startup(self, window, toolbar, application):
|
||||
raise NotImplementedError()
|
||||
|
||||
def teardown(self, window, toolbar, application):
|
||||
raise NotImplementedError()
|
0
src/solarcarsim/satellaview/__init__.py
Normal file
0
src/solarcarsim/satellaview/__init__.py
Normal file
0
src/solarcarsim/satellaview/goes.py
Normal file
0
src/solarcarsim/satellaview/goes.py
Normal file
0
src/solarcarsim/satellaview/himawari.py
Normal file
0
src/solarcarsim/satellaview/himawari.py
Normal file
0
src/solarcarsim/satellaview/rendering.py
Normal file
0
src/solarcarsim/satellaview/rendering.py
Normal file
0
src/solarcarsim/satellaview/source.py
Normal file
0
src/solarcarsim/satellaview/source.py
Normal file
36
src/solarcarsim/satellaview/ui.py
Normal file
36
src/solarcarsim/satellaview/ui.py
Normal file
|
@ -0,0 +1,36 @@
|
|||
# Qt widget to select/render satellite data.
|
||||
# this ties together the renderer and source plugins
|
||||
|
||||
|
||||
from PySide6.QtCore import QObject
|
||||
from PySide6.QtWidgets import QPushButton, QSplitter, QVBoxLayout, QWidget, QGridLayout
|
||||
from pyqtgraph import ImageView
|
||||
from pyqtgraph.dockarea import Dock
|
||||
from pyqtgraph.parametertree import ParameterTree
|
||||
|
||||
class SatellaUI(QObject):
|
||||
|
||||
def __init__(self, parent = None) -> None:
|
||||
super().__init__(parent)
|
||||
|
||||
self.splitter = split = QSplitter()
|
||||
|
||||
|
||||
self.param_tree = ParameterTree(split)
|
||||
|
||||
self.run_button = QPushButton("Execute", split)
|
||||
|
||||
split.addWidget(self.param_tree)
|
||||
split.addWidget(self.run_button)
|
||||
|
||||
self.viewer = ImageView(split, "Data")
|
||||
|
||||
split.addWidget(self.viewer)
|
||||
|
||||
self._dock = None
|
||||
@property
|
||||
def dock(self):
|
||||
if self._dock is not None:
|
||||
return self._dock
|
||||
self._dock = Dock("Satellaview", widget=self.splitter)
|
||||
return self._dock
|
0
src/solarcarsim/simulator/__init__.py
Normal file
0
src/solarcarsim/simulator/__init__.py
Normal file
0
src/solarcarsim/simulator/main.py
Normal file
0
src/solarcarsim/simulator/main.py
Normal file
|
@ -7,7 +7,7 @@ from functools import partial
|
|||
|
||||
from typing import NamedTuple, Tuple
|
||||
|
||||
from solarcarsim.noise import (
|
||||
from .noise import (
|
||||
fractal_noise_1d,
|
||||
generate_wind_field,
|
||||
)
|
|
@ -1,5 +1,5 @@
|
|||
import gymnasium as gym
|
||||
import solarcarsim.physsim as sim
|
||||
import solarcarsim.simulator.physsim as sim
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import jit
|
|
@ -9,8 +9,8 @@ from jax import lax, vmap
|
|||
from gymnax.environments import environment
|
||||
from gymnax.environments import spaces
|
||||
|
||||
from solarcarsim.physsim import CarParams, fractal_noise_1d
|
||||
import solarcarsim.physsim as sim
|
||||
from solarcarsim.simulator.physsim import CarParams, fractal_noise_1d
|
||||
import solarcarsim.simulator.physsim as sim
|
||||
|
||||
|
||||
@struct.dataclass
|
|
@ -1 +0,0 @@
|
|||
import pyray as ray
|
Loading…
Reference in a new issue