<?xml version="1.0" encoding="utf-8"?><feed xmlns="http://www.w3.org/2005/Atom" ><generator uri="https://jekyllrb.com/" version="4.4.1">Jekyll</generator><link href="https://yonakalabs.net/feed.xml" rel="self" type="application/atom+xml" /><link href="https://yonakalabs.net/" rel="alternate" type="text/html" /><updated>2026-04-07T20:25:04+09:00</updated><id>https://yonakalabs.net/feed.xml</id><title type="html">Yonaka Research</title><subtitle>Developing reinforcement learning to let Yonaka play games</subtitle><author><name>Kazumi</name></author><entry><title type="html">awawawawawawawawawawa</title><link href="https://yonakalabs.net/s/643f46d5f67711e0/" rel="alternate" type="text/html" title="awawawawawawawawawawa" /><published>2026-04-06T22:00:00+09:00</published><updated>2026-04-06T22:00:00+09:00</updated><id>https://yonakalabs.net/s/test-post</id><content type="html" xml:base="https://yonakalabs.net/s/643f46d5f67711e0/"><![CDATA[]]></content><author><name>Kazumi</name></author><category term="test" /></entry><entry><title type="html">Q-Learning with Multiple Subactions</title><link href="https://yonakalabs.net/subactions/" rel="alternate" type="text/html" title="Q-Learning with Multiple Subactions" /><published>2025-07-31T22:00:00+09:00</published><updated>2025-07-31T22:00:00+09:00</updated><id>https://yonakalabs.net/subactions</id><content type="html" xml:base="https://yonakalabs.net/subactions/"><![CDATA[<p>In my last post, I showed how to handle continuous actions in Q-learning using cubic splines.
That solved one major limitation, but there’s still another one that keeps people away from using DQN.</p>

<p><img src="/assets/images/posts/subactions/stable-baselines-comparison.png" alt="Stable Baselines 3 Algorithm Comparison" /></p>

<p>DQN is still missing support for <a href="https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiDiscrete">MultiDiscrete</a> action spaces. That’s the other major limitation I want to address.
It also says <a href="https://gymnasium.farama.org/api/spaces/fundamental/#gymnasium.spaces.MultiBinary">MultiBinary</a> is missing, but that’s just a weird MultiDiscrete so I’ll ignore it.</p>

<p>Box is also only half solved, but the full Box action space isn’t just about continuous action, it’s also an action space for multiple actions.</p>

<h2 id="actions-with-multiple-decisions">Actions with multiple decisions</h2>

<p>MultiDiscrete actions represent situations where you need to make multiple independent decisions simultaneously in each step.</p>

<p>Think about controlling a character in a game where you need to be doing multiple things at once, you might need to decide movement
direction, where to aim, or whether to use special abilities at any given moment. For each of these, you might have different buttons for controling them.</p>

<p>Let me use an example. The agent has an Xbox controller as the action space. At any timestep, the agent needs to decide to press:</p>

<ul>
  <li><strong>Face buttons</strong>: Press A, B, X, Y or not (4 binary decisions)</li>
  <li><strong>Joysticks</strong>: Left stick direction, right stick direction (two 2D continuous inputs)</li>
  <li><strong>Triggers/Bumpers</strong>: Left trigger, right trigger, left bumper, right bumper (4 more binary decisions, or 2 binary and 2 continuous actions)</li>
</ul>

<p>I’ll call these individual decisions a subaction. The complete action for that timestep is the combination of all subaction choices. 
In general, subactions are unordered set, they are executed simultaneously.</p>

<p>In this post, I’ll be going over how to handle composite actions.</p>

<h2 id="the-standard-approach">The standard approach</h2>

<p>The easiest way, and how many tutorial for Q-learning does is by making it a single action space by representing it as a cartesian product, or using a hand selected subset of it.
This means, for every additional possible action, the possible actions that the agent needs to consider each step grows exponentially.</p>

<p>For the controller case, it would be</p>

<ul>
  <li><strong>Face buttons</strong>: \(2^4 = 16\) combinations</li>
  <li><strong>Joysticks</strong>: \(8^2 = 64\) combinations (if discretized to 8 directions each)</li>
  <li><strong>Triggers/bumpers</strong>: \(2^4 = 16\) combinations</li>
</ul>

<p>For a total of \(16 \times 64 \times 16 =16,384\) possible actions.
If you want to consider joystick and trigger to have variable strengths instead of being discretized, this would grow even more.</p>

<p>The size of the action space isn’t the only problem here, the biggest problem is that every action is treated completely differently from each other.
To the agent, an action with one different component would be represented as differently as a completely different different action. This would make training take so much longer, because associating actions would be needlessly hard.</p>

<p>It would be much better to keep subactions separate, sample them individually, and combine them into the final action. That way, it would reduce the represented actions, and also relate similar actions together better. This is how every other RL algorithm does it anyway.</p>

<h2 id="how-to-make-sampling-subactions-easier">How to make sampling subactions easier</h2>

<p>Q-values represent expected future reward for taking a particular action. But what should Q-values mean for multiple simultaneous actions?</p>

<p>Q-values only make sense for a complete action that you’ve decided to take, which is why the standard approach has a full table of every combination. 
We can make sampling much easier if each subaction had its own function to sample, which has predictable effects in the overall Q-value. How could we do this?
The simplest way might be to have functions for each subactions, and let the Q-value be the sum of them:</p>

\[Q(\mathbf{a}; s) = \displaystyle\sum_{i = 1}^n Q(a_i; s)\]

<p>I tried this initially and it kind of worked, but there’s an identifiability problem during training: Even if you’ve identified \(Q(\mathbf{a}; s)\) for every possible $\mathbf{a}$ in a given state, you still can’t uniquely reconstruct each \(Q(a_i; s)\).
If one subaction’s Q-values have a constant added while another’s have the same constant subtracted, the sum remains unchanged.
This is a problem for trying to figure out gradients, because there wouldn’t be a unique minimum, there will be entire symmetry group with the same minimum loss, where some of the minima have worse training dynamic.</p>

<p>But this the exact same problem that <a href="https://arxiv.org/abs/1511.06581">Dueling Networks</a> faced, which means we can use the same trick as it!</p>

<p>They solved this by adding a state value that is independent of actions, and replacing action values with action advantages where the average over possible actions is subtracted. We can apply the same fix: add an action-independent state value, and replace all subaction values with subaction advantage.</p>

\[Q(\mathbf{a}; s) = V(s) + \displaystyle\sum_{i = 1}^n \left ( Q(a_i; s ) - \displaystyle\sum_{a'_i \in A_i} \frac{1}{|A_i|} Q(a'_i; s) \right)\]

<p>This approach works well, and it’s essentially what the <a href="https://arxiv.org/abs/1711.08946">Action Branching Architectures</a> paper implements. They create independent advantages for each subaction and add them to a state value to make the Q-value.</p>

\[\]

<h2 id="the-independence-problem">The Independence Problem</h2>

<p>While the Action Branching paper treats independent subaction sampling as a feature, I think it’s actually a limitation.</p>

<p>Consider an agent learning to use an art program. For each brush stroke, it needs to simultaneously decide:</p>

