To understand this, there are three major components: the continuous reward, which is rewarded at every step, and is the position of the car
relative to the goal distance. The victory reward is a constant, minus the energy used and the early arrival penalty.
This was added to help guide the agent towards arriving with as little time left as possible. Finally, there's a penalty for the time
going above the goal time, as after that point the car is disqualified from the race.
It took a few iterations to find a reward metric that promoted fast learning. Some of these issues were exacerbated by the initially low
performance when using stable baselines. A crucial part of the improvement was the energy loss only being applied during wins.
This allowed the model to quickly learn to go forward to finish, after which refinement of speed could take
place\footnote{I looked into Q-initialization but couldn't figure out a way to implement it easily.}.
\subsection{State and Observation Spaces}
The complete state of the simulator is the position, velocity, and energy of the car, as well as the entire environment.
These parameters are sufficient for a deterministic snapshot of the simulator. However, one goal of the project
was to enable partial-observation of the system. To this end, we separate the observation space into a small snippet
of the upcoming wind and slope. This also simplifies the agent calculations since the view of the environment is
relative to its current position. The size of the view can be controlled as a parameter.
\section{Experiments and Results}
An implementation of the aforementioned simulator was developed with Jax. Jax was chosen as it enables
vectorization and optimization to improve performance. Additionally, Jax allows for gradients of any function
to be computed, which is useful for certain classes of reinforcement learning. In our case, we didn't
use this as there seemed to be very little available off the shelf.
Initially Stable Baselines was used since it is one of the most popular implemntations of common RL algorithms.
Stable Baselines3\cite{stable-baselines3} is written in PyTorch\cite{Ansel_PyTorch_2_Faster_2024}, and uses the Gym\cite{gymnasium} format for environments. A basic Gym wrapper
was created to connect SB3 to our environment.
PPO was chosen as the RL algorithm as it is very simple, while still being effective \cite{proximalpolicyoptimization}
The performance and convergence was very bad. This made
it difficult to diagnose as the model would need potentially millions of steps before it would learn anything interesting.
The primary performance loss was in the Jax/Numpy/PyTorch conversion, as this requires a CPU roundtrip.
To combat this I found a Jax-based implementation of PPO called \texttt{purejaxrl}. This library is
written in the style of CleanRL but instead uses pure Jax and an environment library called \texttt{gymnax}\cite{gymnax2022github}.
The primary advantage of writing everything in Jax is that both the RL agent and the environment can be offloaded to the GPU.
Additionally, the infrastructure provided by \texttt{gymnax} allows for environments to be vectorized. The speedup from
using this library cannot be understated. The SB3 PPO implementation ran at around 150 actions per second. After rewriting
some of the code to make it work with \texttt{purejaxrl}, the effective action rate\footnote{I ran 2048 environments in parallel}
was nearly$238000$ actions per second\footnote{It's likely that performance with SB3 could have been improved, but I was struggling to figure out exactly how.}.
The episode returns after 50 million timesteps with a PPO model can be seen in Figure~\ref{fig:returns}. Each update step
is performed after collecting minibatches of rollouts based on the current policy. We can see a clean ascent at the start of training,
this is the agent learning to drive forward. After a certain point, the returns become noisy. This is likely due to energy scoring
being random based on the terrain. A solution to this, which wasn't pursued due to lack of time, would be to compute the
"nominal energy" use based on travelling at $v_{avg}$. Energy consumption that was above the nominal use would be penalized, and
below would be heavily rewarded. Despite this, performance continued to improve, which is a good sign for the agent being
While the PPO performance was decent, it still had a significant amount of improvement on the table. Tuning the reward function
would probably help it find a solution better. One strategy that would help significantly is to pre-tune the model to output the
average speed by default, so the model doesn't have to learn that at the beginning. This is called Q-initialization and is a common
trick for problem spaces where an initial estimate exists and is easy to define. Perhaps the most important takeaway from this
work is the power of end-to-end Jax RL. \texttt{purejaxrl} is CleanRL levels of code clarity, with everything for an agent
being contained in one file, but surpassing Stable Baselines3 significantly in terms of performance. One drawback is that
the ecosystem is very new, so there was very little to reference when I was developing my simulator. Often there would be
an opaque error message that would yield no results on search engines, and would require digging into the Jax source code to diagnose.
Typically this was some misunderstanding about the inner works of Jax. Future work on this project would involve trying out
other agents, and comparing different reward functions. Adjusting the actor-critic network would also be an interesting avenue,
especially since a CNN will likely work well with wind and cloud information, which have both a spatial and temporal
axis\footnote{You can probably tell that the quality dropped off near the end - bunch of life things got in the way, so this didn't go as well as I'd hoped. Learned a lot though.}.