Jekyll2023-12-12T23:26:11+00:00http://willbeckman.com/feed.xmlWill Beckman blogPosting publicly, but mostly archiving for my future self to revisit.Wordle, Entropy, and… Decision Trees?2023-12-12T16:00:00+00:002023-12-12T16:00:00+00:00http://willbeckman.com/wordle<center>
<p><image height="200px" src="assets/posts/2023-12-12-wordle/title.png"></image></p>
</center>
<p style="text-align: center;">The code associated with this blog post can be found
<a target="_blank" href="https://github.com/wbeckman/wordlebot">here</a>.</p>
<h2 id="table-of-contents-">Table of Contents <!-- omit in toc --></h2>
<ul>
<li><a href="#intro">Intro</a></li>
<li><a href="#the-game-of-wordle">The Game of Wordle</a></li>
<li><a href="#strategy">Strategy</a>
<ul>
<li><a href="#strategy-overview">Strategy Overview</a></li>
<li><a href="#enter-entropy">Enter: Entropy</a></li>
<li><a href="#what-is-entropy">What Is Entropy?</a></li>
</ul>
</li>
<li><a href="#practical-matters">Practical Matters</a>
<ul>
<li><a href="#choosing-word-lists">Choosing Word Lists</a></li>
<li><a href="#writing-and-assessing-an-entropy-based-solver">Writing and Assessing an Entropy-Based Solver</a></li>
<li><a href="#optimality-worst-case-scenarios">Optimality, Worst Case Scenarios</a></li>
<li><a href="#wordlebot">Wordlebot</a></li>
</ul>
</li>
<li><a href="#wait-but-what-about-decision-trees">Wait, But What About Decision Trees?</a></li>
<li><a href="#conclusion">Conclusion</a></li>
</ul>
<h2 id="intro">Intro</h2>
<p>I’m a little late to the party, but I wanted to try my hand at creating a Wordle solver and see how good I could make it. It turns out, if you exploit some information from the New York Times list of allowed words, you can make a solver that is very good. Let’s dive in!</p>
<h2 id="the-game-of-wordle">The Game of Wordle</h2>
<p>If you’re like me and you play Wordle every day, you can skip this part, because it’s just an overview of how the game works.</p>
<p>The setup is as follows: there is a secret 5-letter word chosen by the New York Times (NYT). Your goal is to guess that word within six guesses. Each guess gives you a piece of feedback on how correct that word is in the form of “blocks”:</p>
<ul>
<li>A black block (⬛) means that the letter is NOT contained in the hidden word.</li>
<li>A yellow block (🟨) means that the letter is contained in the hidden word, but the letter’s position in the hidden word differs from our guess.</li>
<li>A green block (🟩) means that the letter is contained in the hidden word in the SAME position as our guess.</li>
</ul>
<center>
<p><image height="200px" src="assets/posts/2023-12-12-wordle/wordle_example.png"></image></p>
</center>
<p>In the example above, my first two guesses tell me that the letters “S”, “L”, and “E” are in the target word (i.e. “SOLVE”), but not in the positions I guessed them. With my final guess, there were two possible words left: SMILE and SOLVE, and I got lucky.</p>
<p>There are a few final technicalities. If there are two of the same letters guessed and that letter only appears once in the hidden word, one letter will be yellow/green and the other will be black. This is shown in the guess ‘TREES’ - the third-position E is yellow and the fourth-position E is black, because there is only one ‘E’ in the target word (SOLVE).</p>
<p>Finally, green squares will ALWAYS resolve before yellow and black squares. For example, if the hidden word were “CRIER” instead of “SOLVE”, the clue for ‘TREES’ would be given as ⬛🟩⬛🟩⬛ rather than ⬛🟩🟨⬛⬛. Even though the third-position ‘E’ comes before the fourth-position ‘E’, the fourth-position ‘E’ is green, and therefore, takes precedence.</p>
<p>If the rules above are unclear, <a href="https://www.nytimes.com/games/wordle/index.html" target="_blank">playing the game</a> on the New York Times’ website will certainly elucidate the rules for you. I learn by doing, and this is especially true with games for me.</p>
<h2 id="strategy">Strategy</h2>
<h3 id="strategy-overview">Strategy Overview</h3>
<p>My Wordle solver is based on the concept of <strong>entropy</strong>. To understand entropy, we are going to explore it using a vocabulary of just five carefully chosen words. These words are: <strong>BILLS, HILLS, MILLS, THUMB,</strong> and <strong>TILLS</strong>. Let’s pretend that our first guess was TILLS. We see the following result:</p>
<center>
<p><image height="50px" src="assets/posts/2023-12-12-wordle/tills.png"></image></p>
</center>
<p>At first, this seems pretty nice - we got four characters correct. At first glance, it seems like we almost solved the puzzle. But, if you look closer, you can see there are still three words in our vocabulary which match this feedback: BILLS, MILLS, and HILLS. Therefore, in our worst case, we are still making four guesses to determine the solution.</p>
<p>But what if we were to guess THUMB instead? Pretend we guess THUMB and we are given the following feedback:</p>
<center>
<p><image height="50px" src="assets/posts/2023-12-12-wordle/thumb.png"></image></p>
</center>
<p>By the colors alone, it might seem worse than TILLS - only one of the letters is in the hidden word. But actually, this guess solves the puzzle for us in the next turn. No other word contains an “H”, so this gives away that HILLS is the correct answer.</p>
<p>So, what we care about for our strategy is - <strong>given a guess word, across ALL possible pieces of feedback the puzzle could give us, how much does this guess reduce the search space on average</strong>? It seems THUMB is better at this than TILLS, but how can we show this quantitatively?</p>
<h3 id="enter-entropy">Enter: Entropy</h3>
<p>Keeping with our vocabulary from before, we can formalize <em>how good each guess is by how much it eliminates other words in the vocabulary on average</em>, assuming all words are equally likely to be the hidden word. “HILLS”, “TILLS”, “MILLS”, and “BILLS” are all equally good guesses, so we will just look at “THUMB” and “TILLS” as examples.</p>
<p>If we guess “THUMB”, what are the possible patterns we could see? We can enumerate them here:</p>
<ul>
<li>If we see: 🟩🟩🟩🟩🟩 - the hidden word is “THUMB” (probability $\frac{1}{5}$)</li>
<li>If we see: 🟩⬛⬛⬛⬛ - the hidden word is “TILLS” (probability $\frac{1}{5}$)</li>
<li>If we see: ⬛🟨⬛⬛⬛ - the hidden word is “HILLS” (probability $\frac{1}{5}$)</li>
<li>If we see: ⬛⬛⬛🟨⬛ - the hidden word is “MILLS” (probability $\frac{1}{5}$)</li>
<li>If we see: ⬛⬛⬛⬛🟨 - the hidden word is “BILLS” (probability $\frac{1}{5}$)</li>
</ul>
<p>So, in the WORST case, we solve the puzzle in two tries guessing “THUMB” as our first guess. Lets look at the same breakdown for if we guess “TILLS”:</p>
<ul>
<li>If we see: 🟩🟩🟩🟩🟩 - the hidden word is “TILLS” (probability $\frac{1}{5}$)</li>
<li>If we see: 🟩⬛⬛⬛⬛ - the hidden word is “THUMB” (probability $\frac{1}{5}$)</li>
<li>If we see: ⬛🟩🟩🟩🟩 - the hidden word is in [“HILLS”, “MILLS”, “BILLS”] (probability $\frac{3}{5}$)</li>
</ul>
<p>So, if we guess “TILLS”, there is a bucket with more than one word that we will hit with 60% probability, meaning we are not guaranteed to solve the puzzle in two turns. If we iterate over all words in our vocabulary and <em>bucket all possible hidden words into their corresponding patterns</em> (an $O(|vocab|^2)$ runtime operation), this prepares us to compute a value called <strong>entropy</strong>. The formula for entropy is:</p>
<center><p>$
Entropy(word)=-\sum_{i}{p(i) * \log_2{p(i)}}
$</p></center>
<p>The ‘$i$’ in the above equation for us is ‘patterns’. So, <strong>for each possible pattern, we take the probability of seeing that pattern ($p(i)$), and multiply it by $log_2{p(i)}$</strong>. Then, this value is summed for all possible patterns and multiplied by $-1$. The entropy for “THUMB” is then:</p>
<center><p>$
Entropy(THUMB)=-\sum_{i=1}^5{\frac{1}{5} * \log_2{\frac{1}{5}}} \approx 2.32
$</p></center>
<p>Now, computing the entropy for TILLS:</p>
<center><p>$
Entropy(TILLS)=-(\frac{1}{5} * \log_2{\frac{1}{5}} + \frac{1}{5} * \log_2{\frac{1}{5}} + \frac{3}{5} * \log_2{\frac{3}{5}}) \approx 1.37
$</p></center>
<p>And we can see that, from an entropy perspective, THUMB is a quantifiably better guess than TILLS. We are effectively <strong>turning each word into a discrete probability distribution over all possible patterns that we could see</strong>, and then <strong>computing an information index (entropy) for that probability distribution</strong>. This translation of word –> probability distribution is an important concept that we will discuss in the next section.</p>
<h3 id="what-is-entropy">What Is Entropy?</h3>
<p>I mentioned how we calculate entropy, but I didn’t really discuss what it is or why it’s useful. Entropy is an attempt to measure “disorder” in a probability distribution. Probability distributions with high entropies will split <em>more events more evenly</em>. <a href="https://en.wikipedia.org/wiki/Claude_Shannon" target="_blank">Claude Shannon</a> is the one who coined entropy, and he wanted to describe “uncertainty” in a probability distribution in terms of <strong>bits</strong>.</p>
<p>As an extremely simple example, we can consider a fair coin ($P(heads) = P(tails) = \frac{1}{2}$) and a fair 4-sided die ($P(1) = P(2) = P(3) = P(4) = \frac{1}{4}$). For this example, let’s pretend we can observe the result of a coin flip and a die roll and <em>we need to transmit the result to a person that we cannot see using a binary encoding</em>. We can encode the coin flip in a single bit - literally sending <code class="language-plaintext highlighter-rouge">0</code> or <code class="language-plaintext highlighter-rouge">1</code> based on whether we observe heads or tails. The die, however must be encoded with two bits - the four results must be encoded as <code class="language-plaintext highlighter-rouge">00</code>, <code class="language-plaintext highlighter-rouge">01</code>, <code class="language-plaintext highlighter-rouge">10</code>, and <code class="language-plaintext highlighter-rouge">11</code>.</p>
<p>So, there is more information present in the die roll, because the encoded message has to be larger. <em>Entropy is a measure of how to quantify the amount of information present in both of these events</em>. This is abstract! We can calculate the entropy of the coin flip and the die roll using our equation above:</p>
<center><p>
$Entropy(COIN)-\sum_{i=1}^2{\frac{1}{2} * \log_2{\frac{1}{2}}}=1.0$
</p></center>
<center><p>
$Entropy(4DIE)-\sum_{i=1}^4{\frac{1}{4} * \log_2{\frac{1}{4}}}=2.0$
</p></center>
<p>As you can see, <strong>the entropy values of each random variable perfectly represent the number of bits needed to send the messages</strong>. (This is a bit of a contrived example, we have to round up to truly encode the information - a 3-sided die $P(1)=0.98$, $P(2)=0.01$, and $P(3)=0.01$ only has an entropy value of 0.08 bits, but we still need a way to transmit the unlikely event of a TAILS event occurring.)</p>
<p>Another interpretation of entropy is as “surprise”. If we are more surprised (i.e. <em>probability distribution is close to uniform</em>), entropy is usually higher and we need more bits to represent the information in the probability distribution. If a small number of events happen very frequently (such as with our biased coin), we are less surprised, and the entropy of the distribution is lower. So, in the context of our Wordle words, <strong>words with high average entropies are words that are likely to have a higher variety of possible patterns with a relatively even distribution among them</strong>.</p>
<p>Looking ahead for a second, I calculated the pattern probability distributions for the best word - ‘TARSE’ and the worst word - ‘QAJAQ’. The sum of words in all buckets for each word is 2309 - the amount of possible hidden words in the NYT solutions list. You can see how massively these distributions differ (particularly, note the scale on the y-axis)!</p>
<center>
<p><image width="600px" src="assets/posts/2023-12-12-wordle/tarse_distributions.png"></image></p>
<p><image width="600px" src="assets/posts/2023-12-12-wordle/qajaq_distributions.png"></image></p>
</center>
<p>You can interpret this as - after guessing ‘QAJAQ’, it’s extremely likely (roughly 60% chance) that the hidden word will give us a feedback of ⬛⬛⬛⬛⬛. This would reduce the initial state space of 2309 words to roughly $P(⬛⬛⬛⬛⬛) * 2309\approx 0.6 * 2309\approx1385$ words. This is bad! This means that much of the time, our state space will not even be halved if we guess ‘QAJAQ’. This makes sense intuitively - ‘J’, and ‘Q’ are very uncommon letters in the English language, so it makes sense that this is a bad word. On the other hand, even the unluckiest pattern for ‘TARSE’ decreases the state space by a factor of roughly 10x to $\approx 0.1* 2309\approx230.9$, and there are many more patterns that decrease the state space by an even larger factor.</p>
<p>So, because TARSE has so many more possible patterns (147) and has a relatively even distribution amongst them, it has an entropy of 5.95. Similarly, because QAJAQ is overwhelmingly likely to yield ⬛⬛⬛⬛⬛ as a response from the puzzle, and the ‘⬛⬛⬛⬛⬛’ bucket is much larger than all of the other ones, it has a correspondingly low entropy - 1.89. Believe it or not, this concept of computing entropy based on word-pattern buckets is all you need to know to make a Wordle solver!</p>
<h2 id="practical-matters">Practical Matters</h2>
<h3 id="choosing-word-lists">Choosing Word Lists</h3>
<p>Before going into the details of the algorithm, we need to know what our vocabulary is. There are a few well-accepted word lists that the NYT uses for Wordle. First, there is a list of words that they will <strong>allow you to guess, but will never be the solution</strong>. These are typically less common words and plural forms of 4-letter singular words. Then, there is a <strong>shorter list of words that could possibly be the hidden word</strong>. <strong>This word list is a subset of the larger word list</strong>. These lists have changed sizes over time. Most notably, the list of possible guess words has expanded in recent years:</p>
<table>
<thead>
<tr>
<th> </th>
<th>All possible guesses</th>
<th>All possible hidden words</th>
</tr>
</thead>
<tbody>
<tr>
<td>Number of Words - Before 08/29/2022</td>
<td>12,972</td>
<td>2309</td>
</tr>
<tr>
<td>Number of Words - After 08/29/2022</td>
<td>14,885</td>
<td>2309</td>
</tr>
</tbody>
</table>
<p>(This information was taken from a <a href="https://sonorouschocolate.com/notes/index.php/The_best_strategies_for_Wordle,_part_2" target="_blank">great blog post by Alex Selby</a>).</p>
<p>I chose to use the more recent word lists in my solver. I computed information ALL possible words based on which hidden words are still viable.</p>
<h3 id="writing-and-assessing-an-entropy-based-solver">Writing and Assessing an Entropy-Based Solver</h3>
<p>The process for writing the solver is identical to the process I described above with our small vocabulary, only now, there are far more words. I computed the entropy of all 14,885 words in the dictionary against all 2,309 words in the NYT hidden words list. Because the lists are static, the best first word is always fixed. From an entropy perspective, I mentioned the best first word is always “<a href="https://en.wiktionary.org/wiki/tarse" target="_blank">TARSE</a>”, with an entropy value of 5.95. If you’ve never heard of the word ‘TARSE’, you’re not alone, but Wordle allows you to guess it. Other good options are ‘TIARE’ with an entropy of 5.93, ‘SOARE’ with an entropy of 5.89, ‘ROATE’ with an entropy of 5.88, and perhaps the first “real” word in the top 5 - ‘RAISE’ with an entropy of 5.88.</p>
<p>Once the first guess is made, the puzzle gives you a piece of feedback. Then, we prune the state space of 2309 words based on which words could possibly still be the solution, based on the feedback we got. Then, we recompute entropy for the new subset of words and choose the next best word. This process is then repeated until the puzzle is solved.</p>
<p>On the list of 2309 NYT hidden words, the solver takes 3.61 guesses on average and never takes more than 6 guesses to guess the hidden word. This is sort of a “cheat mode”, because it assumes we have access to the hidden words list AND the “guessable” words list. This performance is likely overfit to the choice of words in each list. However, if you gave a person access to these two lists while they were playing, it’s unlikely it would change their play style much at all.</p>
<h3 id="optimality-worst-case-scenarios">Optimality, Worst Case Scenarios</h3>
<p>It’s also worth noting that this entropy-based strategy <em>is not optimal</em>! It is guaranteed to reduce the guess space as much as it possibly can on average with each guess, but <strong>it does not guarantee future word groups will have good splits, on average</strong>. In other words, <strong>this strategy is a greedy strategy</strong>, based on the heuristic of entropy. While it’s not optimal, it probably has above average human performance. This is also probably why my solver chooses the best word to be ‘TARSE’, while the NYT bot chooses the best opening word to be ‘CRANE’ - I imagine they put a little more computational power into looking a few steps ahead. Some people<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> have also put a lot of compute hours into finding a tree that is optimal for Wordle play.</p>
<p>Another interesting tidbit I came across in doing research for this blog post is a worst-case Wordle puzzle proof. I found a <a href="https://alexpeattie.com/blog/establishing-minimum-guesses-wordle/" target="_blank">fantastic blog post by Alex Peattie</a> detailing the minimum number of guesses needed to solve Wordle, in the worst case (i.e. you are as unlucky as you possibly could be). Rather than take a brute force computational approach, he describes a game of ‘ILLS’ wordle, where all of the possible solution words end with the letter ‘ILLS’. There are 19 such words in the NYT allowable vocabulary list: {b,c,d,f,g,h,j,k,l,m,n,p,r,s,t,v,w,y,z}-ills. His blog post proves that as long as the vocabulary is limited to the 12972 words listed on the NYT front-end, and that all of the ‘-ILLS’ words are contained in the solution set, in the worst case, the minimum number of guesses required is 6. This result holds for an <em>expanded game of Wordle</em>, where the hidden word set is the same as the 12,972 possible guess words.</p>
<h3 id="wordlebot">Wordlebot</h3>
<p>This is just a plug to a CLI tool that I built to help solve Wordle puzzles, if you feel like cheating on a given day ;). You can go to the <a href="https://github.com/wbeckman/wordlebot" target="_blank">wordlebot</a> Github repo, clone it, and run <code class="language-plaintext highlighter-rouge">pip install .</code></p>
<p>This gives you a CLI utility that you can run by typing <code class="language-plaintext highlighter-rouge">wordlebot</code>. If you want to see what it would guess in a given situation without ruining the Wordle puzzle of the day, <a href="https://wordleunlimited.org/" target="_blank">Wordle Unlimited</a> shares a very similar vocabulary list with <code class="language-plaintext highlighter-rouge">wordlebot</code> (although ‘TARSE’ is not in its allowed guesses list, so you might have to open with ‘SOARE’ instead. This is because Wordle Unlimited uses the pre-08/29/2022 vocabulary list of 12,972 words). If multiple words are as good as each other, it will recommend a word that could also be the solution.</p>
<h2 id="wait-but-what-about-decision-trees">Wait, But What About Decision Trees?</h2>
<p>Finally, to <em>really</em> drill the concept of entropy home, we are going to talk about its role in decision trees. As it turns out, <em>the decision tree training process uses entropy to determine the next best split!</em> Specifically, it computes the entropy of the current dataset and then compares it to what the average entropy would be between the resultant children datasets after a split. If the average entropy of the children is LOWER we prefer the split. The optimal split is the split that provides us with the <em>lowest average child entropy</em>. Calculating entropy for the parent dataset and the two children datasets is a metric called <strong>information gain</strong>. The formula for information gain is as follows:</p>
<center>
$Gain(feature)=Entropy(Parent)-WeightedEntropy(Children)$
</center>
<p>Let’s look at a small dataset to exemplify this. This is a binary classification dataset I generated with a single informative feature, ‘X1’.</p>
<center>
<p><image width="600px" src="assets/posts/2023-12-12-wordle/toy_example.png"></image></p>
</center>
<p>In this case, we calculate entropy by looking at the probability of the <em>classes</em> (rather than the patterns, like before). For example, in this dataset, there are 50 positive class instances and 50 negative class instances. The entropy of the <em>parent dataset</em> is exactly $\frac{50}{100} * \log_2{\frac{50}{100}} + \frac{50}{100} * \log_2{\frac{50}{100}}=1.0$.</p>
<p>Now, pretend I were to make a split at <code class="language-plaintext highlighter-rouge">X1=0.0</code>. This partitions the dataset into two smaller datasets - the left half with 58 instances and the right half with 42 instances.</p>
<center>
<p><image width="600px" src="assets/posts/2023-12-12-wordle/example_partition.png"></image></p>
</center>
<p>When we calculate the <em>weighed entropy</em> of each of those halves, we get</p>
<center>
$WeightedEntropy(LeftSide)=-\frac{58}{100} \times (\frac{9}{58} * \log_2{\frac{9}{58}} + \frac{49}{58} \times \log_2{\frac{49}{58}})\approx 0.36$
</center>
<center>
$WeightedEntropy(RightSide)=-\frac{42}{100} \times (\frac{41}{42} * \log_2{\frac{41}{42}} + \frac{1}{42} * \log_2{\frac{1}{42}})\approx 0.068$
</center>
<center>
$WeightedEntropy(Children)=0.36+0.068\approx 0.428$
</center>
<p>This weighted average is lower than the parent entropy! So, to get our value for information gain, we compute:</p>
<center>
$Gain(feature)=Entropy(Parent)-WeightedEntropy(Children)=1.0-0.428=0.572$.
</center>
<p>This split is definitely an improvement over the original dataset, but can we do better? The way a decision tree finds the best split is by iterating over each possible value of X1 and calculating its information gain. This is a visualization of that process:</p>
<center>
<video loop="" autoplay="">
<source src="assets/posts/2023-12-12-wordle/best_splits.mp4" />
</video>
</center>
<p>It looks like the optimal split is actually very close to 0.0 - at 0.04, giving us an information gain of 0.64! In practice, the <a href="https://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity" target="_blank">Gini impurity</a> is a computationally fast way to compute an entropy-like value. Instead of $-\sum_i{\log_2{p(i)} * p(i)}$, Gini impurity computes $1-\sum_i{p(i)^2}$. Avoiding this <code class="language-plaintext highlighter-rouge">log</code> can speed up decision tree training for a large number of features.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Thanks for making it this far. Hopefully this has been an educational journey on Wordle solvers, decision trees, and how the two relate. If you notice any errors in the article, don’t hesitate to reach out to me at will@willbeckman.com.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p><a href="https://www.poirrier.ca/notes/wordle/" target="_blank">Laurent Poirrier’s blog detailing all the recent Wordle research</a> <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Productionizing ML Models - Part 12023-09-26T16:00:00+00:002023-09-26T16:00:00+00:00http://willbeckman.com/flask-api<center>
<p><image height="200px" src="assets/posts/2023-09-26-flask-api/agaricus-campestris.jpg"></image></p>
<p><em>Delicious supplement to dinner? Or deadly poison?<br />(image from <a href="https://en.wikipedia.org/wiki/Agaricus#/media/File:Agaricus-campestris-michoacan.jpg">Wikipedia</a>)</em></p>
</center>
<p style="text-align: center;">The code associated with this blog post can be found
<a target="_blank" href="https://github.com/wbeckman/flask-api">here</a>.</p>
<h2 id="table-of-contents-">Table of Contents <!-- omit in toc --></h2>
<ul>
<li><a href="#background">Background</a></li>
<li><a href="#training-a-model">Training a Model</a></li>
<li><a href="#setting-up-a-flask-scoring-server">Setting up a Flask Scoring Server</a>
<ul>
<li><a href="#creating-a-post-endpoint">Creating a POST Endpoint</a></li>
<li><a href="#model-validations">Model Validations</a></li>
<li><a href="#creating-a-feature-vector">Creating a Feature Vector</a></li>
<li><a href="#scoring-with-a-model-object">Scoring With a Model Object</a></li>
</ul>
</li>
<li><a href="#conclusion">Conclusion</a></li>
</ul>
<h2 id="background">Background</h2>
<p>Most of the time, ML models are objects that contain information about how to make predictions within a very specific problem context. Because the knowledge of how to make a decision is built into the object itself through the process of <em>model training</em>, they often do not require an additional database to consult when making decisions. This unique attribute of ML models makes model hosting servers simpler than many other applications. Typically, with the exception of extremely large models, a model can live in-memory on a server/within a container. This post will go over how easy it is to set up a simple API using flask to host your ML models, and then a future post will cover how to scale up this API by dockerizing it and deploying it to a container hosting service.</p>
<p>One caveat for this tutorial is that it assumes all of the data for model scoring can be POST’d to your API, or can easily be obtained with a quick pre-processing step requiring no outside information (min-max scaling, normalization, binning, etc.). This may not be the case, depending on your application. For example, your company may require you to retrieve information about an entity (e.g. a customer) and only pass you the unique identifier of the entity (e.g. customer ID). This blog post also assumes that all requests coming in to the server should be trusted (e.g. this server sits inside of a private VPC). Because of this, there will be no authentication checks for the incoming requests, but it would be easy enough to check if incoming requests are properly authorized with token-based authentication.</p>
<h2 id="training-a-model">Training a Model</h2>
<p>The first thing that we have to do in order to serve a model is to train a model. I trained a model for the UCI Mushroom dataset. If you know me, you know I am an avid mushroom forager, so what would be better than using a dataset with personal meaning? This dataset contains a number of categorical variables describing qualitative features of the mushrooms and a target variable denoting “poisonous” (1) or “edible” (0). The mushroom samples are all limited to the genera <em>Lepiota</em> and <em>Agaricus</em>. This should go without saying, but please don’t use this dataset to assess whether or not a mushroom is edible :).</p>
<center>
<p>
<img src="assets/posts/2023-09-26-flask-api/reddening_lepiota.jpg" height="200px" alt="" />
<img src="assets/posts/2023-09-26-flask-api/lepiota_brunneoincarnata.jpg" height="200px" alt="" />
</p>
<p>
<em>Two mushrooms of the Lepiota genus. The reddening Lepiota on the left/top is an edible mushroom. The other mushroom is the deadly poisonous Lepiota brunneoincarnata! (images from <a href="https://en.wikipedia.org/wiki/Lepiota_brunneoincarnata">Wikipedia</a> and <a href="https://mdc.mo.gov/discover-nature/field-guide/reddening-lepiota">MO Dept. of Conservation</a>)</em>
</p>
</center>
<p>Because all of the input variables are categorical, I chose to use a simple decision tree model, which lends itself to these types of features. I one-hot encoded the categorical features as input to the model and filtered out a number of categories based on feature importance, which left me with 6 features. Not a lot of time went into fine-tuning the model, but the model achieves an accuracy score of >99.6% on the holdout set (which is still not good enough to say for sure whether or not you should eat potentially lethal mushrooms!). These are the six relevant categorical features in our model:</p>
<ul>
<li>odor == None (n)? [N(0)/Y(1)]</li>
<li>stalk_root == club (c)? [N(0)/Y(1)]</li>
<li>stalk_surface_below_ring == scaly (y)? [N(0)/Y(1)]</li>
<li>spore_print_color == green (r)? [N(0)/Y(1)]</li>
<li>odor == almond (a)? [N(0)/Y(1)]</li>
<li>odor == anise (l)? [N(0)/Y(1)]</li>
</ul>
<p>The corresponding correlation matrix for these features is as follows:</p>
<center>
<p>
<img src="assets/posts/2023-09-26-flask-api/correlation_coefficient_matrix.png" height="400px" alt="" />
</p>
<p>
<em>Feature/target correlation coefficient matrix</em>
</p>
</center>
<p>Since we are not thinking too critically on the “business use case” of this model and really only care about model serving, this is fine for our purposes. As a small aside, in the real world, we would want a model like this to have <em>perfect recall</em>. Not poisoning people is much more important than having a few edible mushrooms falsely marked as inedible. So, ideally, a model like this would have 0 false negatives.</p>
<h2 id="setting-up-a-flask-scoring-server">Setting up a Flask Scoring Server</h2>
<p>Now that we have a model, we need a way to allow other services to query it for its score. We are going to make the simplest GET route on a flask server. Flask will need to be installed in the environment we are working with using <code class="language-plaintext highlighter-rouge">pip install flask</code>. Then, we can create a file called <code class="language-plaintext highlighter-rouge">app.py</code> with the following code:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="kn">from</span> <span class="nn">flask</span> <span class="kn">import</span> <span class="n">Flask</span><span class="p">,</span> <span class="n">request</span><span class="p">,</span> <span class="n">Response</span>
<span class="n">app</span> <span class="o">=</span> <span class="n">Flask</span><span class="p">(</span><span class="n">__name__</span><span class="p">)</span>
<span class="o">@</span><span class="n">app</span><span class="p">.</span><span class="n">route</span><span class="p">(</span><span class="s">'/'</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">home</span><span class="p">():</span>
<span class="k">return</span> <span class="s">"<p>Hello World</p>"</span></code></pre></figure>
<p>Now, from this directory, we can run <code class="language-plaintext highlighter-rouge">flask run --port 8000</code>. This will start a flask server on port 8000. To check to see if the server is up and running, in another terminal window, we can run <code class="language-plaintext highlighter-rouge">curl localhost:8000</code> or visit localhost:8000 in a browser window. The response should be: “<p>Hello World</p>”. Great - this is the most barebones flask app that we can run. There is a single endpoint, <code class="language-plaintext highlighter-rouge">/</code>, that returns a static string no matter what the request contains.</p>
<h3 id="creating-a-post-endpoint">Creating a POST Endpoint</h3>
<p>We are going to add another endpoint to this application that uses the POST method and takes in as payload input features for our model. It will receive this input as JSON, parse it out, feed it into our model, and produce a score. To begin, we will create another function in <code class="language-plaintext highlighter-rouge">app.py</code> that literally just echoes our own request back to us as a string:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">app</span><span class="p">.</span><span class="n">route</span><span class="p">(</span><span class="s">'/score'</span><span class="p">,</span> <span class="n">methods</span><span class="o">=</span><span class="p">[</span><span class="s">"POST"</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">score</span><span class="p">():</span>
<span class="n">json_req</span> <span class="o">=</span> <span class="n">request</span><span class="p">.</span><span class="n">get_json</span><span class="p">()</span>
<span class="k">return</span> <span class="n">Response</span><span class="p">(</span><span class="n">response</span><span class="o">=</span><span class="nb">str</span><span class="p">(</span><span class="n">json_req</span><span class="p">),</span> <span class="n">status</span><span class="o">=</span><span class="mi">200</span><span class="p">)</span></code></pre></figure>
<p>Easy enough. We can now send valid JSON to the server and expect an echoed response. After restarting the server, we can run the following in our terminal to test:</p>
<p><code class="language-plaintext highlighter-rouge">curl -X POST -H 'Content-Type: application/json' -d '{"blah": "ok"}' http://localhost:8000/score</code></p>
<p>The local server should echo the request payload: <code class="language-plaintext highlighter-rouge">{'blah': 'ok'}</code>.</p>
<h3 id="model-validations">Model Validations</h3>
<p>Before we can create feature vectors to feed into our model, we need to validate that the raw categorical data that was POSTed to us is valid. The two checks that we need to make are that</p>
<ol>
<li>All of the features that we need to score the request are present (<code class="language-plaintext highlighter-rouge">odor</code>, <code class="language-plaintext highlighter-rouge">stalk_root</code>, <code class="language-plaintext highlighter-rouge">stalk_surface_below_ring</code>, <code class="language-plaintext highlighter-rouge">spore_print_color</code>)</li>
<li>All of the features listed above are given a valid value (must be category that we’ve seen in training before)</li>
</ol>
<p>If either of these two conditions fail, we will return a 400 Bad Request response indicating the input is malformed, and providing a descriptive message as to what is wrong with the request. For example, when we try to post the API with <code class="language-plaintext highlighter-rouge">{"odor": "x", "stalk_root": "c", "stalk_surface_below_ring": "y"}</code>, the API returns the response</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Not all columns were specified for model. Missing columns:
spore_print_color
</code></pre></div></div>
<p>Similarly, if we try and send a request not containing valid values for our inputs, we receive the following error message:</p>
<div class="language-plaintext highlighter-rouge"><div class="highlight"><pre class="highlight"><code>Input is invalid:
"x" is not a valid value for "odor"
</code></pre></div></div>
<p>If both of these validation checks succeed, we continue to creating a one-hot encoded feature vector from the raw categorical inputs.</p>
<h3 id="creating-a-feature-vector">Creating a Feature Vector</h3>
<p>I mentioned in the introduction that this model assumes a lot of nice things about the input we receive. In reality, feature vectors are often not so easy to obtain and require more complex systems, such as <strong>feature stores</strong>, to bring us up-to-date features. This post assumes that we receive the raw categorical information and that only one-hot encoding must be done on our server. The one-hot encoding function looks as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">extract_features</span><span class="p">(</span><span class="n">odor</span><span class="p">,</span> <span class="n">stalk_root</span><span class="p">,</span> <span class="n">stalk_surface_below_ring</span><span class="p">,</span> <span class="n">spore_print_color</span><span class="p">):</span>
<span class="s">"""
Turns categorical features into features that we care about for our model. Realistically,
this would probably be much more complex and would be done in a feature store.
"""</span>
<span class="k">return</span> <span class="p">(</span>
<span class="nb">int</span><span class="p">(</span><span class="n">odor</span> <span class="o">==</span> <span class="s">'n'</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">stalk_root</span> <span class="o">==</span> <span class="s">'c'</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">stalk_surface_below_ring</span> <span class="o">==</span> <span class="s">'y'</span><span class="p">),</span>
<span class="nb">int</span><span class="p">(</span><span class="n">spore_print_color</span> <span class="o">==</span> <span class="s">'r'</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">odor</span> <span class="o">==</span> <span class="s">'a'</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">odor</span> <span class="o">==</span> <span class="s">'l'</span><span class="p">)</span>
<span class="p">)</span></code></pre></figure>
<p>It takes in our four raw categorical features (<code class="language-plaintext highlighter-rouge">odor</code>, <code class="language-plaintext highlighter-rouge">stalk_root</code>, <code class="language-plaintext highlighter-rouge">stalk_surface_below_ring</code>, <code class="language-plaintext highlighter-rouge">spore_print_color</code>) and returns a 6-tuple representing a one-hot encoded feature vector that can now be scored by our model!</p>
<h3 id="scoring-with-a-model-object">Scoring With a Model Object</h3>
<p>A modeling team needs to decide if they are keeping their model as a <strong>model binary</strong> and loading it in the same framework that it was trained in or if they are storing it in an agnostic format to be loaded in a language/framework-independent manner. An example of a model binary would be the <code class="language-plaintext highlighter-rouge">pickle</code> serialization module in Python. Pickle is an easy way to save model objects within python, but <em>they can only be deserialized in python if the library used for training is in the namespace</em>. An example of a framework-independent format is the Open Neural Network Exchange (ONNX). ONNX is a universal method of storing neural network weights/architectures so that models can be passed between neural network frameworks. Another flexible representation of models is the Predictive Model Markup Language (PMML).</p>
<p>In my case, I could install <code class="language-plaintext highlighter-rouge">scikit-learn</code>, keep the model object in memory, and score incoming requests using the decision tree object’s <code class="language-plaintext highlighter-rouge">predict()</code> function. Because of the simplicity of my model, however, there is an easier way to generate a model score. I chose to slightly modify <a href="https://stackoverflow.com/a/39772170" target="_blank">some code from Stack Overflow</a> to generate python code to mimic the functionality of a <code class="language-plaintext highlighter-rouge">scikit-learn</code> decision tree object. This has the benefit of not requiring <code class="language-plaintext highlighter-rouge">scikit-learn</code> to be installed on my flask server, but it is not practical for all kinds of models, particularly models whose decisions are not easily explained. The function mimicking the decision tree behavior begins something like this:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">score_input</span><span class="p">(</span><span class="n">odor_n</span><span class="p">,</span> <span class="p">...,</span> <span class="n">odor_l</span><span class="p">):</span>
<span class="s">"""
Decision tree spelled out in code for our model.
"""</span>
<span class="k">if</span> <span class="n">odor_n</span> <span class="o"><=</span> <span class="mf">0.5</span><span class="p">:</span>
<span class="p">...</span>
<span class="k">else</span><span class="p">:</span>
<span class="p">...</span></code></pre></figure>
<p>and continues to cover all of the other branches for the decision tree, returning <code class="language-plaintext highlighter-rouge">True</code> or <code class="language-plaintext highlighter-rouge">False</code> depending on the majority class at each leaf node. This assumes we have chosen a probability threshold of 0.5 for our model, but we could also change this to be a probability threshold based on the fraction of negative/positive class instances at each leaf node. This function will be used to score incoming requests.</p>
<h2 id="conclusion">Conclusion</h2>
<p>Putting it all together, we have a POST route on our flask server which receives a payload from an incoming request, parses the request body, validates the features, scores the request, and finally, returns the model’s score as a response. The final function looks as follows:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="o">@</span><span class="n">app</span><span class="p">.</span><span class="n">route</span><span class="p">(</span><span class="s">'/score'</span><span class="p">,</span> <span class="n">methods</span><span class="o">=</span><span class="p">[</span><span class="s">"POST"</span><span class="p">])</span>
<span class="k">def</span> <span class="nf">score</span><span class="p">():</span>
<span class="n">error_msg</span> <span class="o">=</span> <span class="s">''</span>
<span class="n">raw_column_data</span> <span class="o">=</span> <span class="n">request</span><span class="p">.</span><span class="n">get_json</span><span class="p">()</span>
<span class="n">required_cols</span> <span class="o">=</span> <span class="p">[</span><span class="s">'odor'</span><span class="p">,</span> <span class="s">'stalk_root'</span><span class="p">,</span> <span class="s">'stalk_surface_below_ring'</span><span class="p">,</span> <span class="s">'spore_print_color'</span><span class="p">]</span>
<span class="n">column_data_relevant</span> <span class="o">=</span> <span class="n">get_relevant_column_data</span><span class="p">(</span><span class="n">raw_column_data</span><span class="p">,</span> <span class="n">required_cols</span><span class="p">)</span>
<span class="n">missing_cols</span> <span class="o">=</span> <span class="n">get_missing_cols</span><span class="p">(</span><span class="n">required_cols</span><span class="p">,</span> <span class="n">column_data_relevant</span><span class="p">.</span><span class="n">keys</span><span class="p">())</span>
<span class="c1"># Generate missing column error message
</span> <span class="k">if</span> <span class="n">missing_cols</span><span class="p">:</span> <span class="n">error_msg</span> <span class="o">+=</span> <span class="n">missing_col_error_msg</span><span class="p">(</span><span class="n">missing_cols</span><span class="p">)</span>
<span class="c1"># Check for valid feature values
</span> <span class="n">validity_check</span> <span class="o">=</span> <span class="n">validate_input</span><span class="p">(</span><span class="n">column_data_relevant</span><span class="p">,</span> <span class="n">required_cols</span><span class="p">,</span> <span class="n">validity_fn</span><span class="p">)</span>
<span class="k">if</span> <span class="ow">not</span> <span class="nb">all</span><span class="p">(</span><span class="n">validity_check</span><span class="p">.</span><span class="n">values</span><span class="p">()):</span>
<span class="n">error_msg</span> <span class="o">+=</span> <span class="n">invalid_input_error_msg</span><span class="p">(</span><span class="n">validity_check</span><span class="p">,</span> <span class="n">column_data_relevant</span><span class="p">)</span>
<span class="k">if</span> <span class="n">error_msg</span><span class="p">:</span> <span class="c1"># If there are any errors, return error message with 400 response
</span> <span class="k">return</span> <span class="n">Response</span><span class="p">(</span><span class="n">error_msg</span><span class="p">,</span> <span class="n">status</span><span class="o">=</span><span class="mi">400</span><span class="p">)</span>
<span class="n">feature_values</span> <span class="o">=</span> <span class="n">extract_features</span><span class="p">(</span><span class="o">*</span><span class="n">column_data_relevant</span><span class="p">.</span><span class="n">values</span><span class="p">())</span>
<span class="k">return</span> <span class="n">Response</span><span class="p">(</span><span class="n">response</span><span class="o">=</span><span class="nb">str</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">score_input</span><span class="p">(</span><span class="o">*</span><span class="n">feature_values</span><span class="p">))),</span> <span class="n">status</span><span class="o">=</span><span class="mi">200</span><span class="p">)</span></code></pre></figure>
<p>We have set up a very barebones flask server for model scoring, but this is not too different from what we need to do in production to serve model requests. A real server would need more application logging/monitoring to know if it was behaving correctly. The next step is to put this server in a container and send it to a service that can deploy many of our containers at once to scale it out. Thanks for reading!</p>Delicious supplement to dinner? Or deadly poison?(image from Wikipedia)RL Intro and Fundamentals2023-09-12T15:00:00+00:002023-09-12T15:00:00+00:00http://willbeckman.com/rl-fundamentals<center><image src="assets/posts/2023-09-12-rl-fundamentals/present.png"></image></center>
<p><br /></p>
<p style="text-align: center;">The code associated with this blog post can be found
<a target="_blank" href="https://github.com/wbeckman/simple-rl">here</a>.</p>
<p><br /></p>
<h2 id="table-of-contents-">Table of Contents <!-- omit in toc --></h2>
<ul>
<li><a href="#background">Background</a></li>
<li><a href="#mdps">MDPs</a>
<ul>
<li><a href="#values-and-policies-some-terminology">Values and policies? Some terminology</a></li>
<li><a href="#the-slippery-frozen-lake">The Slippery Frozen Lake</a></li>
</ul>
</li>
<li><a href="#exact-solutions-to-mdps">Exact Solutions to MDPs</a>
<ul>
<li><a href="#value-iteration">Value Iteration</a></li>
<li><a href="#policy-iteration">Policy Iteration</a></li>
</ul>
</li>
<li><a href="#reinforcement-learning">Reinforcement Learning</a>
<ul>
<li><a href="#enter-q-learning">Enter: Q-Learning</a></li>
<li><a href="#beyond-q-learning">Beyond Q-Learning</a></li>
</ul>
</li>
</ul>
<h2 id="background">Background</h2>
<p>In this post, I will be talking about Markov Decision Processes (MDPs), their role in reinforcement learning, and a few ways in which they can be solved. This post will first talk about methods for solving MDPs exactly, including value iteration and policy iteration. We will then discuss the reinforcement learning problem, where we assume MDPs are the underlying mechanic in our observations, but we are blind to the MDP’s reward and state transition mechanics. This will require us to infer value estimates using observation tuples and an algorithm called Q-learning.</p>
<h2 id="mdps">MDPs</h2>
<p>Formally, a Markov decision process is defined as a tuple of four things:</p>
\[(S, A, R, T)\]
<p>$S$ is a set of states and $A$ is a set of actions available at each state. $R(s,a,s’)$ is a reward function, namely, “if I take action $a$ in state $s$ and end in state $s’$, what reward will I receive?”. Finally, $T(s,a,s’)$ is a state-transition function that defines transition probabilities between states $s$ and $s’$, given an action $a$. Solving an MDP involves maximizing the reward we can receive from the environment within a finite-time horizon.</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/mdp.png" width="400px" alt="" />
</p>
<p>
<em>A simple MDP (credit to wikipedia). States can be seen in green and actions can be seen in red. Each state/action pair is associated with one or more transition probabilities to other states. The squiggly arrows show rewards emanating from certain state transitions being made.</em>
</p>
</center>
<p>To make this somewhat abstract definition of an MDP concrete, this is what a small portion of the transition function, $T$ would look like for the MDP defined in the diagram above:</p>
<table>
<thead>
<tr>
<th>State</th>
<th>Action</th>
<th>Other State</th>
<th>Transition Probability</th>
</tr>
</thead>
<tbody>
<tr>
<td>$S_0$</td>
<td>$A_0$</td>
<td>$S_0$</td>
<td>0.5</td>
</tr>
<tr>
<td>…</td>
<td>…</td>
<td>…</td>
<td>…</td>
</tr>
<tr>
<td>$S_2$</td>
<td>$A_1$</td>
<td>$S_0$</td>
<td>0.3</td>
</tr>
</tbody>
</table>
<p>You can imagine how the rest of the table would look. Similarly, a fragment of the reward function, $R$, would look as follows:</p>
<table>
<thead>
<tr>
<th>State</th>
<th>Action</th>
<th>Other State</th>
<th>Reward Magnitude</th>
</tr>
</thead>
<tbody>
<tr>
<td>$S_0$</td>
<td>$A_0$</td>
<td>$S_0$</td>
<td>0.0</td>
</tr>
<tr>
<td>…</td>
<td>…</td>
<td>…</td>
<td>…</td>
</tr>
<tr>
<td>$S_2$</td>
<td>$A_1$</td>
<td>$S_0$</td>
<td>-1.0</td>
</tr>
</tbody>
</table>
<p>The definition of an MDP is intentionally very general. For example, the following situations could all be formalized as MDPs:</p>
<ul>
<li>A maze within a grid. Actions consist of N/E/S/W cardinal directions and states consist of valid grid positions. Rewards for transitioning to any position that is not the terminal maze position are small and negative. There is a large positive reward for reaching the terminal maze position.</li>
<li>A game of chess. Actions are any valid action on a chess board, states are valid board positions. There is a positive reward for arriving at a winning game state and a negative reward for arriving at a losing game state.</li>
<li>A self-driving car. Actions are to coast, accelerate, brake, turn left, turn right (and then combinations of those actions as well). The state consists of measurements of the car - current speed/acceleration, GPS position, heading, etc. There is a large positive reward for getting a passenger safely to their destination, a very large negative reward for crashing, hitting pedestrians, or breaking traffic laws.</li>
</ul>
<p>We will talk about how MDPs are tied to reinforcement learning a bit later, but first will discuss two common algorithms for finding optimal state/action combinations within an MDP.</p>
<h3 id="values-and-policies-some-terminology">Values and policies? Some terminology</h3>
<p>I make reference to the ideas of “values” and “policies” a lot in this post, but if you’re totally new to reinforcement learning, these terms aren’t necessarily helpful. A <strong>policy</strong> is how an agent chooses an action in a given state. The <strong>optimal policy</strong> is the best possible policy, where, for every state, we are choosing an action which maximizes the amount of reward we expect to obtain. Our goal is to find a policy that is as close to the optimal policy as possible. The policy is usually denoted as $\pi(s)$, and the optimal policy is denoted as $\pi_*(s)$.</p>
<p>The <strong>values</strong> of a state (or state-values) are the accumulated rewards that we expect to receive for being in state $s$ across the entire episode. The optimal value function tells us “if we were to act optimally according to our value estimates from this point onwards, this value represents the cumulative reward that we expect to receive.” Value functions are denoted as $V(s)$ and the optimal value function is denoted $V_*(s)$. You might worry about the optimal value for a state growing infinitely large if we find a policy containing a positive reward loop. <strong>For this reason, we use a “discounting” factor,</strong> $\gamma$, to denote “how much we value future reward”. On most tasks, we set $\gamma$ close to 1.0 (>0.9), so that we can look far ahead.</p>
<p>It is worth mentioning there are episodic (finite-time horizon) and non-episodic (infinite-time horizon) environments for MDPs, as well. Episodic environments contain a terminal state which requires us to restart from the beginning. An example of an episodic MDP would be chess, which ends when there is a win, loss, or draw. A non-episodic MDP would be traffic control, which has no end-of-episode trigger.</p>
<p>If you take one thing away from this blog post, it would be good to remember the <strong>Markov property</strong>. The Markov property states that an MDP is memoryless - this means that the best action given a state is completely independent of the states that came before it, and therefore, the best action depends only on the current state observations. This is desirable, since it removes (potentially expensive) conditioning on previous observations.</p>
<p>As an aside, although the reinforcement learning framework assumes the underlying mechanics of the task are an MDP, it has been shown empirically that some RL algorithms can still learn very good policies on tasks that do not satisfy the Markov property (such as arcade games - in pong, we do not know the <em>trajectory</em> of the ball from a single frame). Now that some basic terminology has been cleared up, let’s discuss the simple environment we intend to solve.</p>
<h3 id="the-slippery-frozen-lake">The Slippery Frozen Lake</h3>
<p>The toy environment that will be used to demo MDP algorithms and reinforcement learning strategies is the slippery frozen lake environment. The environment works as follows:</p>
<ul>
<li>The environment is a square grid. Each grid cell represents a state in the MDP, which we will attempt to determine the value of.</li>
<li>Every state has <em>4 actions</em> (N, E, S, W). Walking into a wall is a valid action, but will just make you stay where you are.</li>
<li>Because the lake is slippery, <em>selected actions are only executed correctly 1/3 of the time</em>. With 1/3 probability each, adjacent cardinal direction actions are taken. For example, if I choose to walk north, I will walk north with 1/3 probability, east with 1/3 probability, or west with 1/3 probability.</li>
<li>There are no rewards except in terminal states. If we reach the goal, the reward is 1.0. If we fall in a hole, the reward is -1.0.</li>
</ul>
<p>The optimal policy for this environment is one in which we eventually reach our goal, even if it takes us a very long time to do so. This means walking into walls a lot to avoid falling in the lake.</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/random_policy.gif" width="400px" alt="" />
</p>
<p>
<em>The untimely demise of our protagonist under a random policy.</em>
</p>
</center>
<p>So how do we find the safest route to our goal, given the uncertainty of our movement?</p>
<h2 id="exact-solutions-to-mdps">Exact Solutions to MDPs</h2>
<p>This section introduces two algorithms that are used to solve MDPs exactly when their state- and reward-transition functions are known. This is almost never the case in practice, but these algorithms are interesting theoretically.</p>
<h3 id="value-iteration">Value Iteration</h3>
<p>The first approach to solving MDPs is an algorithm called <strong>value iteration</strong>. Value iteration initializes value estimates for each state randomly and then iteratively refines its value estimates using the following equation:</p>
\[V(s) = \max_a \sum_{s'} T(s, a, s') * [R(s, a, s') + \gamma V(s')]\]
<p>What this means, effectively, is we iteratively use our <em>previous estimate of</em> $\textit{V}$ to eventually converge to the optimal value function, $V^{*}$. You may be wondering - how on earth does this work? We initialized $V$ as a bunch of random values, so how could this possibly converge to $V*$? The answer lies in the reward function - $R(s, a, s’)$. With each step of value iteration, the true rewards from the MDP are nudging our values closer to the true values of each state, making each state-value estimate monotonically better than the last.</p>
<p>When we iteratively apply this, within an infinite limit, the values will converge to the true values for each state. In practice, we can stop running value iteration if the <em>largest difference between values from one step to the next</em> is smaller than a manually chosen threshold (say, < 0.001 units, but this also problem-/reward-magnitude-dependent). We can also choose to stop value iteration if the optimal policy hasn’t changed in a fixed number of iterations.</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/value_iteration_value_heatmap.png" width="400px" alt="" />
</p>
<p>
<em>Optimal values found by value iteration for states in the frozen lake grid. Successful goal state with reward 1.0 is in the bottom right. All other states with no numbers are terminal non-goal states with -1.0 reward. Values for states that are guaranteed a reward of 1.0 are lower when they are further away from the goal state because of $\gamma$ discounting.</em>
</p>
</center>
<p>It’s worth mentioning that there is a one-to-one mapping from the optimal value function to the optimal policy. For each state, we choose the action that maximizes our expected future reward: \(a_{V*}=\mathop{\mathrm{argmax}}_a \sum_{s'} T(s, a, s') * [R(s, a, s') + \gamma V^*(s')]\).</p>
<h3 id="policy-iteration">Policy Iteration</h3>
<p>Another algorithm for solving MDPs is called policy iteration. Policy iteration has similarities to value iteration, specifically in the <em>policy evaluation step</em>.</p>
<p>We begin by initializing a random policy. We then <em>evaluate</em> our policy in a way that is similar to value iteration (except instead of a <em>max</em> over actions, we assess the action that is chosen according to our current policy). This is the <strong>policy-evaluation</strong> step. The policy-evaluation step looks as follows:</p>
\[V_{\pi_t}(s) = \sum_{s'} T(s, \pi_t(a), s') * [R(s, \pi_t(a), s') + \gamma V(s')]\]
<p>We run this iteratively in exactly the same way as value iteration until the largest magnitude difference between timesteps becomes smaller than a manually chosen threshold (<em>note: this convergence criteria is exactly how value iteration converges</em>). Once the policy-evaluation step has converged, we use our value estimates from the policy-evaluation step to improve the policy. This is the <strong>policy-improvement step</strong>:</p>
\[\pi_{t+1}(s) = \mathop{\mathrm{argmax}}_a(\sum_{s'} T(s, a, s') * [R(s, a, s') + \gamma V(s')])\]
<p>Basically, the improved policy is the <em>best action based on the value function that we just computed in the policy-evaluation step</em>. The policy-improvement step is relatively quick compared to the policy evaluation step, since it only passes over each state once (as opposed to an arbitrary number of times). If the improved policy ends up being <em>the same as the policy from the previous timestep, policy iteration has converged</em>. If not, we use the <em>improved policy</em> as input to the policy evaluation step and repeat this evaluation $\rightarrow$ improvement step iteratively until the policy converges.</p>
<center>
<p>
$$\pi_{0}(S) \rightarrow V_{\pi_0}(S) \rightarrow \pi_{1}(S) \rightarrow V_{\pi_1}(S) \rightarrow ... \rightarrow \pi_{*}(S)$$
<em>Policy evaluation/improvement loop for policy evaluation.</em>
</p>
</center>
<p>The optimal policy that policy iteration (and also value iteration) converge to can be seen below:</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/policy_iteration_arrow_heatmap_overlaid.png" width="400px" alt="" />
</p>
<p>
<em>Optimal policy for policy/value iteration.</em>
</p>
</center>
<p>As you can see, the optimal policy rewards extreme caution. It is impossible to leave the perimeter of the map, but, without fail, in an infinite-time horizon, this policy will achieve the goal state and receive a reward of 1.0.</p>
<h2 id="reinforcement-learning">Reinforcement Learning</h2>
<p>Policy iteration and value iteration both converge to the same optimal policy, since they are each guaranteed to find a solution to an MDP. While value/policy iteration are algorithms for solving MDPs when the transition and reward functions are fully accessible, practical problems are not usually formulated in this manner. In the real world, we usually operate through observations (sensor observations, etc.) and the underlying state-transition and reward mechanics are unknown to us. In the reinforcement learning problem, which more closely models the real world, we assume that we simply see a series of tuples $(s_t, a_t, r_{t+1}, s_{t+1})$ and we must act with no knowledge of $R$ or $T$. This means that we could the same action in the same state as before and get different results, without us fully understanding why. The process by which observation tuples are generated is shown below:</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/agent-environment.png" alt="" />
</p>
<p>
<em>No blog post about RL is complete without this diagram. Diagram (and lots of other info) from <a href="http://incompleteideas.net/book/the-book-2nd.html">Reinforcement Learning: An Introduction by Sutton & Barto.</a></em>
</p>
</center>
<p>How are MDPs and reinforcement related if we can’t see observe any internals of an MDP? <em>Reinforcement learning assumes that an MDP is producing the observation tuples</em> that we are seeing when we interact with the environment. Since we don’t have a $T$ or $R$ function available to us, we are faced with a more difficult problem and we need a new algorithm to account for our new uncertainty.</p>
<h3 id="enter-q-learning">Enter: Q-Learning</h3>
<p>Why is it called Q-learning? In Q-learning, we create a model of the <strong>quality</strong> of state-action pairs, which we call $Q(s, a)$. This is similar to a value function $V$, but now, we are considering a state-action value instead of just a state value. Q-Learning is even simpler than value iteration and policy iteration, but it takes much longer to converge. For a single observation tuple $(s_t, a_t, r_{t+1}, s_{t+1})$, this is the Q-learning update rule:</p>
\[Q(s_t, a_t) = Q(s_t, a_t) + \alpha \cdot [(r_{t+1} + \gamma * \max_a{Q(s_{t+1}, a)}) - Q(s_t, a_t)]\]
<p>In this update rule, we choose $\alpha$, which is a “blending” parameter that governs how much we want our observation to influence our current estimate of $Q(s_t,a_t)$. Similarly as before, the $\gamma$ parameter governs how much we value future reward. This algorithm is similar to value iteration in that it iteratively refines its value estimates, but this time, we update them based on observations instead of a directly observed model of the MDP.</p>
<center>
<p>
<img src="assets/posts/2023-09-12-rl-fundamentals/q_policy.gif" width="400px" alt="" />
</p>
<p>
<em>The er... meandering... Q-learning policy in action after 20k episodes. Note that some action choices are still not optimal - for example, first row, fifth column.</em>
</p>
</center>
<p>Although we create a “model” of the quality function, $Q$, Q-learning is confusingly <em>not</em> what is known as a “model-based approach”. <strong>Model-based</strong> approaches in reinforcement learning create a model of the state-transition- and reward-dynamics of the underlying MDP. <strong>Q-learning is <em>model-free</em></strong>, because it does not create an explicit model of $T(s,a,s’)$ and $R(s,a,s’)$.</p>
<p>When training a Q-learning algorithm, the update rule is given above, but how do we choose actions to receive experience tuples to update our value estimates? If we always choose actions that we believe to be optimal, we will never know the value of states that our policy visits infrequently. However, if we only behave randomly, we may never actually execute an optimal policy and may be wasting time exploring when we already know how to act optimally. The trade-off between accumulating lots of reward and gaining information about not-yet-visited states is called the <strong>exploration-exploitation dilemma</strong>. We should choose actions so that we can balance exploration with optimizing reward.</p>
<p>A simple and widely used algorithm for choosing actions for Q-learning is called $\epsilon$-greedy action selection. $\epsilon$-greedy action selection states that we should choose an action randomly with probability $\epsilon$ or otherwise, with probability $1-\epsilon$, we choose an action from the policy given by our current estimation of $Q(s, a)$. It is common practice to <em>reduce</em> $\epsilon$ ($\epsilon$-decay) over time to decrease the rate of random action selection as our policy improves.</p>
<p>Value iteration converges in 643 steps, policy iteration in 10 steps, and Q-learning in 20000 episodes (and not even to the optimal policy for every grid square). These aren’t extremely meaningful comparisons since the steps of each algorithm take a variable amount of computation time, but it’s clear that twenty-thousand episodes for practical RL applications would require a lot of time in the real world. Even if time isn’t an issue, most real-world control problems wouldn’t have the tolerance for twenty-thousand episodes worth of mistakes.</p>
<h3 id="beyond-q-learning">Beyond Q-Learning</h3>
<p>Of course, the slippery frozen lake is incredibly simple compared to many real-world reinforcement learning problems. Many real-world applications use continuous-valued sensor measurements as state measurements. Because continuous-valued sensor measurements are much more varied and represent a massive number of states when combined, it becomes intractable to create a Q value for every state-action pair. In this case, we prefer to <strong>approximate</strong> Q-values with a flexible function approximator, such as a neural network. This approach is called Deep Q-Learning. There are some tricks that are needed to get Deep Q-Learning working, but it is an incredible algorithm that has been used to play video games <strong>at superhuman levels from just pixel data</strong><sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>. If you’re interested in going deeper learning about reinforcement learning, the most comprehensive guide, as of writing this, is <a href="http://incompleteideas.net/book/the-book-2nd.html">Reinforcement Learning: An Introduction by Sutton & Barto</a>. Thanks for reading!</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p><a href="https://www.nature.com/articles/nature14236" target="_blank">DQN paper</a> <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Re-writing Micrograd: Training a Neural Network from Scratch2023-06-14T16:00:00+00:002023-06-14T16:00:00+00:00http://willbeckman.com/nn-from-scratch<center><image src="assets/posts/2023-06-14-nn-from-scratch/neural_net.png"></image></center>
<p><br /></p>
<p style="text-align: center;">The code associated with this blog post can be found
<a target="_blank" href="https://github.com/wbeckman/micrograd-without-looking">here</a>.</p>
<p><br /></p>
<h2 id="table-of-contents-">Table of Contents <!-- omit in toc --></h2>
<ul>
<li><a href="#background">Background</a></li>
<li><a href="#neural-networks">Neural Networks</a></li>
<li><a href="#computational-graphs">Computational graphs</a>
<ul>
<li><a href="#what-is-a-computational-graph">What is a computational graph?</a></li>
<li><a href="#representing-simple-neural-networks-as-a-computational-graph">Representing (Simple) Neural Networks as a Computational Graph</a></li>
</ul>
</li>
<li><a href="#calculus">Calculus</a>
<ul>
<li><a href="#derivatives-for-a-simple-feedforward-network">Derivatives for a Simple Feedforward Network</a></li>
<li><a href="#chain-rule">Chain rule</a></li>
</ul>
</li>
<li><a href="#gradients-on-a-computational-graph">Gradients on a Computational Graph</a>
<ul>
<li><a href="#the-trivial-case">The Trivial Case</a></li>
<li><a href="#two-levels-deep---propagating-gradients">Two Levels Deep - Propagating Gradients</a></li>
</ul>
</li>
<li><a href="#the-other-parts-of-training-a-neural-network">The Other Parts of Training a Neural Network</a>
<ul>
<li><a href="#loss-functions">Loss Functions</a></li>
<li><a href="#gradient-descent">Gradient Descent</a></li>
</ul>
</li>
<li><a href="#training-a-network">Training a Network!</a></li>
<li><a href="#conclusion">Conclusion</a></li>
</ul>
<h2 id="background">Background</h2>
<p>In an attempt to refamiliarize myself with the backpropagation algorithm (i.e. “backprop”), I re-wrote an autograd library written by Andrej Karpathy called “<a href="https://github.com/karpathy/micrograd" target="_blank">micrograd</a>”, and (almost successfully) managed to do so without looking. Micrograd runs <em><a href="https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation" target="_blank">reverse-mode automatic-differentiation</a> (auto-diff) to compute gradients for a computational graph</em>. Since a neural network is a <em>special case</em> of a computational graph, backprop is a <em>special case</em> of reverse-mode auto-diff when it is applied to a neural network. If this sounds confusing, read on, and I will break this all down step-by step.</p>
<p>You will need a bit of calculus knowledge of elementary derivatives and the chain rule to understand this. If you have previously studied calculus but need a refresher, this post should get you up to speed on what you need to remember. There’s, unfortunately, no way to make a post about backpropagation short, but I have provided a table of contents so you can skim/skip sections that you’re already familiar with.</p>
<p>This is an informal post that will use some formal math notation, mostly derivatives. I almost always prefer code and visuals to math notation, but sometimes it can’t be helped.</p>
<p>If you somehow arrived here without having seen Andrej Karpathy’s original video, I highly recommend you check out the original <a href="https://www.youtube.com/watch?v=VMj-3S1tku0&list=PLAqhIrjkxbuWI23v9cThsA9GvCAUhRvKZ&index=2" target="_blank">here</a>.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="neural-networks">Neural Networks</h2>
<p>If you aren’t familiar with why neural networks are a big deal, I could write thousands of words on that, but I will spare you. Neural networks are non-linear (e.g. any function that’s not a line, think \(y=x^2\)) function approximators that can theoretically approximate <a href="https://en.wikipedia.org/wiki/Universal_approximation_theorem" target="_blank">any continuous function</a>. They can be used to model statistical co-occurrences of words (as in large language models), used to model co-occurrences of image pixels (as in image segmentation/classification models), used to help recommend content on a website (via content embeddings), and are an important component of systems that can play games at superhuman levels (as in deep reinforcement learning). Each of these things are extremely cool in their own right and deserve a blog post of their own, but this post is going low-level in how simple neural networks are trained via backpropagation and gradient descent.</p>
<center>
<p>
<video width="300" height="200" autoplay="" loop="" muted="">
<source src="assets/posts/2023-06-14-nn-from-scratch/lunar-lander-trimmed.mp4" type="video/mp4" />
</video>
</p>
<p>
<em>Maybe not... superhuman performance... but a moderately smart reinforcement learning agent I trained using a deep Q-network with experience replay.</em>
</p>
</center>
<p>The most amazing thing about neural networks to me is that however different the network purposes are, they are all trained with <em>one common algorithm</em> - that algorithm is <strong>backpropagation</strong>. GPT-4 is trained using backpropagation. All sorts of generative AI (e.g. stable diffusion, midjourney) are trained with backpropagation. Even local image explanations in the emerging field of explainable AI are produced using backpropagation (on the <em>input pixels</em> of the image instead of the network weights). It’s not an exaggeration to say that recently, backpropagation has become one of the most important algorithms in the world.</p>
<p>In brief: neural networks are function approximators that take an input X and produce a predicted output Y that attempts to model a true distribution of the input data. There is a training procedure called backpropagation that, when combined with gradient-based optimization algorithms, iteratively drives the neural network’s approximation closer to the true function it is trying to approximate. The <em>data</em> the network is fed, the <em>architecture</em>, and the <em>loss function</em> of a neural network primarily govern how it behaves. If any of this is confusing to you, there is a beautiful introduction to neural networks by <a href="https://www.youtube.com/watch?v=aircAruvnKk" target="_blank">3blue1brown</a>. Otherwise, lets discuss computational graphs.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="computational-graphs">Computational graphs</h2>
<h3 id="what-is-a-computational-graph">What is a computational graph?</h3>
<p>A <em>computational graph</em> is a directed acyclic graph (DAG) in which nodes correspond to either <em>variables</em> or <em>operations</em>. The simplest computational graph might look as follows:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/computational_graph.png" alt="" />
</p>
<p>
<em>A computational graph in its simplest form - two variables: a=1 and b=2 being multiplied to produce a third value, c=2. The code to generate these visuals was adapted from Andrej's graphviz code in Micrograd.</em>
</p>
</center>
<p>While this is a very simple example of a computational graph, it’s a step in the right direction for what we need for a forward pass in a neural network. It might help to show what a single neuron activation looks like in a computational graph and compare it to the more “classic” representation of a neural network.</p>
<h3 id="representing-simple-neural-networks-as-a-computational-graph">Representing (Simple) Neural Networks as a Computational Graph</h3>
<p>You have probably seen a traditional view of a neural network as a bunch of tangled edges between nodes in an undirected graph. While this is a compact way to represent neural networks visually, for the purposes of backpropagation, it’s much better to think of the network as a computational graph.</p>
<p>Let’s pretend that we have a very, very small neural network with two inputs, two hidden nodes, and a single output. Let’s also pretend that we have just computed the activation for a single neuron in the hidden layer, h1. In the graph, we will assume dark gray nodes have triggered as of the present time and light gray nodes have not yet triggered. Here’s what the traditional view of this might look like:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/single_neuron.png" alt="" />
</p>
<p>
<em>A neural network in which we have just computed</em> $$h_1=tanh(x_1w_1+x_2w_2)$$
</p>
</center>
<p>The equivalent computational graph, representing just \(h1\)’s activation, would look as follows:</p>
<center><img src="assets/posts/2023-06-14-nn-from-scratch/single_neuron_comp_graph_final.png" alt="" /></center>
<p>While the computational graph view is a lot less… terse… it makes explicit a number of details that the traditional view of neural networks obscure. For example, you can see, step-by-step, the process of computing a neuron’s activation:</p>
<ol>
<li>Multiply the inputs by the neuron’s weights (\(o_1=w_1x_1; o_2=w_2x_2\))</li>
<li>Sum all of the \(wx\) terms (\(h_1=o_1+o_2\))</li>
<li>Compute the activation for \(h1\) (\(h_1\_activation=tanh(h1))\)</li>
</ol>
<p>It’s unclear that all of this is happening in the first view. More importantly, the computational graph allows us to show what the data (and the gradients, but more on that later) are at each step of the way. You can imagine that if this is one neuron (h1), all the other neurons in a hidden layer (in this case, h2) fire the same way with different weights. Now that you have seen how computational graphs can be used to represent neural networks, we’re going to put this on hold for a second and take a trip back to Calc 1.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="calculus">Calculus</h2>
<p>Because we want to find out how we can change the weights of a neural network to make its performance improve, we will need to calculate the gradient of the weights with respect to <em>some sort of performance measurement</em> (known as the loss, but more on that later). You will need to know a few elementary derivatives and have an intuitive grasp of the chain rule to understand how backpropagation works. Both of those things will be covered in brief here.</p>
<h3 id="derivatives-for-a-simple-feedforward-network">Derivatives for a Simple Feedforward Network</h3>
<p>For this tutorial, we are considering a feedforward neural network with one input layer (i.e. data), a hidden layer, and an output layer. The derivatives that you need to know for this network are: addition, multiplication, tanh, and the power function. The derivatives for these are as follows:</p>
<p><strong>Addition</strong>: \(f(x, y)=x+y\); \(\frac{\partial f}{\partial x}=1; \frac{\partial f}{\partial y}=1\)</p>
<p><strong>Multiplication</strong>: \(f(x, y)=xy\); \(\frac{\partial f}{\partial x}=y; \frac{\partial f}{\partial y}=x\)</p>
<p><strong>Tanh</strong>: \(f(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}}\); \(\frac{d f}{d x}=1-\tanh(x)^2\)</p>
<p><strong>Pow</strong>: \(f(x, y)=x^y\); \(\frac{\partial f}{\partial x}=y * x^{(y-1)}\)</p>
<p>(leaving out the derivative \(\frac{\partial f}{\partial y}\) for the power function, because I’m being lazy and it’s not important for the purposes of this post)</p>
<h3 id="chain-rule">Chain rule</h3>
<p>The amount of times that I’ve heard backpropagation described as a “recursive application of the chain rule” without the explainer providing intuition about what that actually <em>means</em> makes my head spin. In the Andrej’s video where he covers backpropagation, he references Wikipedia’s explanation of the chain rule, which I think is one of the most cogent explanations of a topic that is frequently over-complicated in the context of backpropagation. Specifically, Wikipedia says:</p>
<blockquote>
<p>If a variable z depends on the variable y, which itself depends on the variable x (…), then z depends on x as well, via the intermediate variable y. In this case, the chain rule is expressed as \(\frac{d z}{d x}=\frac{d z}{d y}\frac{d y}{d x}\)</p>
</blockquote>
<p>Effectively, we are able to simply multiply the rates of change if we have \(\frac{d z}{d y}\) and \(\frac{d y}{d x}\) to get \(\frac{d z}{d x}\). Although \(\frac{d z}{d y}\) isn’t really a fraction, an easy way to remember this is that \(d y\) terms still “cancel each other out” as though they were fractions.</p>
<p>The Wikipedia page offers a concrete example, as well. Specifically, it asks us to consider the speed of a human, a bicycle, and a car. Lets say \(h\) represents the speed of a human, \(b=2h\) the speed of the bicycle, and \(c=4b\) the speed of the car. We want to find the rate of change of the car with respect to the human. We can calculate \(\frac{d c}{d b} = 4\) and \(\frac{d b}{d h} = 2\) and by the chain rule, we know that \(\frac{d c}{d h} = \frac{d c}{d b} \frac{d b}{d h} = 4 * 2 = 8\), or that a car is 8 times as fast as a human.</p>
<p>In the context of backpropagation, for an arbitrary weight, we have a gradient that flows into <em>the current network weight</em> from somewhere further down the computational graph (i.e. it is “backpropagated”). This value represents how all of the downstream activations that the current weight feeds into affect the loss. Since the performance of the network is affected by all the activations downstream that the current weight affects, we need to look <em>how it indirectly affects the loss through the nodes that it contributes to downstream</em>. The chain rule gives us a way to compute this value. We must <em>multiply</em> the “global” derivative that has been propagated backwards with the “local” derivative of the current weight. Once we do this, the current node’s derivative becomes the new “global” derivative and we continue to pass the new “global” derivative backwards. This is what is meant by “recursively applying the chain rule”. You don’t have to understand all of this now, and it should become clear with a spelled out example in the next section.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="gradients-on-a-computational-graph">Gradients on a Computational Graph</h2>
<p>Now that we’ve covered the relevant bits of calculus and computational graphs, we are going to combine them to take derivatives <em>on</em> a computational graph using the chain rule. We will first go over a graph that is one node deep and then we will make it deeper so that we are forced to use the chain rule to propagate gradients using the chain rule.</p>
<h3 id="the-trivial-case">The Trivial Case</h3>
<p>Let’s look at the simplest form of a computational graph that we made above. We want to compute the derivative of \(c\) with respect to \(a\) and \(b\). We know trivially that as \(c\) changes, \(c\) varies proportionally to it. The derivative of any function with respect to itself is 1: \(\frac{dc}{dc} = 1\). We can fill this out in our computational graph:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/computational_graph_grad_1_1.png" alt="" />
</p>
</center>
<p>We know that \(c(a, b) = a \cdot b\) and, further that \(\frac{\partial c}{\partial a} = b\). By the same logic, \(\frac{\partial c}{\partial b} = a\). In a computational graph, when backpropagating gradients, <strong>the multiplication operator acts as a magnitude/gradient “swap” that magnifies the current input’s gradient by the other input’s magnitude</strong>. This makes sense - if the current value is a result of a multiplication operation and it grows, the output value grows proportionally to <em>the value that it’s multiplied by</em>. The final graph looks like this:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/computational_graph_grad_1_2.png" alt="" />
</p>
</center>
<p>Note that b took a’s data as its gradient and vice versa. The full operations that are happening here are \(\frac{d c}{d c}\frac{\partial c}{\partial a} = 1.0 * b = 2.0\) and \(\frac{d c}{d c}\frac{\partial c}{\partial b} = 1.0 * a = 1.0\), since we are multiplying the “local” gradients of \(a\) and \(b\) by the downstream gradient, \(\frac{dc}{dc} = 1\). The code to implement the multiplication operation for a computational graph in python is this simple:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">mul</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="s">"""Multiplies two values together and configure result to compute gradient"""</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">)):</span>
<span class="n">other</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="n">other</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">():</span>
<span class="s">"""
Multiply "swaps" magnitudes from inputs and multiplies upstream grad
"""</span>
<span class="bp">self</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="n">other</span><span class="p">.</span><span class="n">data</span> <span class="o">*</span> <span class="n">result</span><span class="p">.</span><span class="n">grad</span>
<span class="n">other</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="bp">self</span><span class="p">.</span><span class="n">data</span> <span class="o">*</span> <span class="n">result</span><span class="p">.</span><span class="n">grad</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="bp">self</span><span class="p">.</span><span class="n">data</span> <span class="o">*</span> <span class="n">other</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">children</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span><span class="n">other</span><span class="p">),</span> <span class="n">_backward</span><span class="o">=</span><span class="n">backward</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span></code></pre></figure>
<p>Without going too much into the definition of the <code class="language-plaintext highlighter-rouge">Value</code> class (it’s a data node in a computational graph), a new value is created (in our case, c). When <code class="language-plaintext highlighter-rouge">c._backward()</code> is called, it does exactly what we just described - the inputs to the multiplication operation (<code class="language-plaintext highlighter-rouge">self</code> and <code class="language-plaintext highlighter-rouge">other</code>) take the other input’s <em>data</em> as a gradient, multiplied by the downstream gradient (<code class="language-plaintext highlighter-rouge">result.grad</code> - in this case, 1.0). The reason we use the <code class="language-plaintext highlighter-rouge">+=</code> operator is because we are <em>accumulating</em> gradients rather than resetting them - a node in the computational graph often feeds into many downstream nodes.</p>
<p>The next computational graph example should help make the gradient flow more concrete.</p>
<h3 id="two-levels-deep---propagating-gradients">Two Levels Deep - Propagating Gradients</h3>
<p>Since we are working with neural networks, let’s look at a common operation: multiplying weights and inputs (\(w_i\) and \(x_i\)) and then taking their sum (we will skip the activation function for now). Similar to the hidden unit that we expressed above, let’s look at two inputs and two weights feeding into a hidden node, \(h1\). We already know that, trivially, \(\frac{d_{h1}}{d_{h1}} = 1\), but to make the gradient flow clearer, let’s pretend that h1 is nested in a larger graph and that it received a gradient value of 2.0 with respect to some downstream variable that we aren’t concerned with.</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/two_level_1.png" alt="" />
</p>
</center>
<p>In calculating \(\frac{d_{h1}}{d_{o1}}\) and \(\frac{d_{h1}}{d_{o2}}\), we need to look to the addition derivative function: \(f(o1, o2)=o1+o2\); \(\frac{d_{h1}}{d_{o1}}=1.0\); \(\frac{d_{h1}}{d_{o1}}=1.0\). Each input to the addition operation has a local gradient of 1. This means any downstream gradient is multiplied with magnitude 1.0 into the current node - <strong>the addition operation just passes on the gradient unchanged and effectively acts as a gradient “splitter”</strong>. The code to implement addition is similarly simple to multiplication:</p>
<figure class="highlight"><pre><code class="language-python" data-lang="python"><span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">other</span><span class="p">):</span>
<span class="s">"""Add two values together and assign local gradient"""</span>
<span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">other</span><span class="p">,</span> <span class="p">(</span><span class="nb">int</span><span class="p">,</span> <span class="nb">float</span><span class="p">)):</span>
<span class="n">other</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="n">other</span><span class="p">)</span>
<span class="k">def</span> <span class="nf">backward</span><span class="p">():</span>
<span class="s">"""
Add "splits" gradients from result back to current node.
"""</span>
<span class="bp">self</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">result</span><span class="p">.</span><span class="n">grad</span>
<span class="n">other</span><span class="p">.</span><span class="n">grad</span> <span class="o">+=</span> <span class="mf">1.0</span> <span class="o">*</span> <span class="n">result</span><span class="p">.</span><span class="n">grad</span>
<span class="n">result</span> <span class="o">=</span> <span class="n">Value</span><span class="p">(</span><span class="bp">self</span><span class="p">.</span><span class="n">data</span> <span class="o">+</span> <span class="n">other</span><span class="p">.</span><span class="n">data</span><span class="p">,</span> <span class="n">children</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span><span class="n">other</span><span class="p">),</span> <span class="n">_backward</span><span class="o">=</span><span class="n">backward</span><span class="p">,</span> <span class="n">_op</span><span class="o">=</span><span class="s">'+'</span><span class="p">)</span>
<span class="k">return</span> <span class="n">result</span></code></pre></figure>
<p>For each input to the addition function (self and other), we assign the gradient for the local nodes as the gradient that we receive from <code class="language-plaintext highlighter-rouge">result</code> times 1.0 - this gives us our “splitter” behavior. Adding behavior in the computational graph is very easy to remember.</p>
<p>When we multiply the downstream gradient at \(h1\) (2.0) into the current node (o1 or o2) with magnitude 1.0 (because of the addition operation), we get (1.0 * 2.0) = 2.0. This is the chain rule! We multiplied the magnitude of the downstream “global” gradient by the current node’s “local” gradient. After following these steps, the computational graph now looks as follows:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/two_level_2.png" alt="" />
</p>
</center>
<p>Continuing backpropagation into the next layer, next we have multiplication operations, which we are already familiar with. To get the gradients for our inputs, we can apply the chain rule to the intermediate gradients that we’ve just computed, and then we simply multiply. Let’s look at the gradient for \(\frac{d_{h1}}{d_{w1}}\). If we have \(\frac{d_{h1}}{d_{o1}}\) and \(\frac{d_{o1}}{d_{w1}}\), we have everything we need: \(\frac{d_{h1}}{d_{w1}}=\frac{\partial_{h1}}{d_{o1}}\frac{d_{o1}}{d_{w1}}\). We just need to calculate the local derivative, \(\frac{d_{o1}}{d_{w1}}\). We know that \(o1 = w1 \cdot x1\), and therefore, that \(\frac{d o1}{d w1} = x1\). So we get \(\frac{d_{h1}}{d_{w1}}=\frac{d_{h1}}{d_{o1}}\frac{d_{o1}}{d_{w1}}= \frac{d_{h1}}{d_{o1}} \cdot x1.data = 2.0 \cdot 2.0 =4.0\). This means that when we increase \(w1\) by 1.0, the node that we are calculating the gradient with respect to (the one that we aren’t concerned with, somewhere downstream beyond h1) increases by 4.0 units. We can apply the same logic to the other inputs to calculate all of the gradients:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/two_level_3.png" alt="" />
</p>
</center>
<p>There is a lot of derivative notation here, but it really just says that you multiply the “local gradient” by the gradient that flows back to it from further down the graph.</p>
<p>Here’s the magic of it: <strong>as long as you can calculate the derivative of a function, you can include it in the computational graph and compute the gradient for inputs to that function</strong>. Every common neural network operation is differentiable for this reason! A simple neural network forward pass might consist of multiplication –> addition –> (differentiable) non-linearity –> multiplication –> addition –> loss function.</p>
<p>The reason this is called reverse-mode autodiff is because we have to compute a forward pass to cache the “data” values, and then, we start from the end and use the cached forward data values in our derivative calculations while we work backwards. There is also a form of autodiff called forward-mode autodiff, but it makes more sense when there are more outputs than inputs, which is usually not the case with neural networks.</p>
<p>This is all you need to understand backpropagation. While I’m glossing over the many network weights that feed into an activation in a hidden layer and the activation function, you can imagine it’s more of the same as what we’ve already done with different derivatives<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup>. With the gradients of the weights, we’re ready to do something powerful. What if we could tweak the weights so that the output goes up or down? This is where gradient descent comes in.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="the-other-parts-of-training-a-neural-network">The Other Parts of Training a Neural Network</h2>
<p>In addition to taking derivatives on a computational graph, we need something that helps us determine how good our predictions are and a way to tweak our network so that the predictions get closer to “good”. This section discusses those two aspects of training a neural network.</p>
<h3 id="loss-functions">Loss Functions</h3>
<p>The ultimate goal of a neural network is to make accurate predictions. In order to assess how well the network makes predictions, we need to consider a <strong>loss function</strong>, which, given a prediction and a <em>true label</em>, scores how close the network was to being correct. If the network is performing poorly, the loss will be high, and if the network is performing well, the loss will be low. Usually, the final output of the network will be this loss function, and the loss function is almost always differentiable. If you are doing regression or classification, you will always have a loss function as the final node in your neural network’s computational graph.</p>
<p>As an example, let’s look at mean-squared error (MSE) - a common loss function used in linear regression:</p>
\[MSE=\frac{1}{n} \sum_{i=1}^n{(Y_i - \hat{Y_i})^2}\]
<p>\(n\) is the amount of input training examples, \(Y_i\) is the true label and \(\hat{Y_i}\) is the predicted label outputted by our network. Let’s examine a few simple cases where \(n=1\). When we predict a label of 1.0, but the true label is 0.0, we get a loss of 1.0. Similarly for when we predict a label of 0.0 and the true label is 1.0. As the prediction gets closer to the loss, the average mean-squared error of outputs gets smaller, driving our predictions closer to the true label. Since a neural network ending with MSE is composed entirely of differentiable functions, we can calculate the derivatives of weights at any depth with respect to the loss function in order to change the weights to drive the loss downward. When we calculate gradients with respect to the loss, we are figuring out how we can make our predictions (\(\hat{Y_i}\)) closer to the true function (\(Y_i\))!</p>
<p>Note that to compute MSE, we only need to compute the derivatives for \(f(x)=x ^ 2\) and \(f(x, y) = x - y\). We could calculate a derivative for MSE and make it its own function, or we could make the computational graph a bit deeper and perform those operations separately. <strong>In an autograd engine, it doesn’t matter the level of granularity that you calculate derivatives for - as long as the function is differentiable, we can calculate \(f(x, y) = (x - y) ^ 2\) or separately calculate \(z=x-y\) and then feed it into \(f(z)=z^2\) afterwards</strong>.</p>
<h3 id="gradient-descent">Gradient Descent</h3>
<p>Once we have calculated the loss, we backpropagate from the end of the graph to calculate gradients for all nodes in the graph. Since the gradients are in the direction of the <em>positive</em> loss, we need to tweak the values of the weights in the <em>opposite</em> direction of the gradient, since we are trying to drive the loss towards 0. The naive gradient descent algorithm is extremely simple: once we have the gradients for the weights, move each weight in the opposite direction of the gradient a very small amount (i.e. 0.01). The small amount that we move it is called the <strong>learning rate</strong>, and is an important parameter for tuning neural networks. In our previous example, the output of a step of gradient descent for the single training example (consisting of (x1, x2)) would look as follows:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/gradient_descent_one_step.png" alt="" />
</p>
</center>
<p>In this case, we used the learning rate \(\alpha=0.01\). W2 was twice as affected than W1 by the gradient descent step because it has 2x larger gradient. We have also completed a new forward pass and can observe the new output of \(h1\) is lower than it previously was! If \(h1\) were a loss function, we would have made the network’s predictions closer to the true function we are trying to model.</p>
<p>This is a simplified example, but it’s really not that much different from what your neural network libraries like PyTorch and Tensorflow are doing in practice. They’re just much better at parallelizing gradient computations and batching your data to do fancier things like mini-batch gradient descent. There is a fair amount we didn’t cover with neural network training here (such as different types of loss functions, different gradient-based optimization algorithms, actually implementing a full neural network), but that isn’t important for understanding how a neural network is trained at a low level. Now, I’ll show you some neural networks I trained from scratch using my autodiff library.</p>
<p><br /></p>
<hr />
<p><br /></p>
<h2 id="training-a-network">Training a Network!</h2>
<p>I debated whether or not to include this, since it’s not immediately related to backpropagation, but I figured it would be cool to show the reward for implementing an autodiff engine. This is an evolving decision boundary of a neural network that was hand-coded with no external dependencies in python. It’s orders of magnitude slower than pytorch would be for the same thing, but it does (more or less) exactly the same thing without the massive parallelization and numerical stability precautions.</p>
<p>The function that I am trying to model is whether or not a point is within a circle with radius 1.0 centered on (0,0). I generated approximately equal parts positive and negative training data. The non-linearity that I used was tanh and the loss function was mean-squared error. Here are networks trained with 1, 2, and 9 hidden neurons - the points plotted from each class are from the training set:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/1_neurons.gif" width="350" alt="" />
<img src="assets/posts/2023-06-14-nn-from-scratch/2_neurons.gif" width="350" alt="" />
<img src="assets/posts/2023-06-14-nn-from-scratch/3_neurons.gif" width="350" alt="" />
<img src="assets/posts/2023-06-14-nn-from-scratch/9_neurons.gif" width="350" alt="" />
</p>
</center>
<p>You can see that one neuron can only roughly approximate a linear decision boundary, two extends that to a parabola, three actually has enough representative power to model a crude circle, and nine has the representative power to model a fairly precise circle. And this was all done with python built-in libraries! Of course, no blog post about training a neural network would be complete without the loss and accuracy per epoch for the (9-hidden neuron) network, which is standard practice to plot to see if there is a bottleneck in network training:</p>
<center>
<p>
<img src="assets/posts/2023-06-14-nn-from-scratch/avg_losses_by_epoch.png" width="350" alt="" />
<img src="assets/posts/2023-06-14-nn-from-scratch/accuracy_by_epoch.png" width="350" alt="" />
</p>
</center>
<h2 id="conclusion">Conclusion</h2>
<p>It’s easier to consider a computational graph separate from a neural network than it is to go through each individual weight in a neural network in excruciating detail when explaining backpropagation. Once you see how the “local derivatives” and the “global derivatives” work together and how the chain rule is recursively applied on a small scale, it’s not a big step to go from our example to a full feedforward network. This was long, but I hope it gave a good framework to approach and debug neural network training. Thanks for reading!</p>
<p><br /></p>
<hr />
<p><br /></p>
<p>If you have anything to add or see anything that you believe may be incorrect, please contact me at will@willbeckman.com.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>If you’re feeling ambitious, try to compute the derivatives for a two layer neural network where there is a <code class="language-plaintext highlighter-rouge">tanh</code> non-linearity and a mean-squared error loss function using the derivatives defined <a href="#derivatives-for-a-simple-feedforward-network">above</a>. Doing this by hand will really hammer in how the chain rule is used. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>