<ul>
  <li><strong>What to draw</strong>: Sun, tree, lake, or cloud</li>
  <li><strong>Where to place it</strong>: Top, middle, bottom of canvas</li>
  <li><strong>Which color</strong>: Yellow, green, blue, or white</li>
</ul>

<p>You could draw a sun on the top in yellow, a lake in the middle in blue, and those combinations would make sense.</p>

<p>But with independent sampling, each subaction would have no idea on what it already chose, so you might get “draw sun + use blue + place at bottom” resulting in a blue sun underwater, or “draw lake + use yellow + place at top” giving you a floating yellow lake in the sky!</p>

<p>In Action Branching, the authors argue that the shared state representation can coordinate decisions. The idea is that the state embedding would have already decided
which action it should take, and the independent subactions will all agree to take the action according to the decision.</p>

<p>But it’s not obvious how this could work. Unlike policy gradient methods where you only make one per subaction, in DQN you construct a Q-function that is defined over all possible actions.
If the subaction advantage values are generated independently, the resulting Q-function necessarily has no correlation between
different subactions for a given environment state.</p>

<div class="yonaka-quote-custom" data-image="/assets/images/characters/yonaka-loading.png" data-artist="Crescend Cinnamon" data-artist-link="https://bsky.app/profile/crescend.bsky.social">
  "I was trying to pour milk from the fridge for my cereal while being distracted, but somehow I ended up putting my phone in the fridge and trying to pour milk on my cereal from my empty hoof."
</div>

<h2 id="autoregressive-action-sampling">Autoregressive Action Sampling</h2>

<p>To handle actions being dependent on each other while keeping the sampling easy, I propose an autoregressive action sampling method, where each subaction are conditioned on previously sampled subactions within the same step.</p>

<p>To explain this, let me use the currying interpretation that helped last time. The Q-value function \(Q(s, A)\) could be written as \(Q(s)(a_0)(a_1)...(a_n)\) for some <a href="#action-order">order of subactions</a>.
The idea is to sample one action at a time to build up the final Q-value, but we need to be careful a bit because this is actually doing two steps at a time.</p>

<p>When I write \(Q(s)\) or \(Q(s)(a_0)...(a_k)\), the result is a function. What we want is for each step to also be producing a value, so that each subaction can be sampled.
I want a Q-advantage function \(F\) that can turn any of these intermediates and produce the advantage function for the next specific subaction. For discrete actions, this looks like a table while for continuous actions, this looks like a spline curve.</p>

<p>So using these, the steps to sample the full action for a step is to</p>

<ol>
  <li>Make embedding \(Q(s)\) and state value that does not depend on the action for this step</li>
  <li>Make \(F(Q(s))\) that is a function of the first subaction, and sample \(a_0\)</li>
  <li>Make the next embedding \(Q(s)(a_0)\) from sampled subaction</li>
  <li>sample from \(F(Q(s)(a_0))\) to get \(a_1\), and repeat from 3 until finished</li>
</ol>

<p>Then the Q-value for the state is the sum of state value and the subaction advantages.</p>

<p><img src="/assets/images/posts/subactions/sampling.png" alt="Autoregressive Sampling" class="diagram" /></p>

<p>This way, each subaction decision can account for what was already decided. In the earlier drawing example, if it decided to draw a tree as the first subaction, it could then decide to start from the trunk, and then choose brown as the color.</p>

<h2 id="model-architecture">Model Architecture</h2>

<p>Autoregressive action sampling needs to do two things, sample actions and evaluate a value given the state and action taken.
Sampling autoregressively is a bit slow, since before being able to sample a new action, all the action before it needs to be sampled first.
Evaluation in other hand can be faster, since you already have all of the action that needs to be taken, so the overhead of sampling individually is not there.</p>

<p>For a recurrent network, these two will be done exactly the same way. But transformers can take advantage of faster evaluation.
Transformers can also use optimizations for sampling like KV caching for each sampling step, or architectures like Grouped Query Attention and Multi Latent Attention.</p>

<p>At first I’ve tried using a full self attention with the state and action embedding concatenated together, but this seems to perform worse than Action Branching.
I’ve spent a lot of time trying to figure out why, but it seems there needed to be a clear separation between the states and action.
Cross Attention works much better, with the states being the query and action being the key and value. At the worst of cases, it would behave like Action Branching.</p>

<p><img src="/assets/images/posts/subactions/cross-attention.png" alt="Cross Attention" class="diagram" /></p>

<p>What I like to do is to add a beginning of sequence tokens before the first action token, so that there are something for the state embedding to attend to for the first sampling.</p>

<h2 id="action-order">Subaction Ordering Problem</h2>

<p>With an autoregressive sampling, there needs to be an order to sample the subactions, but how would you decide the order?</p>

<p>In general, subactions are a set with no order. At the same time, there might be a more natural order to decide, but it’s impossible to know
without some domain knowledge which order is better or not. You need to already know something about the environment.</p>

<p>Picking a predetermined action order at random is what I’m doing right now and this works, but feels unsatisfying.
I have two ideas on how to make this better, but I haven’t been able to make them work yet. They are at the <a href="#action-order-method">Action Order</a> section</p>

<h2 id="conclusion">Conclusion</h2>

<p>I couldn’t really get a clear <a href="#experiment-results">result</a> after many trial and error, and I underestimated just how well the prior work, Action Branching works well in practice.
I can come up with cases where Action Branching would definitely fail, but I don’t have any environments where that might be a problem yet.</p>

<p>Q-learning can be used to handle actions which have multiple subactions to them just fine, without hitting the combinatorial explosion problem that people typically face.
I’ve shown that combined with the spline action space for continuous action, Q-learning can be used for any environments just like policy gradient methods.</p>

<div class="kazumi-quote">
When writing this blog post, I've held off comparing against a baseline, and when I finally did compare, found out my method was much worse. That was kind of a good experience though, because it gave me clues on how to improve it, and also taught me that when trying out something new, I should test it against a baseline as soon as possible.
</div>

<h1 id="addendum">Addendum</h1>

<details class="collapse-section">
  <summary>Prior Works</summary>
  <div class="collapse-content">
    
<p>The core idea of handling multiple subactions with Q-learning was introduced in <a href="https://arxiv.org/abs/1711.08946">Action Branching Architectures for Deep Reinforcement Learning</a>. They use independent Q-functions for each subaction and combine them using the dueling network architecture.</p>

<p>I had a similar idea independently, but found their paper while researching for this blog and realized they had already solved the basic version of this problem. Their approach works well for many cases, but struggles with scenarios that require coordinated subaction strategies.</p>

<p>The autoregressive approach I’m proposing builds on their foundation but adds the ability to handle action dependencies through sequential conditioning.</p>


  </div>
</details>

<h3 id="action-order-method">Action Order</h3>

<details class="collapse-section">
  <summary>Dynamic Action Order</summary>
  <div class="collapse-content">
    
<p>If I could try different orderings and learn which order performs better dynamically, that could be better than deciding a random order.</p>

<p>To make dynamic action order to work, I imagined sampling a permutation matrix that determines the action sampling order, and learn orders that generally improve rewards as the model trains. This would need some algorithm to figure out what order should be used.</p>

<p>To figure out an algorithm, I made a few assumptions. If one subaction order is better than another order, then the better order would have less uncertain prediction, and the temporal difference loss could be used as a proxy for uncertainty.</p>

<p>To simplify, I’ll also assume that TD loss scale is determined only by the order of subactions and is the sum of pairwise values associated with consecutive subactions in that order.</p>

<p>We can’t just compare the TD loss per subaction, for the same reason that TD loss needed to be summed over the state value and action value. A subaction sampled earlier might make another subaction sampled later to have better options.</p>

<p>Given these assumption, I could formulate this as an optimization problem:</p>

<ul>
  <li>$n$ vertices are connected by directed edges with unknown weights. Given samples of hamiltonian path with the sum of edge weights, what is the best path that likely minimizes the total weight from the given information gathered from sampling?</li>
</ul>

<p>Each vertices represent a subaction, and the directed edge between them are the order of sampling, and the weights are how much they contribute to the total TD loss.</p>

<p>This is just a traveling salesman problem, if the edge weights are known. You could maybe setup a linear equation to figure out the weights as you sample, but that would make a $n^2$ by $n^2$ matrix that needs to be solved, which would already be $O(n^6)$ even before the TSP step.</p>

<p>Maybe there is a better way to figure out the weights, or maybe there is a way that doesn’t even need to figure out the weight? I got stuck trying to figure this out.</p>

<p>The assumption that TD loss scale is only determined by order pairs doesn’t really hold up probably, reward scale already matters on which action was taken first regardless of how far ago it was sampled.</p>

<p>So, I could try another assumption. Instead of only consecutive subaction having interaction, suppose that every pair of subaction has an ordering preference, where the values are assigned on whether a subaction is before another, regardless of how far apart they are. The TD loss would be the sum of all ordering preference.</p>

<p>This would change the problem into a Linear Ordering Problem if the values are known, which is also a known NP-hard problem.</p>

<p>I couldn’t figure out a good way to do this, and this is about the time I started thinking about Action Latents so I’ve put it on indefinite hold.</p>


  </div>
</details>

<details class="collapse-section">
  <summary>Latent Action Space</summary>
  <div class="collapse-content">
    
<p>This is the other idea that I came up while writing this blog post, and I’m still exploring.</p>

<p>If picking an order is a problem because I don’t know which order would be better, what if I make my own representation of actions, and sample from that?</p>

<p>Imagine you’re riding a bicycle. You don’t consciously think of what muscles to pull at what moment, you have a conceptual understanding of what to do, and let your muscle memory figure out the small movements. In order for your body to figure that out, it takes a while of trying.</p>

<p>Latent action space is kind of like that, instead of dealing with the raw actions, it could learn how to conceptualize actions in meaningful ways, and act on that space instead.</p>

<p>The learned representation would be sampled in an order, but the representation could be learned in a way that the order it gets sampled would be the optimal order to be sampled.</p>

<p>Looking for papers that implement this, I came across <a href="https://arxiv.org/pdf/2103.15793">LASER</a>, Learning a Latent Action Space for Efficient Reinforcement Learning.</p>

<p><img src="/assets/images/posts/subactions/laser.png" alt="Laser Overview" /></p>

<p>The idea is to abstract away actions by dealing it in a latent space that might be easier to think about. The properties we want from the latent space is that</p>

<ul>
  <li>Latent action should preserve important information of the original action, and be able to uniquely reconstruct it</li>
  <li>Latent action should make it easier to think about the state, and make the future more easily predictable</li>
  <li>Actions that results in similar outcome should be close together in latent space</li>
  <li>Sampling from a typical latent action distribution and acting them should naturally have higher rewards than sampling from atypical latent distribution</li>
</ul>

<p>To achieve this, the collected state \(s\), action \(a\) and state transition \(s'\) from interacting with the environment are used to train an Encoder \(E(a, s)\) to encode actions into latent action \(\overline{a}\), Decoder \(D(\overline{a}, s)\) that decodes latent action into action as a variational autoencoder pair, and a latent state transition function \(\overline{T}(s, \overline{a})\) that predicts next state given previous state and latent action.</p>

<p>Then these are trained using</p>

<ul>
  <li><strong>Action Reconstruction loss</strong>: Squared error \(\| a - D(E(a, s), s) \|^2_2\) if continuous, or Cross entropy \(-a \log( D(E(a, s), s))\) if discrete</li>
  <li><strong>Dynamics loss</strong>: Squared error \(\| s' - T(E(a, s), s) \|^2_2\)</li>
  <li><strong>Regularization loss</strong>: KL divergence \(KL(N(μ, σ) \| N(0,I))\) where \(μ, σ \sim E(a, s)\)</li>
  <li><strong>Policy loss</strong>: \(- Q_{policy}(E(a, s), s)\)</li>
</ul>

<p>I’m still trying to figure out how to make this work with Q-learning still, I’m hoping this would solve the action order problem.
This actually kind of works already, but it’s not working as well. This was mostly me throwing around ideas to see what sticks, and I would like to spend more time on this.</p>


  </div>
</details>

<h3 id="experiment-results">Results</h3>

<details class="collapse-section">
  <summary>Results</summary>
  <div class="collapse-content">
    
<p>Honestly I’m not sure what kind of conclusion I’m supposed to make out of the final results, since everything ended up performing pretty much the same as any other.
Action Branching is actually pretty good, it turns out. It is also still faster to do, even if autoregressive sampling has many optimization points.
For most environments, it’s probably fine to just do action branching, and see if my method is any better.</p>

<p>In some environments, Action Latents do much worse than not using it, while in other it’s about the same.
Action Latent needs some warmup time where it only trains the action latent encoder, decoder and the dynamics model.
During that time, the Q-Net doesn’t learn anything, but it seems to be getting better scores than random actions, which is kind of weird?
My guess is that, even though it doesn’t know what actions are better or worse, doing things that lead to more interesting states generally have a higher reward, so maybe it can learn what actions achieve nothing and avoid doing them?</p>

<p>I think right now, I need more environments to test on. If every model reaches the same performance at the same time, it usually means you need a better test to see which are better.</p>

<p><img src="/assets/images/results/subactions/walker.png" alt="Walker Result" /></p>


  </div>
</details>]]></content><author><name>Kazumi</name></author><category term="reinforcement-learning" /><category term="dqn" /></entry><entry><title type="html">Extending DQN to Continuous Action Spaces with Cubic Splines</title><link href="https://yonakalabs.net/DQN-spline/" rel="alternate" type="text/html" title="Extending DQN to Continuous Action Spaces with Cubic Splines" /><published>2025-04-18T22:00:00+09:00</published><updated>2025-04-18T22:00:00+09:00</updated><id>https://yonakalabs.net/DQN-spline</id><content type="html" xml:base="https://yonakalabs.net/DQN-spline/"><![CDATA[<p>One of the main things that turns people away from using Deep Q-Learning is its inability to handle continuous actions or multiple sub-actions. In <a href="https://stable-baselines3.readthedocs.io/en/master/guide/algos.html">Stable Baselines 3</a>, they have a table of reinforcement learning algorithms and what kind of action spaces they each work in.</p>

<p><img src="/assets/images/posts/spline/stable-baselines-comparison.png" alt="Stable Baselines 3 Algorithm Comparison" /></p>

<p>In their table, DQN only has a tick on the Discrete actions box. That is very limiting! It would be nice if there was an easy and cheap way of allowing DQN to work with continuous and multiple actions. But for now, let’s focus on how to make the first one work.</p>

<h2 id="the-problem-with-discrete-only-actions">The Problem with Discrete-Only Actions</h2>

<p>In games such fighting games, where an agent selects from a set of actions (move left, jump, shoot), a normal DQN works wonderfully. But what about games that need more precise control? Think about:</p>

<ul>
  <li>A car adjusting its steering angle</li>
  <li>Twinstick shooter like Binding of Isaac</li>
  <li>A game like Minecraft where you need both discrete actions (moving with WASD keys, mining with click) and continuous control (moving the camera around)</li>
</ul>

<p>Eventually I would have to build an agent that works with continuous control, but I knew DQN wouldn’t work out of the box. The standard approach-discretizing the action space into bins-technically works but produces jerky, unnatural movement. Imagine a car that can only turn its steering wheel in 10-degree increments instead of smoothly!</p>

<p>Most practitioners simply avoid DQN altogether for these tasks, moving to algorithms specifically designed for continuous control like DDPG or SAC. But I wondered: could we adapt DQN to handle continuous actions elegantly?</p>

<h2 id="why-cant-dqn-handle-continuous-actions">Why Can’t DQN Handle Continuous Actions?</h2>

<p>To understand the problem, we need to revisit how Q-learning actually works.</p>

<p>In DQN, the Q-function represents the expected future reward when taking action a in state s, then following the policy afterward. This is written as $Q(s, a)$.</p>

<p>For an agent to act, it needs to find the action that maximizes this Q-function:</p>

\[a^* = \arg\max_a Q(s, a)\]

<p>For discrete actions, this is straightforward. If you have 4 possible actions, you calculate a Q-value for each one and pick the highest. Done!</p>

<p>But what happens with continuous actions? If an action can be any value between, say, 0 and 1, we can’t simply enumerate all possibilities.</p>

<p><img src="/assets/images/posts/spline/discrete-vs-continuous.png" alt="Discrete vs Continuous Action Space" class="diagram" /></p>

<h2 id="the-standard-solution-discretization">The Standard Solution: Discretization</h2>

<p>The most common approach is to simply chop up (discretize) the continuous action space into a finite set of actions.</p>

<p>For example, if your action space is $[0, 1]$, you might use ${0, 0.1, 0.2, …, 0.9, 1.0}$ as your discrete approximation.</p>

<p><img src="/assets/images/posts/spline/discretization-diagram.png" alt="Discretization Diagram" class="diagram" /></p>

<p>This works, but has significant drawbacks:</p>

<ol>
  <li><strong>Resolution problems</strong>: Too few points and your agent can’t make fine adjustments; too many and learning becomes inefficient</li>
  <li><strong>No knowledge transfer</strong>: Learning that an action is good doesn’t tells the agent whether a similar action would also be good</li>
  <li><strong>Curse of dimensionality</strong>: Discretizing multiple continuous actions leads to combinatorial explosion (more on this in next post!)</li>
</ol>

<div class="yonaka-quote-custom" data-image="/assets/images/characters/yonaka-confused.png" data-artist="sroka001" data-artist-link="https://bsky.app/profile/sroka001.bsky.social">
  "I tried moving 45 degrees to the left and it worked well... but should I try 44? 46?"
</div>

<h2 id="a-different-way-of-looking-at-q-functions">A Different Way of Looking at Q-Functions</h2>

<p>Let’s think about what happens when we’re trying to select an action. Notice something important:</p>

<p>For a given state $s$, the argmax operation over actions doesn’t depend on the state anymore. We’ve essentially “locked in” our state and now just need to find the best action for that particular state.</p>

<p>This means, to make the argmax operation easier, we could curry the state into the Q-function $Q(s, a)$ to make a simpler function that only depends on the action $Q_s(a)$, and then take the maximum over the action:</p>

\[Q_s(a) = Q(s, a) \text{ where } s \text{ is fixed}\]

<p>For discrete actions, $Q_s(a)$ is just a lookup table! Finding the maximum value in a table is trivial.</p>

<p>But for continuous actions, $Q_s(a)$ becomes a continuous function over the action space. Finding the maximum of an arbitrary continuous function is much harder.</p>

<h2 id="what-we-need-in-a-continuous-q-function">What We Need in a Continuous Q-Function</h2>

<p>If we want to use Q-learning with continuous actions, our representation of the Q-function needs to support several operations:</p>

<ol>
  <li><strong>Evaluation</strong>: We need to compute $Q(s, a)$ for any action $a$</li>
  <li><strong>Maximization</strong>: We need to efficiently find the action $a$ that maximizes $Q(s, a)$</li>
  <li><strong>Integration</strong>: For some advanced techniques like Dueling Networks, we need to compute the average Q-value across all actions</li>
  <li><strong>Addition</strong>: We need to be able to add Q-functions together (useful for ensemble methods)</li>
</ol>

<p><img src="/assets/images/posts/spline/continuous-values-diagram.png" alt="Continuous Q Value Diagram" class="diagram" /></p>

<p>Many function approximators can handle evaluation, but maximization and integration are trickier. Neural networks, for instance, make evaluation easy but finding the global maximum is very difficult.</p>

<p>So what kind of mathematical construct could satisfy all these requirements?</p>

<h2 id="using-natural-cubic-splines">Using Natural Cubic Splines</h2>

<p>A cubic spline is a piecewise function made up of cubic polynomials that are smoothly connected at specific points called knots.</p>

<p><img src="/assets/images/posts/spline/cubic-spline-diagram.png" alt="Cubic Spline Diagram" class="diagram" /></p>

<p>Cubic splines have several properties that make them perfect for our needs:</p>

<ol>
  <li>They’re smooth and continuous</li>
  <li>They can approximate any continuous function (with enough knots)</li>
  <li>We can analytically find their maximums and compute their integrals</li>
  <li>They’re closed under addition (adding two cubic splines gives you another cubic spline)</li>
</ol>

<h3 id="how-cubic-splines-work">How Cubic Splines Work</h3>

<p>A cubic spline is defined by a set of control points (or knots) $(x_0, y_0), …, (x_n, y_n), (x_{n+1}, y_{n+1})$ where values $x$ are positions in our action space and values $y$ are our estimated Q-values at those actions.</p>

<p>Given these knots, the spline is</p>

\[\begin{align}S(x) &amp;= a_i t^3 + b_i t^2 + c_i t + d_i &amp; \text{where} &amp; &amp; x_i \leq x \leq x_{i+1} &amp; \text{,} &amp; t = \frac{x-x_i}{x_{i+1} - x_i}  \end{align}\]

<p>These polynomials are crafted to ensure that:</p>

<ul>
  <li>The spline passes through all control points</li>
  <li>The first and second derivatives match at each interior control point</li>
  <li>Specific boundary conditions are met at the endpoints</li>
</ul>

<p>I find that it’s much easier to handle if the internal coordinates of each polynomial goes from 0 to 1, and we translate when using them.</p>

<p>Check out <a href="https://mathworld.wolfram.com/CubicSpline.html">WolframMathWorld</a> for the cubic polynomial formula when the knots are equidistant, and the <a href="#spline-formula">Addendum</a> for non-equidistant knots.</p>

<h2 id="operations-on-cubic-splines">Operations on Cubic Splines</h2>

<p>Now let’s see how cubic splines handle all the operations we need:</p>

<h3 id="1-evaluation">1. Evaluation</h3>

<p>To evaluate a cubic spline at a particular action value:</p>

<ol>
  <li>Find which segment the action falls into</li>
  <li>Evaluate the cubic polynomial for that segment</li>
</ol>

<h3 id="2-finding-the-maximum">2. Finding the Maximum</h3>

<p>We can use the derivative tests to find all the potential points for each segment, and then find the maximum of those.</p>

<p>For each cubic polynomial segment:</p>

<ol>
  <li>Calculate its derivative curve</li>
  <li>Find the roots of the derivative (1st derivative test)</li>
  <li>Evaluate the spline at these points and at the boundaries</li>
  <li>Take the maximum of all these values</li>
</ol>

<p><img src="/assets/images/posts/spline/maximization-diagram.png" alt="Maximization of Spline" class="diagram" /></p>

<p>Since we’re dealing with cubic polynomials, the derivative is quadratic, and finding roots of a quadratic equation is trivial using the quadratic formula</p>

<p>And we can even narrow down the points by half if we use the 2nd derivative test, halving the amount to search!</p>

<h3 id="3-computing-the-mean">3. Computing the Mean</h3>

<p>Taking the mean of the Q function over the action is needed in methods like Dueling Network and a few others.</p>

<p>The mean value of a function over the entire input could be computed by taking the integral and dividing by the input space size.</p>

\[\mu = \int_{\min}^{\max} \frac{ S(x)}{\max - \min} dx\]

<p>For our cubic spline, we just need to integrate all the cubic polynomials and add them together, then multiply by the segment lengths they’re in.</p>

<p>If we made the internal coordinates go from 0 to 1, we don’t even need to integrate, it all simplifies to a single einsum expression</p>

\[\frac{ \left[\frac{1}{4}, \frac{1}{3}, \frac{1}{2}, 1 \right]_i Coeff_{ij} \Delta x_j}{x_{n+1} - x_0}\]

<p><a href="#spline-mean">Derivation</a></p>

<h3 id="4-adding-splines">4. Adding Splines</h3>

<p>Adding Q functions together is needed in some methods such as some extended Dueling Network or multi goal learning</p>

<p>Adding two cubic splines is straightforward:</p>

<ol>
  <li>Combine all unique knot points</li>
  <li>For each segment in the combined domain, add the corresponding polynomial coefficients</li>
</ol>

<h2 id="advantages-of-spline">Advantages of Spline</h2>

<p>Using cubic splines to represent our Q-function gives us several advantages:</p>

<ol>
  <li><strong>Smooth approximation</strong>: Unlike discretization, splines provide a continuous representation with few points</li>
  <li><strong>Knowledge transfer</strong>: Learning about the Q-value at one action informs us about nearby actions</li>
  <li><strong>Analytical maximization</strong>: The optimal action can be found precisely and efficiently, without needing to evaluate the entire space</li>
  <li><strong>Circular action spaces</strong>: Spline curves can have connected end points with continuous derivative, handling angles well</li>
</ol>

<p>I am not aware of any easily usable environments with circular action spaces to experiment in yet, let me know if you do</p>

<h2 id="conclusion">Conclusion</h2>

<p>DQN doesn’t have to be limited to discrete action spaces. By representing the Q-function as a cubic spline, we can enable DQN to work with continuous actions, without adding too much overhead.</p>

<p>Since splines are controlled by knots, it works with the exact input shape with what you would have used when doing discretized actions, making it pretty much a drop in replacement.</p>

<p>In the next post, I’ll show how to solve the other limitation with Q learning, handling mutliple subactions in a step without getting cursed by the dimensionality</p>

<div class="yonaka-quote">
"I used to be limited to jumping between discrete steps, but with splines, I can slide smoothly through the action space! No more awkward robot movements - now my actions can be as fluid as a human player's!"
</div>

<h1 id="addendum">Addendum</h1>

<details class="collapse-section">
  <summary>Prior Works</summary>
  <div class="collapse-content">
    
<p>I have found some work on this, one is <a href="https://arxiv.org/abs/1909.12397">CAQL: CONTINUOUS ACTION Q-LEARNING</a> where they make a tiny ReLU network and use Mixed Integer Programming, basically linear programming to find the maximum.</p>

<p>The method has limitations such as being slow, since each forward needs to solve an optimization problem, and only working when the action space is small, they only test on environments where the action space has been very limited.</p>

<p>I have found a work that uses spline curves for handling continuous action space, they call it <a href="https://apps.dtic.mil/sti/tr/pdf/ADA280844.pdf">Wire Fitting</a>, it’s from 1993 and this is what the front page looks like</p>

<p><img src="/assets/images/posts/spline/wirefitting.png" alt="REINFORCEMENT LEARNING WITH HIGHDIMENSIONAL, CONTINUOUS ACTIONS" /></p>


  </div>
</details>

<h2 id="experimental-results">Experimental Results</h2>

<details class="collapse-section">
  <summary>Results</summary>
  <div class="collapse-content">
    
<p>Here I put Weights and Bias plots of rewards for runs done with discretized action and spline action.</p>

<p>Environments are Reacher from <a href="https://gymnasium.farama.org/environments/mujoco/reacher/">MuJoCo</a> and Walker and Finger environment from Deep Mind’s <a href="https://github.com/google-deepmind/dm_control/tree/main">Control Suit</a></p>

<p>(click to enlarge)</p>

<div class="grid">
    <img src="/assets/images/results/spline/fingerspin.png" alt="fingerspin" title="fingerspin" />
    <img src="/assets/images/results/spline/walker.png" alt="walker" title="walker" />
    <img src="/assets/images/results/spline/reacher.png" alt="reacher" title="reacher" />
</div>

<p>Comparison of Walker performance between discretized action and spline action.</p>

<div class="grid">
    <img src="/assets/images/results/spline/discretized.gif" alt="buttslide" title="buttslide" />
    <img src="/assets/images/results/spline/spline.gif" alt="I'm late!" title="I'm late!" />
</div>

<p>The discretized agent is doing a butt slide, this strategy seems to be a very stable way of moving and won’t fall over, but has a limit on how fast it can move.</p>

<p>The spline agent is running as fast as it can, losing balance but quickly able to stand up and fall down again. It seems like it’s prioritizing short term gain over long term gain.</p>

<div class="kazumi-quote">
Side note, have you ever noticed that some agents with less capacity to learn will converge to a very safe strategy that's hard to mess up, while some agents with high capacity might not even learn a strategy, just have good execution? There is sometimes a sort of Strategy vs Execution trade off that happens.
</div>

<p>Real time Q function graph of Reacher environment</p>

<p><img src="/assets/images/results/spline/reacher.gif" alt="reacher" title="reacher" /></p>


  </div>
</details>

<h2 id="code-implementation">Code Implementation</h2>

<details class="collapse-section">
  <summary>Spline Implementation</summary>
  <div class="collapse-content">
    
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="n">torch</span>
<span class="kn">import</span> <span class="n">einops</span>

<span class="k">class</span> <span class="nc">SplineLayer</span><span class="p">(</span><span class="n">torch</span><span class="p">.</span><span class="n">nn</span><span class="p">.</span><span class="n">Module</span><span class="p">):</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">num_points</span><span class="p">,</span> <span class="nb">min</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">max</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
        <span class="nf">super</span><span class="p">().</span><span class="nf">__init__</span><span class="p">()</span>
        <span class="n">self</span><span class="p">.</span><span class="n">num_points</span> <span class="o">=</span> <span class="n">num_points</span>
        <span class="n">self</span><span class="p">.</span><span class="nb">min</span> <span class="o">=</span> <span class="nb">min</span>
        <span class="n">self</span><span class="p">.</span><span class="nb">max</span> <span class="o">=</span> <span class="nb">max</span>

        <span class="n">self</span><span class="p">.</span><span class="nf">register_buffer</span><span class="p">(</span><span class="sh">"</span><span class="s">inverse</span><span class="sh">"</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="nf">_precompute_inverse</span><span class="p">(),</span> <span class="n">persistent</span> <span class="o">=</span> <span class="bp">False</span><span class="p">)</span>

    <span class="k">def</span> <span class="nf">_precompute_inverse</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="n">n</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">num_points</span>
        
        <span class="n">diag</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">n</span><span class="p">)</span> <span class="o">*</span> <span class="mi">4</span>
        <span class="n">diag</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">diag</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">2</span>
        
        <span class="n">off_diag</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="n">n</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">A</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">diag</span><span class="p">(</span><span class="n">diag</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="nf">diag</span><span class="p">(</span><span class="n">off_diag</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">+</span> <span class="n">torch</span><span class="p">.</span><span class="nf">diag</span><span class="p">(</span><span class="n">off_diag</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">return</span> <span class="n">torch</span><span class="p">.</span><span class="n">linalg</span><span class="p">.</span><span class="nf">inv</span><span class="p">(</span><span class="n">A</span><span class="p">)</span> <span class="o">*</span> <span class="mi">3</span>

    <span class="k">def</span> <span class="nf">_compute_coefficients</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span>
        <span class="o">*</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n</span> <span class="o">=</span> <span class="n">y</span><span class="p">.</span><span class="n">shape</span>

        <span class="n">rhs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">zeros</span><span class="p">((</span><span class="o">*</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">n</span><span class="p">),</span> <span class="n">device</span><span class="o">=</span><span class="n">y</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
        <span class="n">rhs</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span><span class="p">[...,</span><span class="mi">1</span><span class="p">]</span>  <span class="o">-</span> <span class="n">y</span><span class="p">[...,</span><span class="mi">0</span><span class="p">])</span>
        <span class="n">rhs</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span><span class="p">[...,</span><span class="mi">2</span><span class="p">:]</span><span class="o">-</span> <span class="n">y</span><span class="p">[...,:</span><span class="o">-</span><span class="mi">2</span><span class="p">])</span>
        <span class="n">rhs</span><span class="p">[...,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">y</span><span class="p">[...,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span>  <span class="o">-</span> <span class="n">y</span><span class="p">[...,</span><span class="o">-</span><span class="mi">2</span><span class="p">])</span>

        <span class="n">D</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">matmul</span><span class="p">(</span><span class="n">rhs</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">inverse</span><span class="p">)</span>
        
        <span class="n">yi</span> <span class="o">=</span> <span class="n">y</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">yi1</span> <span class="o">=</span> <span class="n">y</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:]</span>
        <span class="n">di</span> <span class="o">=</span> <span class="n">D</span><span class="p">[...,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">di1</span> <span class="o">=</span> <span class="n">D</span><span class="p">[...,</span> <span class="mi">1</span><span class="p">:]</span>

        <span class="n">coeffs</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">stack</span><span class="p">([(</span><span class="mi">2</span><span class="o">*</span><span class="p">(</span><span class="n">yi</span><span class="o">-</span><span class="n">yi1</span><span class="p">)</span><span class="o">+</span><span class="n">di</span><span class="o">+</span><span class="n">di1</span><span class="p">),</span>  <span class="p">(</span><span class="mi">3</span> <span class="o">*</span> <span class="p">(</span><span class="n">yi1</span><span class="o">-</span><span class="n">yi</span><span class="p">)</span><span class="o">-</span><span class="mi">2</span><span class="o">*</span><span class="n">di</span><span class="o">-</span><span class="n">di1</span><span class="p">),</span> <span class="n">di</span><span class="p">,</span> <span class="n">yi</span><span class="p">],</span> <span class="n">dim</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">coeffs</span>
    
    <span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">points</span><span class="p">):</span>
        <span class="n">coefficients</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="nf">_compute_coefficients</span><span class="p">(</span><span class="n">points</span><span class="p">)</span>
        <span class="k">return</span> <span class="nc">Spline</span><span class="p">(</span><span class="n">points</span><span class="p">,</span> <span class="n">coefficients</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span><span class="p">,</span><span class="n">self</span><span class="p">.</span><span class="nb">max</span><span class="p">)</span>

<span class="k">class</span> <span class="nc">Spline</span><span class="p">:</span>
    <span class="k">def</span> <span class="nf">__init__</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">points</span><span class="p">,</span> <span class="n">coefficients</span><span class="p">,</span> <span class="nb">min</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">max</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
        <span class="n">self</span><span class="p">.</span><span class="n">points</span> <span class="o">=</span> <span class="n">points</span>
        <span class="n">self</span><span class="p">.</span><span class="n">coefficients</span> <span class="o">=</span> <span class="n">coefficients</span>
        <span class="n">self</span><span class="p">.</span><span class="n">num_segments</span> <span class="o">=</span> <span class="n">coefficients</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span>
        <span class="n">self</span><span class="p">.</span><span class="n">batch_shape</span> <span class="o">=</span> <span class="n">points</span><span class="p">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">self</span><span class="p">.</span><span class="nb">max</span> <span class="o">=</span> <span class="nb">max</span>
        <span class="n">self</span><span class="p">.</span><span class="nb">min</span> <span class="o">=</span> <span class="nb">min</span>
        
    <span class="k">def</span> <span class="nf">evaluate</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span>
        <span class="k">if</span> <span class="nf">isinstance</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">)):</span>
            <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tensor</span><span class="p">([</span><span class="n">t</span><span class="p">],</span> <span class="n">device</span><span class="o">=</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="nf">to</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="n">device</span><span class="p">)</span>

        <span class="n">t</span> <span class="o">=</span> <span class="p">(</span><span class="n">t</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nb">max</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span><span class="p">)</span> <span class="o">*</span> <span class="n">self</span><span class="p">.</span><span class="n">num_segments</span>
        <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">clamp</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="mf">1e-6</span><span class="p">,</span> <span class="n">self</span><span class="p">.</span><span class="n">num_segments</span><span class="o">-</span><span class="mf">1e-6</span><span class="p">)</span>

        <span class="n">seg</span> <span class="o">=</span> <span class="n">t</span><span class="p">.</span><span class="nf">long</span><span class="p">()</span>
        <span class="n">t</span> <span class="o">=</span> <span class="p">(</span><span class="n">t</span> <span class="o">-</span> <span class="n">seg</span><span class="p">.</span><span class="nf">float</span><span class="p">())</span>
        
        <span class="n">abcd</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="nf">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span><span class="n">seg</span><span class="p">[...,</span><span class="bp">None</span><span class="p">].</span><span class="nf">expand</span><span class="p">(</span><span class="o">*</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">*</span><span class="nf">len</span><span class="p">(</span><span class="n">seg</span><span class="p">.</span><span class="n">shape</span><span class="p">),</span> <span class="mi">4</span><span class="p">))</span>

        <span class="nf">return </span><span class="p">((</span><span class="n">abcd</span><span class="p">[...,</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span><span class="mi">2</span><span class="p">])</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span><span class="mi">3</span><span class="p">]</span>

    <span class="k">def</span> <span class="nf">maximum</span><span class="p">(</span><span class="n">self</span><span class="p">,</span> <span class="n">dim</span> <span class="o">=</span> <span class="bp">None</span><span class="p">,</span> <span class="n">weight</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">,</span><span class="n">c</span><span class="p">,</span><span class="n">d</span> <span class="o">=</span> <span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span> <span class="o">*</span> <span class="n">weight</span><span class="p">).</span><span class="nf">sum</span><span class="p">(</span><span class="n">dim</span><span class="p">).</span><span class="nf">tensor_split</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">,</span><span class="n">c</span><span class="p">,</span><span class="n">d</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="nf">tensor_split</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span>
        <span class="n">aa</span> <span class="o">=</span> <span class="o">-</span><span class="mi">3</span> <span class="o">*</span> <span class="n">a</span> 

        <span class="n">dd</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">sqrt</span><span class="p">(</span><span class="n">b</span><span class="p">.</span><span class="nf">square</span><span class="p">()</span> <span class="o">+</span> <span class="n">aa</span> <span class="o">*</span> <span class="n">c</span><span class="p">)</span>
        
        <span class="n">t</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">clamp</span><span class="p">((</span><span class="n">b</span> <span class="o">+</span> <span class="n">dd</span><span class="p">)</span> <span class="o">/</span> <span class="n">aa</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
        <span class="n">t</span><span class="p">[...,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span> <span class="n">t</span><span class="p">[...,:</span><span class="o">-</span><span class="mi">1</span><span class="p">,:].</span><span class="nf">nan_to_num</span><span class="p">(</span><span class="mf">0.</span><span class="p">)</span>
        <span class="n">t</span><span class="p">[...,</span><span class="o">-</span><span class="mi">1</span><span class="p">,:]</span> <span class="o">=</span><span class="n">t</span><span class="p">[...,</span><span class="o">-</span><span class="mi">1</span><span class="p">,:].</span><span class="nf">nan_to_num</span><span class="p">(</span><span class="mf">1.</span><span class="p">)</span>
        <span class="n">m</span> <span class="o">=</span> <span class="p">((</span><span class="n">a</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">b</span><span class="p">)</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">c</span><span class="p">)</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">d</span>
        
        <span class="n">p</span> <span class="o">=</span> <span class="n">m</span><span class="p">.</span><span class="nf">max</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">dim</span> <span class="ow">is</span> <span class="ow">not</span> <span class="bp">None</span><span class="p">:</span>
            <span class="n">indices</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="nf">ones</span><span class="p">(</span><span class="nf">len</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="n">shape</span><span class="p">),</span> <span class="n">dtype</span> <span class="o">=</span> <span class="n">torch</span><span class="p">.</span><span class="n">int64</span><span class="p">)</span>
            <span class="n">indices</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">4</span>
            <span class="n">indices</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="n">shape</span><span class="p">[</span><span class="n">dim</span><span class="p">]</span>
            <span class="n">abcd</span> <span class="o">=</span> <span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="nf">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span><span class="n">p</span><span class="p">.</span><span class="n">indices</span><span class="p">[...,</span><span class="bp">None</span><span class="p">,</span> <span class="bp">None</span><span class="p">].</span><span class="nf">expand</span><span class="p">(</span><span class="o">*</span><span class="n">indices</span><span class="p">))[...,</span> <span class="mi">0</span><span class="p">,</span> <span class="p">:]</span> 
            <span class="n">t</span> <span class="o">=</span> <span class="n">t</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">].</span><span class="nf">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span><span class="n">p</span><span class="p">.</span><span class="n">indices</span><span class="p">)</span>
            <span class="nf">return </span><span class="p">((</span><span class="n">abcd</span><span class="p">[...,</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span> <span class="mi">2</span><span class="p">])</span> <span class="o">*</span> <span class="n">t</span> <span class="o">+</span> <span class="n">abcd</span><span class="p">[...,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">(</span><span class="n">t</span><span class="o">+</span><span class="n">p</span><span class="p">.</span><span class="n">indices</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nb">max</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span><span class="p">)</span><span class="o">/</span><span class="n">self</span><span class="p">.</span><span class="n">num_segments</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">p</span><span class="p">.</span><span class="n">values</span><span class="p">,</span> <span class="p">(</span><span class="n">t</span><span class="p">[...,</span> <span class="mi">0</span><span class="p">].</span><span class="nf">gather</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">p</span><span class="p">.</span><span class="n">indices</span><span class="p">))</span><span class="o">*</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="nb">max</span> <span class="o">-</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span><span class="p">)</span> <span class="o">+</span> <span class="n">self</span><span class="p">.</span><span class="nb">min</span>

    <span class="k">def</span> <span class="nf">mean</span><span class="p">(</span><span class="n">self</span><span class="p">):</span>
        <span class="k">return</span> <span class="n">einops</span><span class="p">.</span><span class="nf">einsum</span><span class="p">(</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">,</span> <span class="n">torch</span><span class="p">.</span><span class="nf">tensor</span><span class="p">([</span><span class="mi">1</span><span class="o">/</span><span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="o">/</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span><span class="n">device</span><span class="o">=</span><span class="n">self</span><span class="p">.</span><span class="n">coefficients</span><span class="p">.</span><span class="n">device</span><span class="p">),</span> <span class="sh">'</span><span class="s">... s f, f -&gt; ...</span><span class="sh">'</span><span class="p">)</span> <span class="o">/</span> <span class="n">self</span><span class="p">.</span><span class="n">num_segments</span>

</code></pre></div></div>


  </div>
</details>

<h3 id="spline-formula">Non-equidistant Knot Case</h3>

<details class="collapse-section">
  <summary>Spline Formula</summary>
  <div class="collapse-content">
    
<p><a href="https://mathworld.wolfram.com/CubicSpline.html">WolframMathWorld</a> goes through the derivation of spline curve coefficients when the knots are equally spaced, but the case for arbitrary knots was hard to find for me.</p>

<p>The derivation is very similar but with few changes. First, we need to setup the problem.</p>

<p>Given a list of $ n $ knots $(x_i, y_i)$ where $ 0 = x_0 &lt; x_1 &lt; … &lt; x_{i-1} &lt; x_n = 1$</p>

<p>We want to define the spline curve $S(x)$ where $ 0 \leq x \leq 1 $.</p>

<p>We define</p>

<ul>
  <li>
    <p>Segment lengths $ \Delta x_i = x_{i+1} - x_i $</p>
  </li>
  <li>
    <p>$ S(x) = S_i(t) = a_i t^3 + b_i t^2 + c_i t + d_i $ where $ x_i \leq x \leq x_{i+1}$, $t = \frac{x-x_i}{\Delta x_i}$</p>
  </li>
</ul>

<p>Then</p>

\[\begin{align}
S_i(0) &amp; = y_i &amp; = d_i \\
S_i(1) &amp; = y_{i+1} &amp; = a_i + b_i + c_i + d_i \\
\end{align}\]

<p>When taking the derivative, we want it in respect to $x$ instead of $t$</p>

\[\begin{align}
S'_i(0) &amp; = D_i &amp; = \frac{c_i}{\Delta x_i} \\
S'_i(1) &amp; = D_{i+1} &amp; = \frac{(3a_i + 2b_i + c_i)}{\Delta x_i}
\end{align}\]

<p>Solving for the coefficients gives</p>

\[\begin{align}
d_i &amp; = y_i \\
c_i &amp; = D_i \\
b_i &amp; = 3(y_{i+1} - y_i) - (2 D_i + D_{i+1}) \Delta x_i\\
a_i &amp; = 2(y_i - y_{i+1}) + (D_i + D_{i+1}) \Delta x_i
\end{align}\]

<p>We now set the second derivative condition of natural spline</p>

\[\begin{align}
S''_0(0)  = &amp; 0 \\
S''_{i-1}(1)  = &amp; S''_i(0) \\
S''_n(1)  = &amp; 0 
\end{align}\]

<p>Substituting the coefficients, we get</p>

\[3y_1 - 3y_0 = \Delta x_0 D_1  +2 \Delta x_0 D_0\]

\[\begin{multline}
 3 \frac{\Delta x_{i-1}}{\Delta x_i} y_{i+1} + 3( \frac{\Delta x_i}{\Delta x_{i-1}} - \frac{\Delta x_{i-1}}{\Delta x_i}) y_i - 3 \frac{\Delta x_i}{\Delta x_{i-1}} y_{i-1} \\
 =  \Delta x_{i-1} D_{i+1} + 2 ( \Delta x_i + \Delta x_{i-1}) D_i + \Delta x_i D_{i-1}
\end{multline}\]

\[3y_n - 3y_{n-1} = \Delta x_n D_n +2 \Delta x_n D_{n-1}\]

<p>Which can be written as the matrix</p>

\[\begin{split}
\begin{bmatrix}
2 \Delta x_0 &amp; \Delta x_0 &amp;  &amp;  &amp; \cdots \\
\Delta x_1 &amp; 2(\Delta x_0 + \Delta x_1) &amp; \Delta x_0 &amp;  &amp; \cdots \\
 &amp; \Delta x_2 &amp; 2(\Delta x_1 + \Delta x_2) &amp; \Delta x_1 &amp; \cdots \\
\vdots &amp; \vdots &amp; \vdots &amp; \vdots &amp; \ddots &amp; \\
 &amp; &amp; \Delta x_{n-1} &amp; 2(\Delta x_{n-2} + \Delta x_{n-1}) &amp; \Delta x_{n-2} \\
 &amp; &amp; &amp; \Delta x_n &amp; 2 \Delta x_n
\end{bmatrix}

\begin{bmatrix}
D_0 \\
D_1 \\
D_2 \\
\vdots \\
D_{n-1} \\
D_n
\end{bmatrix}
\\
= 3 
\begin{bmatrix}
y_1 - y_0 \\
\frac{\Delta x_0}{\Delta x_1} y_2 + ( \frac{\Delta x_1}{\Delta x_0} - \frac{\Delta x_0}{\Delta x_1}) y_1 - \frac{\Delta x_1}{\Delta x_0} y_0 \\
\frac{\Delta x_1}{\Delta x_2} y_3 + ( \frac{\Delta x_2}{\Delta x_1} - \frac{\Delta x_1}{\Delta x_2}) y_2 - \frac{\Delta x_2}{\Delta x_1} y_1 \\
\vdots \\
\frac{\Delta x_{n-2}}{\Delta x_{n-1}} y_n + ( \frac{\Delta x_{n-1}}{\Delta x_{n-2}} - \frac{\Delta x_{n-2}}{\Delta x_{n-1}}) y_{n-1} - \frac{\Delta x_{n-1}}{\Delta x_{n-2}} y_{n-2} \\
y_n - y_{n-1}
\end{bmatrix}
\end{split}\]


  </div>
</details>

<h3 id="spline-mean">mean values of cubic spline</h3>

<details class="collapse-section">
  <summary>Spline Mean</summary>
  <div class="collapse-content">
    
\[\begin{align}
\mu &amp; = \int_{\min}^{\max} \frac{S(x)}{(\max - \min)} dx\\
 &amp; = \frac{1}{x_{n+1}-x_0}\sum_{i=0}^n \int_{x_i}^{x_{i+1}} (a_i t^3 + b_i t^2 + c_i t + d_i) dx &amp; \text{where } &amp; t = \frac{x-x_i}{x_{i+1} - x_i} \text{,} \\ &amp; &amp; &amp; dx = (x_{i+1}-x_i) dt\\
 &amp; = \frac{1}{x_{n+1}-x_0}\sum_{i=0}^n \int_{0}^{1} (a_i t^3 + b_i t^2 + c_i t + d_i) (x_{i+1} - x_i) dt \\
 &amp; = \frac{1}{x_{n+1}-x_0}\sum_{i=0}^n \left[\frac{a_i t^4}{4} + \frac{b_i t^3}{3} + \frac{c_i t^2}{2} + d_i t \right]^1_0 (x_{i+1} - x_i) \\
 &amp; = \frac{1}{x_{n+1}-x_0}\sum_{i=0}^n (\frac{a_i}{4} + \frac{b_i}{3} + \frac{c_i}{2} + d_i) (x_{i+1} - x_i) \\
 &amp; = \frac{ \left[\frac{1}{4}, \frac{1}{3}, \frac{1}{2}, 1 \right]_j \left[ a, b, c, d \right]_{ij}  \Delta x_i}{x_{n+1} - x_0}
\end{align}\]


  </div>
</details>]]></content><author><name>Kazumi</name></author><category term="reinforcement-learning" /><category term="dqn" /></entry></feed>