<?xml version="1.0" encoding="UTF-8"?><rss xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom" version="2.0" xmlns:media="http://search.yahoo.com/mrss/"><channel><title><![CDATA[kevin frans blog]]></title><description><![CDATA[kevin frans blog]]></description><link>https://kvfrans.com/</link><image><url>https://kvfrans.com/favicon.png</url><title>kevin frans blog</title><link>https://kvfrans.com/</link></image><generator>Ghost 3.41</generator><lastBuildDate>Sun, 19 Apr 2026 11:44:03 GMT</lastBuildDate><atom:link href="https://kvfrans.com/rss/" rel="self" type="application/rss+xml"/><ttl>60</ttl><item><title><![CDATA[Alchemist's Notes on Deep Learning]]></title><description><![CDATA[<p>I recently compiled a more cleaned-up set of explanations and implementations on topics relating to deep learning and generative modelling. Check them out at <a href="https://notes.kvfrans.com/">https://notes.kvfrans.com/</a></p>]]></description><link>https://kvfrans.com/alchemists-notes-on-deep-learning/</link><guid isPermaLink="false">68373c8336185f7b9686fc30</guid><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Wed, 28 May 2025 16:41:19 GMT</pubDate><content:encoded><![CDATA[<p>I recently compiled a more cleaned-up set of explanations and implementations on topics relating to deep learning and generative modelling. Check them out at <a href="https://notes.kvfrans.com/">https://notes.kvfrans.com/</a></p>]]></content:encoded></item><item><title><![CDATA[Small-Research: Tanh Activations with DDPG]]></title><description><![CDATA[<p>When implementing DDPG-style policy extraction, we often use a tanh normalizer to bound the action space. That way the policy does not attempt to output something OOD when the Q-function is only trained on actions in [-1, 1].</p><pre><code class="language-python">dist = agent.actor(batch['observations'])
normalized_actions = jnp.tanh(dist.loc)
q</code></pre>]]></description><link>https://kvfrans.com/small-research-tanh-activations-with-ddpg/</link><guid isPermaLink="false">65814ba5feb1cd031e03cf12</guid><category><![CDATA[small-research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Tue, 19 Dec 2023 08:00:24 GMT</pubDate><content:encoded><![CDATA[<p>When implementing DDPG-style policy extraction, we often use a tanh normalizer to bound the action space. That way the policy does not attempt to output something OOD when the Q-function is only trained on actions in [-1, 1].</p><pre><code class="language-python">dist = agent.actor(batch['observations'])
normalized_actions = jnp.tanh(dist.loc)
q = agent.critic(batch['observations'], normalized_actions)
q_loss = -q.mean()</code></pre><p>When I was writing code to execute the policy, I had forgotten to include this tanh shaping. So, it was using the <em>raw</em> actions, but clipped between [-1, 1]. Somehow, the performance was still basically at SOTA. </p><pre><code class="language-python">actions = agent.actor(observations).sample(seed=seed)
actions = jnp.tanh(actions) # We don't actually need this line.
actions = jnp.clip(actions, -1, 1)</code></pre><p><strong>Question: </strong>Is it better to not apply tanh during evaluation? One hypothesis is that the policy can't actually output -1 or 1 using tanh, because the pre-activation has to approach infinity. </p><p>I ran experiments comparing the correct version (tanh applied during eval) with my clipped version (instead of tanh, clip the actions):</p><figure class="kg-card kg-gallery-card kg-width-wide"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-11.59.07-PM.png" width="1958" height="610" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-11.59.07-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-11.59.07-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/12/Screen-Shot-2023-12-18-at-11.59.07-PM.png 1600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-11.59.07-PM.png 1958w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-11.59.02-PM.png" width="1958" height="610" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-11.59.02-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-11.59.02-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/12/Screen-Shot-2023-12-18-at-11.59.02-PM.png 1600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-11.59.02-PM.png 1958w" sizes="(min-width: 720px) 720px"></div></div></div></figure><p><strong>Conclusion: </strong>It doesn't actually matter. Both perform the same. DMC environments are known to be solvable with bang-bang control, so the policy is likely outputting actions close to -1 or 1, regardless of the tanh activation.</p>]]></content:encoded></item><item><title><![CDATA[Small-Research: Policy Extraction in IQL]]></title><description><![CDATA[<p><strong>Insight</strong>: IQL-AWR agents on 'maze2d' converge to a standard deviation of 0.4, which is much higher than expected.<br><strong>Hypothesis</strong>: The KL constraint of AWR prevents it from learning Gaussian policies that are centered around the optimal action.<br><strong>Conclusion</strong>: In data-heavy settings, use DDPG or discretized AWR for policy extraction.</p>]]></description><link>https://kvfrans.com/small-research-policy-extraction-in-on-iql/</link><guid isPermaLink="false">6580bbb0feb1cd031e03ce31</guid><category><![CDATA[small-research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Tue, 19 Dec 2023 07:35:17 GMT</pubDate><content:encoded><![CDATA[<p><strong>Insight</strong>: IQL-AWR agents on 'maze2d' converge to a standard deviation of 0.4, which is much higher than expected.<br><strong>Hypothesis</strong>: The KL constraint of AWR prevents it from learning Gaussian policies that are centered around the optimal action.<br><strong>Conclusion</strong>: In data-heavy settings, use DDPG or discretized AWR for policy extraction. For sparse data, use AWR because DDPG will overfit.</p><p>When doing offline RL with IQL, we learn an implicit Q-function without learning a policy. The policy must be extracted later using the information from Q(s,a). In the original IQL paper, advantage-weighted regression (AWR) is used for policy extraction. The AWR objective is a weighted behavior cloning loss with the form:</p><!--kg-card-begin: html--><p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>m</mi><mi>a</mi><mi>x</mi><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>π</mi><mo stretchy="false">(</mo><mi>a</mi><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo><mo>∗</mo><mi>e</mi><mi>x</mi><mi>p</mi><mo stretchy="false">(</mo><mi>λ</mi><mo>∗</mo><mi>A</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">max \; log \; \pi(a|s) * exp(\lambda * A(s,a))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">m</span><span class="mord mathnormal">a</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">e</span><span class="mord mathnormal">x</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">λ</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">A</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mclose">)</span></span></span></span></span></p><!--kg-card-end: html--><p>The AWR objective can be derived as the solution to constrained optimization problem of</p><!--kg-card-begin: html--><ol>
<li>Maximize <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>E</mi><mo stretchy="false">[</mo><mi>π</mi><mo stretchy="false">(</mo><mi>a</mi><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo><mo>∗</mo><mi>A</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">E[ \pi(a|s) * A(s,a) ]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mopen">(</span><span class="mord mathnormal">a</span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">A</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span></li>
<li>st. <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><msub><mi>π</mi><mrow><mi>d</mi><mi>a</mi><mi>t</mi><mi>a</mi></mrow></msub><mi mathvariant="normal">∣</mi><mi>π</mi><mo stretchy="false">)</mo><mo>&lt;</mo><mi>ϵ</mi></mrow><annotation encoding="application/x-tex">KL(\pi_{data} | \pi) &lt; \epsilon</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">d</span><span class="mord mathnormal mtight">a</span><span class="mord mathnormal mtight">t</span><span class="mord mathnormal mtight">a</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">&lt;</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">ϵ</span></span></span></span></span></li>
</ol><!--kg-card-end: html--><!--kg-card-begin: html--><p>The <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>λ</mi></mrow><annotation encoding="application/x-tex">\lambda</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord mathnormal">λ</span></span></span></span></span> term is a temperature parameter that controls how tight the constraint is.</p><!--kg-card-end: html--><p>A nice property of AWR is that it avoids querying out-of-data actions. We only weight state-action pairs inside the dataset, instead of searching over the $A(s,a)$ space which may have inaccurate values.</p><p>But a downside of AWR is that we assign probability mass to each action in the dataset. Since we usually use a Gaussian policy, this leads to bad behavior.</p><p>As an example, let's look at maze-2d. Here the actions are (x,y) velocities to move an agent to a goal. It's not complicated; and intuitively the best actions should be along the edge of the unit sphere. But, the policy doesn't behave that way.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d.png" width="640" height="480" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d.png 600w, https://kvfrans.com/content/images/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d.png 640w"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/media_images_actions_traj_450000_32959380eca73121c010.png" width="640" height="480" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/media_images_actions_traj_450000_32959380eca73121c010.png 600w, https://kvfrans.com/content/images/2023/12/media_images_actions_traj_450000_32959380eca73121c010.png 640w"></div></div></div><figcaption><strong>Left</strong>: Data actions. <strong>Right</strong>: Actions produced by IQL-AWR policy.</figcaption></figure><p>If we look at the standard deviations of the Gaussian heads, they are not at zero. They are at 0.4, which is quite a bit of variance.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-2.57.01-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-2.57.01-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-2.57.01-PM.png 740w" sizes="(min-width: 720px) 720px"></figure><p>Remember that Gaussian heads have a log-probability that is equivalent to L2 loss scaled by variance. They suffer from the issue of mean-matching rather than mode-matching. With AWR, we find a 'compromise' between behavior cloning and advantage-maximization, and this compromise might not make sense.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-3.01.46-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-3.01.46-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-3.01.46-PM.png 1000w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-3.01.46-PM.png 1244w" sizes="(min-width: 720px) 720px"><figcaption>Gaussians cannot capture multi-modal distributions. The most likely point is not the mode, but the mean.</figcaption></figure><p>Let's try some fixed. What if we extract the policy via DDPG? DDPG-style extraction optimizes <em>through the Q-function</em> to find the best action at each state. Importantly, DDPG ignores the shape of the Q-function landscape, and only cares about where the maximum is.</p><p>Another path is to discretize the action space. If we bin the action space into tokens, we can handle multimodality easily. So we don't run into the mean-matching problem. The token with highest advantage will also have the highest probability.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d-2.png" width="640" height="480" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d-2.png 600w, https://kvfrans.com/content/images/2023/12/media_images_actions_data_350000_3924a9c5a61e4354233d-2.png 640w"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/media_images_actions_traj_500000_fb648f73800ab683eda7.png" width="640" height="480" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/media_images_actions_traj_500000_fb648f73800ab683eda7.png 600w, https://kvfrans.com/content/images/2023/12/media_images_actions_traj_500000_fb648f73800ab683eda7.png 640w"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/media_images_actions_traj_500000_772898813d71ebfcddb6.png" width="640" height="480" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/media_images_actions_traj_500000_772898813d71ebfcddb6.png 600w, https://kvfrans.com/content/images/2023/12/media_images_actions_traj_500000_772898813d71ebfcddb6.png 640w"></div></div></div><figcaption><strong>Left</strong>: Data actions. <strong>Middle</strong>: DDPG actions. <strong>Right</strong>: AWR-Discrete actions.</figcaption></figure><p>That looks much more reasonable. The actions are all on the unit circle now, which is consistent with the data. DDPG picks only the highest-advantage actions, whereas the AWR policy still assigns some probability mass to all data actions, but does so in a multimodal way that preserves the correct locations.</p><p><strong>Question: </strong>What about the temperature parameter? As lambda goes to infinity, then the KL constraint becomes zero and AWR approaches the same objective as DDPG. The potential issue is that 1) we get some numerical issues exponentiating a large number, and 2) the percent of actions with non-vanishing gradients becomes lower.</p><p>One way to think about it is that AWR acts like a filter. If the action has enough advantage, increase its probability. Otherwise do nothing. DDPG acts like an optimization problem. For every state, find the action that is the best, and memorize this mapping. AWR has the benefit that it only trains on state-action pairs <em>in the data</em>, which is good for offline RL, but it is also less efficient.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.09.39-PM-1.png" width="1300" height="498" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.09.39-PM-1.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-6.09.39-PM-1.png 1000w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.09.39-PM-1.png 1300w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.10.21-PM.png" width="646" height="322" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.10.21-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.10.21-PM.png 646w"></div></div></div><figcaption>Increasing the temperature helps, but not as much as DDPG. We still have high standard deviation.</figcaption></figure><p><strong>Conclusion: </strong>In maze2d, don't use AWR for policy extraction with a Gaussian head. Use discretization, or use DDPG.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-3.04.22-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-3.04.22-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-3.04.22-PM.png 1000w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-3.04.22-PM.png 1570w" sizes="(min-width: 720px) 720px"><figcaption><strong>DDPG does the best on maze2d.</strong> Highest success rate, and also shortest time until completion.</figcaption></figure><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><p>How about other environments? I tried the same tactics on other envs, and here's what we get. For the full AntMaze, and on D4RL half-cheetah, the AWR style extraction actually performs the best.</p><figure class="kg-card kg-gallery-card kg-width-wide"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.13.31-PM.png" width="632" height="456" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.13.31-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.13.31-PM.png 632w"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.12.43-PM.png" width="646" height="442" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.12.43-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.12.43-PM.png 646w"></div></div></div></figure><p>In both of these cases, the standard deviation via AWR is much lower. Both of these offline datasets are collected from expert data. So the data is naturally unimodal, which may be why AWR does not have an issue. If we look at the standard deviations, they are much lower than the maze2d experiments. All of these envs use an action space of (-1, 1).</p><figure class="kg-card kg-gallery-card kg-width-wide"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.14.05-PM.png" width="632" height="446" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.14.05-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.14.05-PM.png 632w"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.12.55-PM.png" width="646" height="442" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.12.55-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.12.55-PM.png 646w"></div></div></div></figure><p>Why does DDPG do so bad? I am not sure. If we look at the average IQL Q-values, along with the average Q-values that during the policy optimization step, they look equal for both envs. Then again, what we really should look at are the <em>advantages</em>, which might be small in comparison to the value at each state. Likely, DDPG extraction is exploiting some errors in the Q-value network. Both AntMaze and half-cheetah do not provide dense coverage of the action space.</p><figure class="kg-card kg-gallery-card kg-width-wide"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.14.14-PM-1.png" width="1304" height="500" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.14.14-PM-1.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-6.14.14-PM-1.png 1000w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.14.14-PM-1.png 1304w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.13.11-PM.png" width="1282" height="468" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-6.13.11-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/12/Screen-Shot-2023-12-18-at-6.13.11-PM.png 1000w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-6.13.11-PM.png 1282w" sizes="(min-width: 720px) 720px"></div></div></div></figure><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><p>If we look at ExoRL cheetah-run, which does have dense coverage of the action space, we get a different story. </p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-10.27.42-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/12/Screen-Shot-2023-12-18-at-10.27.42-PM.png 600w, https://kvfrans.com/content/images/2023/12/Screen-Shot-2023-12-18-at-10.27.42-PM.png 664w"></figure><p>DDPG is better again. So, it seems like there are two categories here.</p><ul><li>If the dataset has dense coverage, DDPG will solve the multimodality issues of AWR and perform better.</li><li>But without dense coverage, DDPG will overfit to the Q-function errors and learn a degenerate policy.</li></ul><p><strong>Results: </strong>On maze2d and ExoRL, DDPG performs decently better. On AntMaze and D4RL, AWR works and DDPG gets zero reward.</p><p>There's some clear room for improvement here – if the DDPG is prevented from overfitting on the sparse-data environments, likely it can outperform the AWR policy performance.</p>]]></content:encoded></item><item><title><![CDATA[Successor Representations Explained]]></title><description><![CDATA[<!--kg-card-begin: html--><p>Note: We're going to conflate terminology for value-functions and Q-functions. In general the Q-function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q(s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span> is just an action-conditioned version of the value function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>V</mi><mo stretchy="false">(</mo><mi>s</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">V(s)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">V</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mclose">)</span></span></span></span></span>.</p>
<h4 id="successor-representations">Successor Representations</h4>
<p>Let's consider a classic temporal difference method -- estimating the</p>]]></description><link>https://kvfrans.com/successor-representations-explained/</link><guid isPermaLink="false">64fb6769feb1cd031e03ce1f</guid><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Fri, 08 Sep 2023 18:28:15 GMT</pubDate><content:encoded><![CDATA[<!--kg-card-begin: html--><p>Note: We're going to conflate terminology for value-functions and Q-functions. In general the Q-function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q(s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span> is just an action-conditioned version of the value function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>V</mi><mo stretchy="false">(</mo><mi>s</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">V(s)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">V</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mclose">)</span></span></span></span></span>.</p>
<h4 id="successor-representations">Successor Representations</h4>
<p>Let's consider a classic temporal difference method -- estimating the value function. The value of a state is the expected discounted sum of future rewards, starting from that state. Notably, temporal difference methods give us a nice recursive form of the value function.</p>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>V</mi><mo stretchy="false">(</mo><mi>s</mi><mo stretchy="false">)</mo><mo>=</mo><mi>r</mi><mo>+</mo><mi>γ</mi><msub><mi>E</mi><mrow><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo></mrow></msub><mo stretchy="false">[</mo><mi>V</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">V(s) = r + \gamma E_{p(s'|s)}[V(s')]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">V</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.66666em; vertical-align: -0.08333em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.10709em; vertical-align: -0.3552em;"></span><span class="mord mathnormal" style="margin-right: 0.05556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span style="top: -2.5198em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mtight"><span class="mord mathnormal mtight">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.682829em;"><span style="top: -2.786em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord mtight">∣</span><span class="mord mathnormal mtight">s</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.3552em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.22222em;">V</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span></p>
<p>Learning a value function is helpful because it lets us reason about future behaviors, from a single function call. What if we could do this not over reward, but over states themselves?</p>
<p>Successor representations are an application of temporal difference learning, where we predict expected <em>state occupancy</em>. If I follow my policy, discounted into the future, what states will I be in?</p>
<p>This is easily understandable in a MDP with discrete states. Let's say I start out in state A. After one timestep, I transition into a probability distribution over A,B,C. Each path then leads to a distribution for t=2, then t=3, etc. Now, we'll combine all of these distributions as a weighted average.</p>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo><mo>=</mo><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mi>y</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo><mo>+</mo><mi>γ</mi><msub><mi>E</mi><mrow><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo stretchy="false">)</mo></mrow></msub><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s) = (1-y)p(s'|s) + \gamma E_{p(s'|s)} \mu(s^*|s')</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.00189em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.10709em; vertical-align: -0.3552em;"></span><span class="mord mathnormal" style="margin-right: 0.05556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span style="top: -2.5198em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mtight"><span class="mord mathnormal mtight">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.682829em;"><span style="top: -2.786em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord mtight">∣</span><span class="mord mathnormal mtight">s</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.3552em;"><span></span></span></span></span></span></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>
<p>Looks similar to the value function, doesn't it. Successor representations follow the same structure as the value function, but we're accumulating state probabilities rather than expected reward.</p><!--kg-card-end: html--><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/09/Screen-Shot-2023-09-08-at-11.21.24-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/09/Screen-Shot-2023-09-08-at-11.21.24-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/09/Screen-Shot-2023-09-08-at-11.21.24-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/09/Screen-Shot-2023-09-08-at-11.21.24-AM.png 1600w, https://kvfrans.com/content/images/2023/09/Screen-Shot-2023-09-08-at-11.21.24-AM.png 2280w" sizes="(min-width: 720px) 720px"></figure><!--kg-card-begin: html--><h4 id="gamma-models">Gamma Models</h4>
<p>In more recent methods, we use deep neural networks to learn the successor representation. Gamma models attempt to a learn an action-conditioned successor function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span>. Notably, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> is a generative model and can handle both continuous and discrete settings. That means we can learn <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span> using our favorite generative-modelling methods. The paper uses GANs and normalizing flows.</p>
<p>This interpretation gives us some nice properties:</p>
<ul>
<li>If our reward function is <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>δ</mi><mo stretchy="false">(</mo><mi>s</mi><mo>=</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\delta(s = s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03785em;">δ</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> (i.e. sparse reward at some goal state), then the value of any state is equivalent to evaluating the gamma-model at <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span>. The value function of a sparse reward is equal to the probability that the policy occupies that state.</li>
<li>We can extract the value function for arbitrary reward functions as well. It involves a sum over trajectories sampled from the policy:
<ul>
<li><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo>=</mo><mo>∑</mo><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q(s,a) = \sum \mu(s^*|s,a) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></li>
</ul>
</li>
<li>If we set <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>γ</mi><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\gamma=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal" style="margin-right: 0.05556em;">γ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span></span></span></span></span>, then the gamma model is actually a one-step world model. In this case, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo>=</mo><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a) = p(s'|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00189em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span> and we're simply modelling the environment's transition dynamics.</li>
</ul>
<h4 id="forward-backward-models">Forward-Backward Models</h4>
<p>This next work describes FB models, an extension of successor representations that lets us extract optimal policies for any reward function.</p>
<p>Above, we showed we can extract any Q(s,a), given a reward function r(s) and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span>. In continuous space, we need to sum over trajectories to compute  <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>Q</mi><mi>r</mi></msub><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo>=</mo><mo>∑</mo><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q_r(s,a) = \sum \mu(s^*|s,a) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>. In discrete space, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span></span></span></span></span> can directly output a vector over <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>s</mi></mrow><annotation encoding="application/x-tex">s</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">s</span></span></span></span></span>, and we can compute Q with a matrix multiplication <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo>=</mo><mi>μ</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>r</mi></mrow><annotation encoding="application/x-tex">Q(s,a) = \mu(s,a)^Tr</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span></span></span></span></span>. Moving forward, we'll assume we are in the continuous setting.</p>
<p>A key insight about value-functions and successor-functions is that they are policy-dependent. Depending on the policy, the distribution of future (reward/states) will change. Assuming we had some representation of a policy <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, we could train a policy-conditioned successor function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a,z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. We can further decompose <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> into the policy-dependent and -independent parts, F and B. <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a,z) = F(s,a,z)^TB(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>.</p>
<p>Let's take a look at that decomposition in detail. The original <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> is a neural-network with inputs <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo separator="true">,</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(s^*,s,a,z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. It returns a scalar of the probability of being in state <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>s</mi><mo>∗</mo></msup></mrow><annotation encoding="application/x-tex">s^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.688696em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span></span></span></span></span>. In the decomposition, F is a network with inputs (s,a,z) and output (d). B is a network with inputs (<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>s</mi><mo>∗</mo></msup></mrow><annotation encoding="application/x-tex">s^*</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.688696em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span></span></span></span></span>), outputting a vector (d,1). A final dot product reduces these vectors into a scalar. One way to think about these representations is: F(s,a,z) calculates a set of <em>expected future features when following policy <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span></em>. B(d,s*) then calculates how those features map back to real states.</p>
<p>Now, how do we find <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>π</mi><mi>z</mi></msub></mrow><annotation encoding="application/x-tex">\pi_z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.04398em;">z</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span></span></span></span></span> that is optimal for every reward function? Well, we need a way to learn a function z(r). The trick here is we can choose this function, so let's be smart about it. Let's define z(r) = <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∑</mo><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\sum B(s^*) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>. Some derivations:<br>
- Remember that the Q-function for any <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi></mrow><annotation encoding="application/x-tex">r</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span></span></span></span></span> was <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>Q</mi><mi>r</mi></msub><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mo>=</mo><mo>∑</mo><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q_r(s,a) = \sum \mu(s^*|s,a) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">Q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">r</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>.<br>
- Also, we decomposed <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a,z) = F(s,a,z)^TB(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>.<br>
- The optimal Q is thus <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∑</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\sum F(s,a,z)^TB(s^*) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.09134em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> = <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mo>∑</mo><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">F(s,a,z)^T \sum B(s^*) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.09134em; vertical-align: -0.25001em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> .<br>
- Based on <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi><mo stretchy="false">(</mo><mi>r</mi><mo stretchy="false">)</mo><mo>=</mo><mo>∑</mo><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">z(r) = \sum B(s^*) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>, we get  <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>z</mi></mrow><annotation encoding="application/x-tex">Q(s,a,z) = F(s,a,z)^Tz</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span></p>
<p>Notably, the reward term has disappeared! Let's interpret. The <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> vector here is playing two roles. First, it serves a policy representation that is given to F -- it contains information about the behavior of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>π</mi><mi>z</mi></msub></mrow><annotation encoding="application/x-tex">\pi_z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.58056em; vertical-align: -0.15em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.04398em;">z</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span></span></span></span></span>. Second, it serves as a feature-reward mapping -- it describes how much reward to assign to each successor feature.</p>
<p>The F and B networks are trained without any reward-function knowledge. We'll do the same thing as in previous works. We update <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mi>F</mi><mi>B</mi></mrow><annotation encoding="application/x-tex">\mu = FB</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span></span></span></span></span> through a Bellman update. Remember that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>π</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mtext>argmax</mtext><mtext>  </mtext><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>z</mi></mrow><annotation encoding="application/x-tex">\pi(s',z) = \text{argmax} \; F(s,a,z)^Tz</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.00189em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord text"><span class="mord">argmax</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu = F(s,a,z)^TB(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>.</p>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mi>δ</mi><mo stretchy="false">(</mo><mi>s</mi><mo>=</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mo>+</mo><mi>γ</mi><msub><mi>E</mi><mrow><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow></msub><mi>μ</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mi mathvariant="normal">∣</mi><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo separator="true">,</mo><mi>π</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\mu(s^*|s,a,z) = \delta(s=s^*) + \gamma E_{p(s'|s,a)} \mu(s^*|s',\pi(s',z),z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03785em;">δ</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.10709em; vertical-align: -0.3552em;"></span><span class="mord mathnormal" style="margin-right: 0.05556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span style="top: -2.5198em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mtight"><span class="mord mathnormal mtight">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.682829em;"><span style="top: -2.786em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord mtight">∣</span><span class="mord mathnormal mtight">s</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">a</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.3552em;"><span></span></span></span></span></span></span><span class="mord mathnormal">μ</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mo>=</mo><mi>δ</mi><mo stretchy="false">(</mo><mi>s</mi><mo>=</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mo>+</mo><mi>γ</mi><msub><mi>E</mi><mrow><mi>p</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mi mathvariant="normal">∣</mi><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo stretchy="false">)</mo></mrow></msub><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mtext>argmax</mtext><mtext>  </mtext><mi>F</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>z</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">F(s,a,z)^TB(s^*) = \delta(s=s^*) + \gamma E_{p(s'|s,a)} F(s,\text{argmax} \; F(s',a,z)^Tz,z)^TB(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03785em;">δ</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.19653em; vertical-align: -0.3552em;"></span><span class="mord mathnormal" style="margin-right: 0.05556em;">γ</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.3448em;"><span style="top: -2.5198em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">p</span><span class="mopen mtight">(</span><span class="mord mtight"><span class="mord mathnormal mtight">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.682829em;"><span style="top: -2.786em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mord mtight">∣</span><span class="mord mathnormal mtight">s</span><span class="mpunct mtight">,</span><span class="mord mathnormal mtight">a</span><span class="mclose mtight">)</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.3552em;"><span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord text"><span class="mord">argmax</span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>
<p>You'll notice that in addition to trajectories, this update is conditioned on <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>. Since we don't know <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> ahead of time, we just need to guess -- we'll generate random <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> vectors from a Gaussian distribution.</p>
<p><em>Test Time</em>. Once we've learned F and B, extracting the optimal policy is simple. We'll calculate <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi><mo>=</mo><mo>∑</mo><mi>B</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo><mtext>  </mtext><mi>r</mi><mo stretchy="false">(</mo><msup><mi>s</mi><mo>∗</mo></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">z = \sum B(s^*) \; r(s^*)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.00001em; vertical-align: -0.25001em;"></span><span class="mop op-symbol small-op" style="position: relative; top: -5e-06em;">∑</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.05017em;">B</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">s</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.688696em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mbin mtight">∗</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>. That gives us our Q-function <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mi>F</mi><mo stretchy="false">(</mo><mi>s</mi><mo separator="true">,</mo><mi>a</mi><mo separator="true">,</mo><mi>z</mi><msup><mo stretchy="false">)</mo><mi>T</mi></msup><mi>z</mi></mrow><annotation encoding="application/x-tex">Q(s,a,z) = F(s,a,z)^Tz</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.09133em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">F</span><span class="mopen">(</span><span class="mord mathnormal">s</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">a</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.841331em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">T</span></span></span></span></span></span></span></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, which we can act greedily over.</p><!--kg-card-end: html-->]]></content:encoded></item><item><title><![CDATA[Variational Information Bottleneck Explained]]></title><description><![CDATA[<!--kg-card-begin: markdown--><p>Let's take a look at neural networks from the perspective of information theory. We'll be following along with the paper <a href="https://arxiv.org/abs/1612.00410">Deep Variational Information Bottleneck</a> (Alemi et al. 2016).</p>
<!--kg-card-end: markdown--><!--kg-card-begin: html--><p>Given a dataset of inputs <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>X</mi></mrow><annotation encoding="application/x-tex">X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span> and outputs <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Y</mi></mrow><annotation encoding="application/x-tex">Y</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span></span></span></span></span>, let's define some intermediate representation <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span>. A</p>]]></description><link>https://kvfrans.com/variational-information-bottleneck-explained/</link><guid isPermaLink="false">64ebcf65feb1cd031e03cdf9</guid><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Sun, 27 Aug 2023 22:39:44 GMT</pubDate><content:encoded><![CDATA[<!--kg-card-begin: markdown--><p>Let's take a look at neural networks from the perspective of information theory. We'll be following along with the paper <a href="https://arxiv.org/abs/1612.00410">Deep Variational Information Bottleneck</a> (Alemi et al. 2016).</p>
<!--kg-card-end: markdown--><!--kg-card-begin: html--><p>Given a dataset of inputs <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>X</mi></mrow><annotation encoding="application/x-tex">X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span> and outputs <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Y</mi></mrow><annotation encoding="application/x-tex">Y</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span></span></span></span></span>, let's define some intermediate representation <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span>. A good analogy here is that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> is the features of a neural network layer. The standard objective is to run gradient descent so to predict <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Y</mi></mrow><annotation encoding="application/x-tex">Y</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span></span></span></span></span> from <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>X</mi></mrow><annotation encoding="application/x-tex">X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span>. But that setup doesn't naturally give us good <em>representations</em>. What if we could write an objective over <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span>?</p>
<p>One thing we can say is that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> should tell us about the output <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Y</mi></mrow><annotation encoding="application/x-tex">Y</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span></span></span></span></span>. We want to maximize the mutual information <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>Y</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,Y)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span></span></span></span></span>. This objective makes natural sense, we don't want to lose any important information.</p>
<p>Another thing we could say is that we should minimize the information passing through <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span>. The mutual info <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,X)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span></span></span></span></span> should be minimized. We call this the <em>information bottleneck</em>. Basically, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> should keep relevant information about <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Y</mi></mrow><annotation encoding="application/x-tex">Y</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span></span></span></span></span>, but discard the rest of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>X</mi></mrow><annotation encoding="application/x-tex">X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span>.</p>
<p>The tricky thing is that mutual information is hard to measure. The definition for mutual information involves optimal encoding and decoding. But we don't have those optimal functions. Instead, we need to approximate them using variational inference.</p>
<p>Let's start by stating a few assumptions. First we'll assume that X,Y,Z adhere to the Markov chain (Y - X - Z). This setup guarantees that P(Z|X) = P(Z|X,Y) and P(Y|X) = P(Y|X,Z). In other words, Y and Z are independent given X. We'll also assume that we're training a neural network to model p(z|x).</p>
<h4 id="information-bottleneck-objective">Information Bottleneck Objective</h4><!--kg-card-end: html--><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-27-at-3.32.09-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-27-at-3.32.09-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-27-at-3.32.09-PM.png 1000w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-27-at-3.32.09-PM.png 1382w" sizes="(min-width: 720px) 720px"></figure><!--kg-card-begin: html--><p>Here's our main information bottleneck objective (<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi></mrow><annotation encoding="application/x-tex">\beta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span></span></span></span></span> determines the "strength" of the bottleneck):</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>−</mo><mi>β</mi><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,Y) - \beta I(Z,X)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span></span></span></span></span>.</p>
</blockquote>
<p>We'll break it down piece by piece. Looking at the first term:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>=</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>−</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Y</mi><mi mathvariant="normal">∣</mi><mi>Z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,Y) = H(Y) - H(Y|Z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>=</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>+</mo><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,Y) = H(Y) + \int p(y,z) \; log \; p(y|z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> (Definition of conditional entropy.)</p>
</blockquote>
<p>The H(Y) term is easy, because it's a constant. We can't change the output distribution because it's defined by the data.</p>
<p>The second conditional term is trickier. Unfortunately, p(y|z) isn't tractable. We don't have that model. So we'll need to approximate it by introducing a second neural network, q(y|z). This is the variational approximation trick.</p>
<p>The key to this trick is recognizing the relationship between <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\int p(y,z) \; log \; p(y|z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> and KL divergence. KL divergence is always positive, so <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo>≥</mo><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>q</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">\int p(y,z) \; log \; p(y|z) \geq \int p(y,z) \; log \; q(y|z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">≥</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. That means we can plugin q(y|z) instead of p(y|z) and get a lower bound. Another way to interpret this trick is: H(Y|Z) is the lowest entropy we can achieve using <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> in the <em>optimal</em> way. If we use <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> in a sub-optimal way, that's an upper bound on the true entropy. The overall term is lower-bounded because we subtract H(Y|Z).</p>
<p>Because of the Markovian assumption above, we can break down p(y,z) into p(y|x)p(z|x). This version is easier to sample from.</p>
<p>Let's look at the second term now. We will use a similar decomposition.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo><mo>=</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Z</mi><mi mathvariant="normal">∣</mi><mi>X</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,X) = H(Z) - H(Z|X)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo><mo>=</mo><mo>−</mo><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo>+</mo><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo separator="true">,</mo><mi>z</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,X) = -\int p(z) \; log \; p(z) + \int p(x,z) \; log \; p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mord">−</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>This time, the second term is the one that is simple. We have the model p(z|x), so we can calculate the log-probability by running the model forward. The p(z) term is tricky. It's not possible to compute p(z) directly, so we approximate it with r(z). The same KL-divergence trick makes this an upper bound to H(Z).</p>
<p>Putting everything together, the information bottleneck term becomes:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>−</mo><mi>β</mi><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo><mo>≥</mo><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mtext>  </mtext><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>q</mi><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>β</mi><mo>∫</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><mrow><mi>r</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow></mfrac></mstyle></mrow><annotation encoding="application/x-tex">I(Z,Y) - \beta I(Z,X) \geq \int p(x)p(y|x)p(z|x) \; log \; q(y|z) - \beta \int p(x)p(z|x) log \; \dfrac{p(z|x)}{r(z)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">≥</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.363em; vertical-align: -0.936em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.427em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.936em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
</blockquote>
<p>It's simpler than it looks. We can compute an unbiased gradient of the above term using sampling. Once we sample <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi><mo>∼</mo><mi>X</mi></mrow><annotation encoding="application/x-tex">x \sim X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">∼</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span>, we can calculate p(y|x) and p(z|x) by running the model forward. The second term is actually a KL divergence between p(z|x) and r(z), which can be computed analytically if we assume both are Gaussian.</p>
<p>Thus the objective function can be written as:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext>min</mtext><mi>θ</mi></msub><mtext>  </mtext><mi>E</mi><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>y</mi><mi mathvariant="normal">∣</mi><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo>+</mo><mi>β</mi><mtext>  </mtext><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi><mi>r</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">\text{min}_\theta \; E[-log \; q_\phi(y|p_\theta(z,x)) + \beta \; KL(p_\theta(z|x) || r(z))]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord text"><span class="mord">min</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">y</span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord">∣</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span></p>
</blockquote>
<h4 id="intuitions">Intuitions</h4>
<p>Let's go over some intuitions about the above formula.</p>
<ol>
<li>
<p>The first part is equivalent to maximizing mutual information between Y and Z. Why is this equivalent? Mutual information can be written as <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>H</mi><mo stretchy="false">(</mo><mi>Y</mi><mo stretchy="false">)</mo><mo>−</mo><mi>H</mi><mo stretchy="false">(</mo><mi>Y</mi><mo separator="true">,</mo><mi>Z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">H(Y) - H(Y,Z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.08125em;">H</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mclose">)</span></span></span></span></span>. The first term is fixed. The second term asks "what is the best predictor of Y we can get, using Z"? We're training the network <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub></mrow><annotation encoding="application/x-tex">q_\phi</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.716668em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span></span></span></span></span> to do precisely this objective.</p>
</li>
<li>
<p>The second part is equivalent to minimizing mutual information between <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Z</mi></mrow><annotation encoding="application/x-tex">Z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span></span></span></span></span> and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>X</mi></mrow><annotation encoding="application/x-tex">X</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span></span></span></span></span>. Why? Because no matter which <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> is sampled, we want the posterior distribution <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p_\theta(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> to look like <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>r</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">r(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. In other words, if this term is minimized to zero, then p(z|x) would be the same for every <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span>. This is the "bottleneck" term.</p>
</li>
<li>
<p>The variational information bottleneck shares a lot of similarities with the variational auto-encoder. We can view the auto-encoder as a special case of the information bottleneck, where Y=X. In fact, the VAE objective <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mtext>min</mtext><mi>θ</mi></msub><mtext>  </mtext><mi>E</mi><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo>+</mo><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mi mathvariant="normal">∣</mi><mi mathvariant="normal">∣</mi><mi>r</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">\text{min}_\theta \; E[-log \; q_\phi(x|p_\theta(z,x)) + KL(p_\theta(z|x) || r(z))]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord text"><span class="mord">min</span></span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord">∣</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.02778em;">r</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span> is equivalent to what we derived above, if y=x and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>β</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">\beta=1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.88888em; vertical-align: -0.19444em;"></span><span class="mord mathnormal" style="margin-right: 0.05278em;">β</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>.</p>
<p>If we set X=Y, one could ask, isn't it weird to maximize <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo><mo>−</mo><mi>I</mi><mo stretchy="false">(</mo><mi>Z</mi><mo separator="true">,</mo><mi>X</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">I(Z,X) - I(Z,X)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">I</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.07153em;">Z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">)</span></span></span></span></span>? It's just zero. The answer is that we're optimizing different approximations to those two terms. The first term is the recreation objective, which optimizes a bound on H(X|Z). The second is the prior-matching objective, which optimizes the true H(Z|X) and a bound on H(Z).</p>
</li>
<li>
<p>What extra networks are needed to implement the variational information bottleneck? Nothing, really. In the end, we're splitting y=f(x) into two networks, p(z|x) and q(y|z). And we're putting a bottleneck on <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> via an auxiliary objective that it should match a constant prior r(z).</p>
</li>
</ol><!--kg-card-end: html--><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-27-at-3.41.02-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-27-at-3.41.02-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-27-at-3.41.02-PM.png 1000w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-27-at-3.41.02-PM.png 1392w" sizes="(min-width: 720px) 720px"></figure>]]></content:encoded></item><item><title><![CDATA[A Brief Examination of Generative Models]]></title><description><![CDATA[<p>Let's explore some generative models. There's a nice variety of them around. How well can the various methods model a given dataset? What are their flaws and strengths?</p><p>I'm interested in how well generative models handle <em>mode coverage</em>. Let's say my dataset has images of five cat breeds. How is</p>]]></description><link>https://kvfrans.com/a-brief-examination-of-generative-models/</link><guid isPermaLink="false">64d55fb1feb1cd031e03cd0f</guid><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Fri, 11 Aug 2023 21:58:26 GMT</pubDate><content:encoded><![CDATA[<p>Let's explore some generative models. There's a nice variety of them around. How well can the various methods model a given dataset? What are their flaws and strengths?</p><p>I'm interested in how well generative models handle <em>mode coverage</em>. Let's say my dataset has images of five cat breeds. How is each breed reflected in the generated images? If one breed is less prominent in the data, does that affect the accuracy of its generated versions?</p><p>The goal with this project was to get a deeper sense of how each generative model method works, and in what unique ways they can fail. We will take a look at 1) variational auto-encoders, 2) generative adversarial networks, 3) auto-regressive samplers, and 4) diffusion models.</p><p>Code for all these experiments is available at <a href="https://github.com/kvfrans/generative-model">https://github.com/kvfrans/generative-model</a>.</p><h4 id="the-dataset">The Dataset</h4><p>To get a cleaner look at the various models, I opted to create an artificial dataset rather than use natural images. The dataset is based off of MNIST digits. But, I've put some tricks in how the dataset is constructed. We'll use only these five digits, which I will refer to as "0A, 0B, 1, 2, 3, 4".</p><p>The main twist is that each digit is present in the dataset at differing frequencies. ["0A, 0B, 1, 2"] are included at the same equal frequency. "3" is half as likely to be sampled, and "4" is a quarter as likely. So, the digits "3" and "4" are less present than the other digits. A strong generative model should be able to correctly model these frequency differences. </p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.08.27-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.08.27-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.08.27-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.08.27-PM.png 1600w, https://kvfrans.com/content/images/size/w2400/2023/08/Screen-Shot-2023-08-10-at-3.08.27-PM.png 2400w" sizes="(min-width: 720px) 720px"></figure><p>You'll notice that 0 is in the list twice. This is intentional. There are two kinds of zeros, but they are quite similar. I'm curious if the models can correctly generate both kinds of digits.</p><p>We'll also add a second dimension of variation -- color. Again, we're going to put some twists on the frequencies here so everything isn't just uniform. Each digit has a (1/3) chance to be pure green. With (2/3) chance, the digit will instead be a random linear interpolation between red and blue. In other words, there is a continuous distribution of colors between red and blue, and a discontinuous set of pure green digits.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.11.31-PM.png" width="1144" height="534" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.11.31-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.11.31-PM.png 1000w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.11.31-PM.png 1144w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.09.31-PM-1.png" width="952" height="950" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.09.31-PM-1.png 600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.09.31-PM-1.png 952w" sizes="(min-width: 720px) 720px"></div></div></div><figcaption>Real data from our custom dataset.</figcaption></figure><p>How will we measure mode coverage? First, we will need a way to categorize generated digits into one of the 30 possible bins. Let's train a classifier network.</p><p>Actually, that's a trap and we won't. There's a much simpler way to classify them; nearest neighbors. To classify a digit, we'll simply loop over the six possible MNIST digits and find the one with the lowest error. We'll do the same for the average color. In this way, we can cheaply categorize digits to look at frequencies.</p><p>Accuracy is straightforward as well. Once a digit has been categorized, measure the distance between the generated digit and its closest match. That will be the accuracy.</p><h4 id="oracle-network">Oracle Network</h4><p>As a starting point, let's construct a network that can access privileged information. We'll call it the oracle network. This network will take as input a vector containing a one-hot encoding of the MNIST digit, and a three-vector of colors. Given this input, it must recreate the matching colored digit.</p><p>The oracle network has a straightforward learning objective. It only has to memorize colored versions of the six digits, and doesn't need to model any sort of probability distribution. We can train the network with a mean-squared error loss.</p><p>I'm using JAX to write the code for these experiments. The network structure is a few fully-connected linear layers, interspersed with tanh activations.</p><pre><code class="language-python">activation = nn.tanh
class GeneratorLinear(nn.Module):
    @nn.compact
    def __call__(self, x):
        # Input: N-length vector.
        x = nn.Dense(features=64)(x)
        x = activation(x)
        x = nn.Dense(features=128)(x)
        x = activation(x)
        x = nn.Dense(features=28*28*3)(x)
        x = nn.sigmoid(x)
        # Output: 28x28x3 image.
        x = x.reshape(x.shape[0], 28, 28, 3)
        return x</code></pre><p>Throughout this project, we'll do our best to use this same network structure for all the methods.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.18.20-PM.png" width="1572" height="734" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.18.20-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.18.20-PM.png 1000w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.18.20-PM.png 1572w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.18.28-PM.png" width="878" height="874" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.18.28-PM.png 600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.18.28-PM.png 878w" sizes="(min-width: 720px) 720px"></div></div></div><figcaption>Images created by the oracle network.</figcaption></figure><p>As expected, it's an easy task. Memorizing these MNIST digits is simple. Trickier will be generating them at the right frequencies.</p><h4 id="variational-auto-encoder">Variational Auto-Encoder</h4><p>The first generative modelling method we'll examine is the variational auto-encoder (VAE).</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/08/vae.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/vae.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/vae.png 1000w, https://kvfrans.com/content/images/2023/08/vae.png 1490w" sizes="(min-width: 720px) 720px"><figcaption>Source: https://danijar.com/building-variational-auto-encoders-in-tensorflow/</figcaption></figure><!--kg-card-begin: html--><p>The VAE works by jointly training an encoder and decoder network. The encoder takes in an image as input, and compresses it down to an N-length latent vector <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>. This latent vector is then given to the decoder, which recreates the original image.</p>
<p>The catch is that we're not happy just encoding and decoding an existing image. We want to generate images, which means sampling them from a random distribution. What is the distribution? It's <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. Unfortunately <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> is high dimensional and intractable to sample from.</p>
<p>Instead, we shape <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> so it's in a form that we can work with. We'll have the encoder output not a latent vector, but a Gaussian parametrized by mean and log-variance. We also apply a KL divergence objective between the output and the unit Gaussian (<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mn>0</mn><mo separator="true">;</mo><mi>σ</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">\mu = 0; \sigma = 1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.83888em; vertical-align: -0.19444em;"></span><span class="mord">0</span><span class="mpunct">;</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>). This objective encourages the model so that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> looks similar to the unit Gaussian.</p>
<p>This shaping means we can approximate <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> by sampling <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> from the unit Gaussian. And that's precisely what we do to generate new images -- sample a random latent variable, then pass it through the decoder.</p><!--kg-card-end: html--><p>For more on VAEs, check out <a href="https://kvfrans.com/variational-autoencoders-explained/">this post</a> and <a href="https://kvfrans.com/deriving-the-kl/">this newer one</a>.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.17.49-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.17.49-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.17.49-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.17.49-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.17.49-PM.png 2052w" sizes="(min-width: 720px) 720px"><figcaption>Images created by the variational auto-encoder.</figcaption></figure><p>The VAE is quite stable to train, and it does a good job. It's surprising how well the frequencies match the ground truth. The errors are greater for digits that are less frequently seen during training. Soon we'll see that a well-behaved objective is not to be taken for granted.</p><h4 id="generative-adversarial-networks">Generative Adversarial Networks</h4><p>Next up is the generative adversarial network (GAN). GANs attack the problem that measuring reconstruction via a mean-squared error loss is not ideal. Such a loss often results in a blurry image -- colors will collapse to the mean, so a digit that is either red or blue will be optimized to be purple. In the GAN setup, reconstruction error is instead measured by a learned discriminator network.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-11-at-5.07.05-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-11-at-5.07.05-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-11-at-5.07.05-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-11-at-5.07.05-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-11-at-5.07.05-PM.png 1662w" sizes="(min-width: 720px) 720px"><figcaption>Source: https://developers.google.com/machine-learning/gan/generator</figcaption></figure><p>The setup is as follows. The generator takes random noise as input, and produces fake images. The images are then passed to the discriminator, which predicts whether the image is real or fake. The generator's objective is to maximize the probability that the discriminator classifies its images as real.</p><p>Meanwhile, the discriminator is trained on both real and fake images, and is optimized to classify them accordingly. Thus, there is an adversarial game here. At the Nash equilibrium, the generator is producing images that exactly match the distribution of real images.</p><figure class="kg-card kg-gallery-card kg-width-wide kg-card-hascaption"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.19.01-PM-1.png" width="2000" height="981" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.19.01-PM-1.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.19.01-PM-1.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.19.01-PM-1.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.19.01-PM-1.png 2186w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.19.22-PM-1.png" width="2000" height="989" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.19.22-PM-1.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.19.22-PM-1.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.19.22-PM-1.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.19.22-PM-1.png 2090w" sizes="(min-width: 720px) 720px"></div></div></div><figcaption>Images creates by the generaive adversarial network. Note the mode collapse; most digits are missing. Green is also missing.</figcaption></figure><p>Looking at the frequencies gives a hint to a key flaw in GANs. They often suffer from "mode collapse", where the generator fails to capture the entire spectrum of data, and instead produces a limited subset. In the results above, we see a generator that produces only 2s, and a separate trial that produces only 0s.</p><p>The mode collapse issues arises because the discriminator assesses each image independently. Thus there is no picture of the <em>distribution</em> of generated images, which is needed to assess diversity. In theory, mode collapse is solved in the optimal equilibrium -- if a generator is producing only 0s, the discriminator can penalize 0s, forcing the generator to produce something else, and so on. In practice, it is hard to escape this local minima.</p><p>One way to address mode collapse is via <a href="https://arxiv.org/pdf/1606.03498v1.pdf">minibatch discrimination</a>. In this idea, the discriminator receives privileged information about the batch as a whole. For each image, the discriminator can see the average distance of the image to the rest of the batch, in a learned feature space. This setup allows the discriminator to learn what statistics should look like in a batch of real images, thus forcing the generator to match those statistics.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.24.33-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.24.33-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.24.33-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.24.33-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.24.33-PM.png 2350w" sizes="(min-width: 720px) 720px"></figure><p>Does it work? Well, we are doing better on the frequencies. But the quality of the results have taken a hit. GANs are quite unstable to train due to their adversarial nature. The discriminator must provide a clean signal for the generator to learn from, and often times this signal will be lost or diminished.</p><h4 id="auto-regressive-sampling">Auto-Regressive Sampling</h4><!--kg-card-begin: html--><p>In this next section, we'll take a look at a method that explicitly models the image probability distribution <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>.</p>
<p>The issue with modelling <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> directly is that the image space is high-dimensional. Instead, we need to break it into smaller components. PixelRNN breaks down an image into a sequence of pixels. If we can model <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mn>0</mn></msub><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><mi mathvariant="normal">.</mi><msub><mi>x</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(x_t|x_0...x_{t-1})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">t</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">0</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mord">.</span><span class="mord">.</span><span class="mord">.</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">−</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>, then <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> becomes the product of these individual probabilities. In other words -- generate an image pixel-by-pixel.</p>
<p>To model the probability distribution of a pixel, it helps to break the color into a set of discrete bins. I chose to represent each RGB channel as 8 possible options. Together this results in <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mn>8</mn><mn>3</mn></msup><mo>=</mo><mn>512</mn></mrow><annotation encoding="application/x-tex">8^3 = 512</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord">8</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">3</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">5</span><span class="mord">1</span><span class="mord">2</span></span></span></span></span> bins for each pixel. Our model will be a neural network that outputs a probability distribution over these 512 bins, given the preceding pixels.</p>
<p>To process our image, we'll flatten each image into a <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mn>28</mn><mo>∗</mo><mn>28</mn><mo>=</mo><mn>784</mn></mrow><annotation encoding="application/x-tex">28*28 = 784</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">2</span><span class="mord">8</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">∗</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">2</span><span class="mord">8</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">7</span><span class="mord">8</span><span class="mord">4</span></span></span></span></span>-length vector. The vector has 3 channels for the color, and an additional 2 channels for the normalized XY position of that pixel. 2 more channels are set to <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>1</mn><mo>−</mo><mi>X</mi><mi>Y</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(1-XY)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mord mathnormal" style="margin-right: 0.22222em;">Y</span><span class="mclose">)</span></span></span></span></span> coordinates. It's also helpful to include a boolean channel that is 1 if the pixel is on the right-most X boundary -- this information helps in modelling the jump between rows of an image.</p><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-11-at-5.24.50-PM.png" class="kg-image" alt><figcaption>Source: https://arxiv.org/abs/1601.06759v3</figcaption></figure><p>In the original PixelRNN setup, the neural network model is a recurrent network. It ingests a sequence of pixels recurrently, updating its hidden state, then using the final hidden state to predict the next pixel. I tried this setup but I had quite a hard time getting the model to converge.</p><p>Instead, we're going to use a hacky flattened version. The model gets to see the past 64 pixels, each which has 8 channels, for a total of $512$ channels per input. That's not bad at all. A flattened network is often inefficient because there is redundancy in treating each channel as separate, when in reality there is some temporal structure to it. But in our case, it's easier to learn a flattened feedforward network than a recurrent one due to stability in training.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.23.40-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.23.40-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.23.40-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.23.40-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.23.40-PM.png 1820w" sizes="(min-width: 720px) 720px"></figure><p>During implementation, I ran into some troubles. It was easy to mess up the sampling procedure. The network needs to iterate through the image, predicting the <em>next</em> pixel and placing that pixel in the correct space. You can see in this example that the rows don't match up correctly so the digit is shifted.</p><p>Even with the correct sampling, the network's results aren't great.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.24.58-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.24.58-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.24.58-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.24.58-PM.png 1600w, https://kvfrans.com/content/images/size/w2400/2023/08/Screen-Shot-2023-08-10-at-3.24.58-PM.png 2400w" sizes="(min-width: 720px) 720px"></figure><p>A key issue with auto-regressive sampling is that it's easy to leave the training distribution. Since we're generating each image a pixel at a time, if any of those pixels gets generated wrong, then the subsequent pixels will likely also be wrong. So the model needs to generate many pixels without any errors -- a chance that quickly decays to 0 as the sequence length increases.</p><p>We can see in the example images that the generated digits start to lose structure. The images all sort of resemble MNIST digits, but they don't have a coherent shape.</p><p>The frequencies of the generated digits are also quite wrong. This issue points to another flaw in auto-regressive sampling. P(x) is split into a sequence of probabilities, and important decisions have to be made early on. As an example, the choice for what digit to generate might occur in the 28th pixel, because that pixel is blank for a 2 or 3. It's not great for stability to have big sways in probabilities during certain pixel locations.</p><h4 id="diffusion-models">Diffusion Models</h4><p>Our final generative modelling method also samples auto-regressively, but does so over <em>time</em> instead of over pixels.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-11-at-5.26.01-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-11-at-5.26.01-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-11-at-5.26.01-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-11-at-5.26.01-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-11-at-5.26.01-PM.png 2068w" sizes="(min-width: 720px) 720px"><figcaption>Source: https://scholar.harvard.edu/binxuw/classes/machine-learning-scratch/materials/foundation-diffusion-generative-models</figcaption></figure><p>We'll start by defining a noising process. This process will take an image from the dataset, and progressively add Gaussian noise to it. By the end of the process (say, 300 timesteps), the image is purely Gaussian noise.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.29.47-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.29.47-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.29.47-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.29.47-PM.png 1600w, https://kvfrans.com/content/images/size/w2400/2023/08/Screen-Shot-2023-08-10-at-3.29.47-PM.png 2400w" sizes="(min-width: 720px) 720px"></figure><p>The task for the diffusion model is now to reverse that process. Starting from a pure Gaussian image, it should iteratively remove the noise until the original image is recreated. The nice property about recreating an image sequentially is that decisions can be spread out. The first few denoising steps might define the outline of an image, while later steps add in the details.</p><p>Gaussian noise gives us a few nice properties.</p><ul><li>Applying two steps of Gaussian noise to an image is equivalent to applying one step with greater variance. This property lets us easily simulate an image with $T$ steps of noising, in a single operation.</li></ul><!--kg-card-begin: html--><ul>
<li>If the forward process <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mi>t</mi></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x_{t+1}| x_{t})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> is Gaussian, then the reverse process <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><msub><mi>x</mi><mi>t</mi></msub><mi mathvariant="normal">∣</mi><msub><mi>x</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q(x_{t}| x_{t+1})</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.280556em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mord">∣</span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.301108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight">t</span><span class="mbin mtight">+</span><span class="mord mtight">1</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.208331em;"><span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span> is also Gaussian given the variance is low. (<a href="https://lilianweng.github.io/posts/2021-07-11-diffusion-models/" target="_blank" rel="nofollow noopener noreferrer">https://lilianweng.github.io/posts/2021-07-11-diffusion-models/</a>)</li>
</ul><!--kg-card-end: html--><p>Because of the second property, we can approximate the reverse process as a Gaussian. The variances are fixed, but we need to learn the mean of that distribution -- that's what our network does.</p><p>In the <a href="https://arxiv.org/abs/2006.11239">original DDPM paper</a>, the network predicts the noise that was added, not the original image. I found that this was hard to achieve. Because our model is quite small, it's hard to compress 28*28<em>*</em>3 samples of independent noise into a feedforward bottleneck. In contrast, it's simpler to compress the original image since there are only a few MNIST digits. The noise can be recovered via <code>noise = noisy_image - predicted_original_image</code>.</p><p>Note that the denoising process is a Gaussian at every timestep. So while we're predicting the mean with our learned network, we also add back noise. The variance of this noise follows the pre-determined schedule so it decreases over time.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.30.17-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.30.17-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.30.17-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.30.17-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.30.17-PM.png 1757w" sizes="(min-width: 720px) 720px"></figure><p>What do the metrics show?</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.25.21-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2023/08/Screen-Shot-2023-08-10-at-3.25.21-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2023/08/Screen-Shot-2023-08-10-at-3.25.21-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2023/08/Screen-Shot-2023-08-10-at-3.25.21-PM.png 1600w, https://kvfrans.com/content/images/2023/08/Screen-Shot-2023-08-10-at-3.25.21-PM.png 2364w" sizes="(min-width: 720px) 720px"></figure><p>Works nicely. The frequencies aren't as well-matched as the VAE, but it's at least covering all the modes. That's better than the GAN or PixelRNN.</p><p>The error rates are interesting. The highest error isn't by shape, but by color. Examining the generated images, it looks like the model has trouble deciding if it wants to make the digits red or blue. Really, those pixels should all be a shade of purple, with all pixels being the same shade.</p><p>As always, code for everything here is available at <a href="https://github.com/kvfrans/generative-model">https://github.com/kvfrans/generative-model</a>.</p>]]></content:encoded></item><item><title><![CDATA[Deriving the KL divergence loss in variational autoencoders]]></title><description><![CDATA[<!--kg-card-begin: html--><p>Let's derive some things related to variational auto-encoders (VAEs).</p>
<h4 id="evidence-lower-bound-elbo">Evidence Lower Bound (ELBO)</h4>
<p>First, we'll state some assumptions. We have a dataset of images, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span>. We'll assume that each image is generated from some unseen latent code <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, and there's an underlying distribution of latents <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi></mrow></semantics></math></span></span></span></p>]]></description><link>https://kvfrans.com/deriving-the-kl/</link><guid isPermaLink="false">64c914e9feb1cd031e03ccef</guid><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Tue, 01 Aug 2023 15:01:54 GMT</pubDate><content:encoded><![CDATA[<!--kg-card-begin: html--><p>Let's derive some things related to variational auto-encoders (VAEs).</p>
<h4 id="evidence-lower-bound-elbo">Evidence Lower Bound (ELBO)</h4>
<p>First, we'll state some assumptions. We have a dataset of images, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span>. We'll assume that each image is generated from some unseen latent code <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, and there's an underlying distribution of latents <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi></mrow><annotation encoding="application/x-tex">p(z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>). We'd like to discover parameters <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>θ</mi></mrow><annotation encoding="application/x-tex">\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.69444em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.02778em;">θ</span></span></span></span></span> to maximize the likelihood of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p_\theta(x|z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> under the data. In practical terms -- train a neural network that will generate images from the dataset.</p>
<p>Let's look at the posterior term closer.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo>=</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow></mfrac></mstyle></mrow><annotation encoding="application/x-tex">p_\theta(x|z) = \dfrac{p(z|x)p(x)}{p(z)}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.363em; vertical-align: -0.936em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.427em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.936em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span></span></span></span></span></p>
</blockquote>
<ul>
<li>We can calculate <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p_\theta(x|z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span> for any <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>.</li>
<li>We can assume <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>. In the VAE case, it's the standard Gaussian.</li>
<li>We don't know <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>. It is intractable to compute <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>=</mo><mo>∫</mo><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mi>d</mi><mi>z</mi></mrow><annotation encoding="application/x-tex">p(x) = \int p_\theta(x|z)p(z)dz</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.11112em; vertical-align: -0.30612em;"></span><span class="mop op-symbol small-op" style="margin-right: 0.19445em; position: relative; top: -0.00056em;">∫</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mord mathnormal">d</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span>, because the space of all <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> is large.</li>
<li>We also don't know <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>.</li>
</ul>
<p>To help us, we'll introduce an approximate model: <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q_\phi(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>. The idea is to closely match the true distribution <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>, but in a way that we can sample from it. We can quantify the distance between the two distributions via KL divergence.</p>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex"> KL(q_\phi(z|x), p(z|x)) = E_q[log \; q_\phi(z|x) - log \; p(z|x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span></p>
<p>In fact, breaking down this term will give us a hint on how to measure <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex"> = E_q[log \; q_\phi(z|x) - log \; p(z|x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mi mathvariant="normal">/</mi><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; q_\phi(z|x) - log \; p(z,x)/p(x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mord">/</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mo>+</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; q_\phi(z|x) - log \; p(z,x) + log \; p(x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo separator="true">,</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; q_\phi(z|x) - log \; p(z,x)] + log \; p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; q_\phi(z|x) - log \; p_\theta(x|z)p(z)] + log \; p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>We can rearrange terms as:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">log \; p(x) = E_q[log \; p_\theta(x|z)p(z) - log \; q_\phi(z|x)] + KL(q_\phi(z|x), p(z|x))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>So, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">log \; p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> breaks down into two terms.</p>
<ul>
<li>
<p>The KL term here is intractable because it involves <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>, so we can't sample it. At least, we know that KL is always greater than zero.</p>
</li>
<li>
<p>The expectation is tractable! It involves three functions that we can evaluate. Note that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>E</mi><mi>q</mi></msub></mrow><annotation encoding="application/x-tex">E_q</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.969438em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span></span></span></span></span> means <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> is sampled from the data, then <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>z</mi></mrow><annotation encoding="application/x-tex">z</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span></span></span></span></span> computed via <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q_\phi(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>. This expectation is a lower bound on <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">log \; p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>. We call it the evidence lower bound, or ELBO.</p>
</li>
</ul>
<p>The ELBO can be further broken down into two components.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex"> E_q[log \; p_\theta(x|z)p(z) - log \; q_\phi(z|x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo>+</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; p_\theta(x|z)+ log \; p(z) - log \; q_\phi(z|x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; p_\theta(x|z)] + E_q[log \; p(z) - log \; q_\phi(z|x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>q</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><msub><mi>p</mi><mi>θ</mi></msub><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>z</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo><mo separator="true">,</mo><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= E_q[log \; p_\theta(x|z)] + KL(p(z), q_\phi(z|x))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.151392em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">q</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>The ELBO is equal to:</p>
<ul>
<li>A reconstruction objective. For all <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> in the dataset, encoding it via <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub></mrow><annotation encoding="application/x-tex">q_\phi</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.716668em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span></span></span></span></span> then decoding via <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>p</mi><mi>θ</mi></msub></mrow><annotation encoding="application/x-tex">p_\theta</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathnormal">p</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: 0em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.02778em;">θ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span></span></span></span></span> should give high probability for the original <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span>.</li>
<li>A prior-matching objective. For all <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> in the dataset, the distribution of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q_\phi(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> should be similar to the prior <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>.</li>
</ul>
<p>Here's a practical way to look at the ELBO objective. We can't maximize <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> directly because we don't have access to <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>. But if we approximate <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> with <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>q</mi><mi>ϕ</mi></msub><mo stretchy="false">(</mo><mi>z</mi><mi mathvariant="normal">∣</mi><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q_\phi(z|x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.03611em; vertical-align: -0.286108em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.336108em;"><span style="top: -2.55em; margin-left: -0.03588em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight">ϕ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.286108em;"><span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mord">∣</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span>, we can get a lower bound on <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> and maximize that instead. The lower bound depends on 1) How well <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>x</mi><mo mathvariant="normal" lspace="0em" rspace="0em">′</mo></msup><mo>=</mo><mi>p</mi><mo stretchy="false">(</mo><mi>q</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">x' = p(q(x))</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.751892em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.751892em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">′</span></span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">)</span></span></span></span></span> recreates the data, and 2) How well <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">q(x)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span></span></span></span></span> matches the true prior <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>z</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">p(z)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.04398em;">z</span><span class="mclose">)</span></span></span></span></span>.</p><!--kg-card-end: html--><h4 id="analytical-kl-divergence-for-gaussians">Analytical KL divergence for Gaussians</h4><p>In the classic VAE setup, p(x) is a standard Gaussian. The prior-matching objective above can be computed analytically without the need to sample q(z|x). </p><!--kg-card-begin: html-->
<p>Since the VAE setup defines an independent Gaussian for each data point in the batch, we only need to consider the univariate Gaussian case (instead of a multivariate). The encoder network will output a mean <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> and standard deviation <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi></mrow><annotation encoding="application/x-tex">\sigma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span></span></span></span></span>. We'll refer to this as <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>P</mi><mo>=</mo><mi>N</mi><mo stretchy="false">(</mo><mi>μ</mi><mo separator="true">,</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">P = N(\mu, \sigma^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.68333em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.10903em;">N</span><span class="mopen">(</span><span class="mord mathnormal">μ</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span>. The standard Gaussian is <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>Q</mi><mo>=</mo><mi>N</mi><mo stretchy="false">(</mo><mn>0</mn><mo separator="true">,</mo><mn>1</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">Q = N(0,1)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.87777em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">Q</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.10903em;">N</span><span class="mopen">(</span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord">1</span><span class="mclose">)</span></span></span></span></span>.</p>
<p>The KL divergence measures how different two probability distributions are:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><mi>P</mi><mo separator="true">,</mo><mi>Q</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>P</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mtext>  </mtext><mi>Q</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">KL(P,Q) = E_{P}[log \;P(x) - log \;Q(x)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">Q</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mord mathnormal">Q</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span></p>
</blockquote>
<p>We'll also need the probability density for a Gaussian:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>p</mi><mo stretchy="false">(</mo><mi>x</mi><mi mathvariant="normal">∣</mi><mi>μ</mi><mo separator="true">,</mo><mi>σ</mi><mo stretchy="false">)</mo><mo>=</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mrow><mi>σ</mi><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mrow></mfrac></mstyle><msup><mi>e</mi><mrow><mo>−</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>x</mi><mo>−</mo><mi>μ</mi></mrow><mi>σ</mi></mfrac></mstyle><msup><mo stretchy="false">)</mo><mn>2</mn></msup></mrow></msup></mrow><annotation encoding="application/x-tex">p(x|\mu,\sigma) = \dfrac{1}{\sigma \sqrt{2 \pi}} e^{-\dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal">p</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mord">∣</span><span class="mord mathnormal">μ</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 3.16619em; vertical-align: -0.93em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 2.23619em;"><span style="top: -4.23619em; margin-right: 0.05em;"><span class="pstrut" style="height: 3.37644em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord sizing reset-size3 size6 mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.37644em;"><span style="top: -2.248em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mtight">2</span></span></span><span style="top: -3.2255em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line mtight" style="border-bottom-width: 0.049em;"></span></span><span style="top: -3.732em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.752em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mopen mtight">(</span><span class="mord sizing reset-size3 size6 mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.31533em;"><span style="top: -2.248em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mathnormal mtight" style="margin-right: 0.03588em;">σ</span></span></span><span style="top: -3.2255em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line mtight" style="border-bottom-width: 0.049em;"></span></span><span style="top: -3.732em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="mbin mtight">−</span><span class="mord mathnormal mtight">μ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.752em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mclose mtight"><span class="mclose mtight">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.891314em;"><span style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</blockquote>
<p>For Q, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\mu=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span></span></span></span></span> and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">\sigma=1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span> so the term simplifies.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>q</mi><mo stretchy="false">(</mo><mi>x</mi><mo stretchy="false">)</mo><mo>=</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mfrac></mstyle><msup><mi>e</mi><mrow><mo>−</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><msup><mi>x</mi><mn>2</mn></msup></mrow></msup></mrow><annotation encoding="application/x-tex">q(x) = \dfrac{1}{\sqrt{2 \pi}} e^{-\dfrac{1}{2}x^2}</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">q</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 3.16619em; vertical-align: -0.93em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mord"><span class="mord mathnormal">e</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 2.23619em;"><span style="top: -4.23619em; margin-right: 0.05em;"><span class="pstrut" style="height: 3.37644em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight"><span class="mord mtight">−</span><span class="mord sizing reset-size3 size6 mtight"><span class="mopen nulldelimiter sizing reset-size3 size6"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.37644em;"><span style="top: -2.248em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mtight">2</span></span></span><span style="top: -3.2255em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line mtight" style="border-bottom-width: 0.049em;"></span></span><span style="top: -3.732em;"><span class="pstrut" style="height: 3em;"></span><span class="mord mtight"><span class="mord mtight">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.752em;"><span></span></span></span></span></span><span class="mclose nulldelimiter sizing reset-size3 size6"></span></span><span class="mord mtight"><span class="mord mathnormal mtight">x</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.891314em;"><span style="top: -2.931em; margin-right: 0.0714286em;"><span class="pstrut" style="height: 2.5em;"></span><span class="sizing reset-size3 size1 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></span></p>
</blockquote>
<p>Given the above, let's plug in to the KL divergence.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><mi>P</mi><mo separator="true">,</mo><mi>Q</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mrow><mi>σ</mi><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mrow></mfrac></mstyle><mo stretchy="false">)</mo><mo>−</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>x</mi><mo>−</mo><mi>μ</mi></mrow><mi>σ</mi></mfrac></mstyle><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mfrac></mstyle><mo stretchy="false">)</mo><mo>+</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mi>x</mi><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">KL(P,Q) = E_P[log(\dfrac{1}{\sigma \sqrt{2 \pi}}) -\dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2 -log(\dfrac{1}{\sqrt{2 \pi}}) +\dfrac{1}{2}(x)^2]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">Q</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.25144em; vertical-align: -0.93em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.00744em; vertical-align: -0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.26033em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord mathnormal">μ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.25144em; vertical-align: -0.93em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.00744em; vertical-align: -0.686em;"></span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span></p>
</blockquote>
<p>Let's break down this expectation and deal with each term one-by-one.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mo stretchy="false">(</mo><mi>P</mi><mo separator="true">,</mo><mi>Q</mi><mo stretchy="false">)</mo><mo>=</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mrow><mi>σ</mi><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mrow></mfrac></mstyle><mo stretchy="false">)</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mfrac></mstyle><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo>+</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo>−</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>x</mi><mo>−</mo><mi>μ</mi></mrow><mi>σ</mi></mfrac></mstyle><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo stretchy="false">]</mo><mo>+</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mi>x</mi><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">KL(P,Q) = E_P[log(\dfrac{1}{\sigma \sqrt{2 \pi}}) -log(\dfrac{1}{\sqrt{2 \pi}})] + E_P[- \dfrac{1}{2}(\dfrac{x-\mu}{\sigma})^2] + E_P[\dfrac{1}{2}(x)^2]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.13889em;">P</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal">Q</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.25144em; vertical-align: -0.93em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.25144em; vertical-align: -0.93em;"></span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.00744em; vertical-align: -0.686em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord">−</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.26033em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord mathnormal">μ</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 2.00744em; vertical-align: -0.686em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span></p>
</blockquote>
<p>The first expectation involves two logs that can be combined into one. There is no <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>x</mi></mrow><annotation encoding="application/x-tex">x</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal">x</span></span></span></span></span> term, so the expectation can be dropped.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mi>σ</mi><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mrow><msqrt><mrow><mn>2</mn><mi>π</mi></mrow></msqrt></mfrac></mstyle><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex"> = E_P[-log(\dfrac{\sigma \sqrt{2 \pi}}{\sqrt{2 \pi}})]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 2.51422em; vertical-align: -0.93em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.58422em;"><span style="top: -2.20278em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mord sqrt"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.90722em;"><span class="svg-align" style="top: -3em;"><span class="pstrut" style="height: 3em;"></span><span class="mord" style="padding-left: 0.833em;"><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">π</span></span></span><span style="top: -2.86722em;"><span class="pstrut" style="height: 3em;"></span><span class="hide-tail" style="min-width: 0.853em; height: 1.08em;"><svg width="400em" height="1.08em" viewbox="0 0 400000 1080" preserveaspectratio="xMinYMin slice"><path d="M95,702
c-2.7,0,-7.17,-2.7,-13.5,-8c-5.8,-5.3,-9.5,-10,-9.5,-14
c0,-2,0.3,-3.3,1,-4c1.3,-2.7,23.83,-20.7,67.5,-54
c44.2,-33.3,65.8,-50.3,66.5,-51c1.3,-1.3,3,-2,5,-2c4.7,0,8.7,3.3,12,10
s173,378,173,378c0.7,0,35.3,-71,104,-213c68.7,-142,137.5,-285,206.5,-429
c69,-144,104.5,-217.7,106.5,-221
l0 -0
c5.3,-9.3,12,-14,20,-14
H400000v40H845.2724
s-225.272,467,-225.272,467s-235,486,-235,486c-2.7,4.7,-9,7,-19,7
c-6,0,-10,-1,-12,-3s-194,-422,-194,-422s-65,47,-65,47z
M834 80h400000v40h-400000z"/></svg></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.13278em;"><span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.93em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mi>σ</mi><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">= E_P[-log(\sigma)]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><mi>σ</mi><mo stretchy="false">)</mo><mo>=</mo><mo>−</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= -log(\sigma) = -(1/2)log(\sigma^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>The second expectation can be simplified by knowing that <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo stretchy="false">(</mo><mi>x</mi><mo>−</mo><mi>μ</mi><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">E_P[(x-\mu)^2]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span> is the equation for variance (<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>σ</mi><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">\sigma^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span>).</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo>−</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mn>1</mn><mn>2</mn></mfrac></mstyle><mo stretchy="false">(</mo><mstyle displaystyle="true" scriptlevel="0"><mfrac><mrow><mo stretchy="false">(</mo><mi>x</mi><mo>−</mo><mi>μ</mi><msup><mo stretchy="false">)</mo><mn>2</mn></msup></mrow><msup><mi>σ</mi><mn>2</mn></msup></mfrac></mstyle><mo stretchy="false">)</mo><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">E_P[- \dfrac{1}{2}(\dfrac{(x-\mu)^2}{\sigma^2})]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 2.17711em; vertical-align: -0.686em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mord">−</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.32144em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">2</span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord">1</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mopen">(</span><span class="mord"><span class="mopen nulldelimiter"></span><span class="mfrac"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 1.49111em;"><span style="top: -2.314em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.740108em;"><span style="top: -2.989em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span><span style="top: -3.23em;"><span class="pstrut" style="height: 3em;"></span><span class="frac-line" style="border-bottom-width: 0.04em;"></span></span><span style="top: -3.677em;"><span class="pstrut" style="height: 3em;"></span><span class="mord"><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mord mathnormal">μ</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.686em;"><span></span></span></span></span></span><span class="mclose nulldelimiter"></span></span><span class="mclose">)</span><span class="mclose">]</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><mo>−</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><msub><mi>E</mi><mi>P</mi></msub><mo stretchy="false">[</mo><mo stretchy="false">(</mo><mi>x</mi><mo>−</mo><mi>μ</mi><msup><mo stretchy="false">)</mo><mn>2</mn></msup><mo stretchy="false">)</mo><mo stretchy="false">]</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= -(1/2) E_P[(x-\mu)^2)] (1/\sigma^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="msupsub"><span class="vlist-t vlist-t2"><span class="vlist-r"><span class="vlist" style="height: 0.328331em;"><span style="top: -2.55em; margin-left: -0.05764em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mathnormal mtight" style="margin-right: 0.13889em;">P</span></span></span></span><span class="vlist-s">​</span></span><span class="vlist-r"><span class="vlist" style="height: 0.15em;"><span></span></span></span></span></span></span><span class="mopen">[</span><span class="mopen">(</span><span class="mord mathnormal">x</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal">μ</span><span class="mclose"><span class="mclose">)</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mclose">]</span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><mo>−</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= -(1/2) \sigma^2 (1/\sigma^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo>=</mo><mo>−</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">= -(1/2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.36687em; vertical-align: 0em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">−</span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>The third expectation can also be explained in terms of variance. An equivalent equation for variance is <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>σ</mi><mn>2</mn></msup><mo>=</mo><mi>E</mi><mo stretchy="false">[</mo><msup><mi>X</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>−</mo><mi>E</mi><mo stretchy="false">[</mo><mi>X</mi><msup><mo stretchy="false">]</mo><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">\sigma^2 = E[X^2] - E[X]^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose"><span class="mclose">]</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span>. In our case, <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>E</mi><mo stretchy="false">[</mo><mi>X</mi><mo stretchy="false">]</mo><mo>=</mo><mi>μ</mi></mrow><annotation encoding="application/x-tex">E[X] = \mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span>, and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>E</mi><mo stretchy="false">[</mo><msup><mi>X</mi><mn>2</mn></msup><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">E[X^2]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span> is what we want to find. So,</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><msup><mi>σ</mi><mn>2</mn></msup><mo>=</mo><mi>E</mi><mo stretchy="false">[</mo><msup><mi>X</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>−</mo><msup><mi>μ</mi><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">\sigma^2 = E[X^2] - \mu^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.814108em; vertical-align: 0em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.00855em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>E</mi><mo stretchy="false">[</mo><msup><mi>X</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>=</mo><msup><mi>σ</mi><mn>2</mn></msup><mo>+</mo><msup><mi>μ</mi><mn>2</mn></msup></mrow><annotation encoding="application/x-tex">E[X^2] = \sigma^2 + \mu^2</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.897438em; vertical-align: -0.08333em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.00855em; vertical-align: -0.19444em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span></span></span></span></span><br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><mi>E</mi><mo stretchy="false">[</mo><msup><mi>X</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>=</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><mo stretchy="false">(</mo><msup><mi>σ</mi><mn>2</mn></msup><mo>+</mo><msup><mi>μ</mi><mn>2</mn></msup><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(1/2)E[X^2] = (1/2)(\sigma^2 + \mu^2)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mord mathnormal" style="margin-right: 0.05764em;">E</span><span class="mopen">[</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.07847em;">X</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span></span></span></span></span></p>
</blockquote>
<p>Nice, all of our expectations are now expressed as functions of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi></mrow><annotation encoding="application/x-tex">\sigma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span></span></span></span></span>. Put together, we get the final equation for the KL divergence loss:</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>K</mi><mi>L</mi><mi mathvariant="normal">_</mi><mi>L</mi><mi>o</mi><mi>s</mi><mi>s</mi><mo stretchy="false">(</mo><mi>μ</mi><mo separator="true">,</mo><mi>σ</mi><mo stretchy="false">)</mo><mo>=</mo><mo stretchy="false">(</mo><mn>1</mn><mi mathvariant="normal">/</mi><mn>2</mn><mo stretchy="false">)</mo><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo><mo>−</mo><mn>1</mn><mo>+</mo><msup><mi>σ</mi><mn>2</mn></msup><mo>+</mo><msup><mi>μ</mi><mn>2</mn></msup><mo stretchy="false">]</mo></mrow><annotation encoding="application/x-tex">KL\_Loss(\mu,\sigma) = (1/2) [-log(\sigma^2) - 1 +\sigma^2 + \mu^2]</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06em; vertical-align: -0.31em;"></span><span class="mord mathnormal" style="margin-right: 0.07153em;">K</span><span class="mord mathnormal">L</span><span class="mord" style="margin-right: 0.02778em;">_</span><span class="mord mathnormal">L</span><span class="mord mathnormal">o</span><span class="mord mathnormal">s</span><span class="mord mathnormal">s</span><span class="mopen">(</span><span class="mord mathnormal">μ</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord">1</span><span class="mord">/</span><span class="mord">2</span><span class="mclose">)</span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.72777em; vertical-align: -0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.897438em; vertical-align: -0.08333em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span></span></span></span></span></p>
</blockquote>
<p>Note that our encoder would give us a batch-sized vector of <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi></mrow><annotation encoding="application/x-tex">\mu</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span></span></span></span></span> and <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi></mrow><annotation encoding="application/x-tex">\sigma</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span></span></span></span></span>. Because we assume each <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>μ</mi><mo separator="true">,</mo><mi>σ</mi><mo stretchy="false">)</mo></mrow><annotation encoding="application/x-tex">(\mu, \sigma)</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">μ</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span></span></span></span></span> pair parametrizes an independent Gaussian, the total loss is the sum of the above equation applied elementwise.</p>
<p>Sanity check: the KL Loss should be minimized when <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mn>0</mn><mo separator="true">,</mo><mi>σ</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">\mu=0, \sigma=1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.83888em; vertical-align: -0.19444em;"></span><span class="mord">0</span><span class="mpunct">,</span><span class="mspace" style="margin-right: 0.166667em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>.</p>
<blockquote>
<p><span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>d</mi><mi mathvariant="normal">/</mi><mi>d</mi><mi>μ</mi><mo stretchy="false">)</mo><mtext>  </mtext><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo><mo>−</mo><mn>1</mn><mo>+</mo><msup><mi>σ</mi><mn>2</mn></msup><mo>+</mo><msup><mi>μ</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>=</mo><mn>2</mn><mi>μ</mi><mo>→</mo></mrow><annotation encoding="application/x-tex">(d/d \mu) \; [-log(\sigma^2) - 1 +\sigma^2 + \mu^2] = 2 \mu \rightarrow</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">d</span><span class="mord">/</span><span class="mord mathnormal">d</span><span class="mord mathnormal">μ</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.72777em; vertical-align: -0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.897438em; vertical-align: -0.08333em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.83888em; vertical-align: -0.19444em;"></span><span class="mord">2</span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">→</span></span></span></span></span> min. at <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>μ</mi><mo>=</mo><mn>0</mn></mrow><annotation encoding="application/x-tex">\mu=0</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.625em; vertical-align: -0.19444em;"></span><span class="mord mathnormal">μ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">0</span></span></span></span></span>.<br>
<span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mo stretchy="false">(</mo><mi>d</mi><mi mathvariant="normal">/</mi><mi>d</mi><mi>σ</mi><mo stretchy="false">)</mo><mtext>  </mtext><mo stretchy="false">[</mo><mo>−</mo><mi>l</mi><mi>o</mi><mi>g</mi><mo stretchy="false">(</mo><msup><mi>σ</mi><mn>2</mn></msup><mo stretchy="false">)</mo><mo>−</mo><mn>1</mn><mo>+</mo><msup><mi>σ</mi><mn>2</mn></msup><mo>+</mo><msup><mi>μ</mi><mn>2</mn></msup><mo stretchy="false">]</mo><mo>=</mo><mn>2</mn><mi>σ</mi><mo>−</mo><mn>2</mn><mi mathvariant="normal">/</mi><mi>σ</mi><mo>→</mo></mrow><annotation encoding="application/x-tex">(d/d \sigma) \; [-log(\sigma^2) - 1 +\sigma^2 + \mu^2] = 2 \sigma - 2/\sigma \rightarrow</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mopen">(</span><span class="mord mathnormal">d</span><span class="mord">/</span><span class="mord mathnormal">d</span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mopen">[</span><span class="mord">−</span><span class="mord mathnormal" style="margin-right: 0.01968em;">l</span><span class="mord mathnormal">o</span><span class="mord mathnormal" style="margin-right: 0.03588em;">g</span><span class="mopen">(</span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">)</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.72777em; vertical-align: -0.08333em;"></span><span class="mord">1</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 0.897438em; vertical-align: -0.08333em;"></span><span class="mord"><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">+</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1.06411em; vertical-align: -0.25em;"></span><span class="mord"><span class="mord mathnormal">μ</span><span class="msupsub"><span class="vlist-t"><span class="vlist-r"><span class="vlist" style="height: 0.814108em;"><span style="top: -3.063em; margin-right: 0.05em;"><span class="pstrut" style="height: 2.7em;"></span><span class="sizing reset-size6 size3 mtight"><span class="mord mtight">2</span></span></span></span></span></span></span></span><span class="mclose">]</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.72777em; vertical-align: -0.08333em;"></span><span class="mord">2</span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.222222em;"></span><span class="mbin">−</span><span class="mspace" style="margin-right: 0.222222em;"></span></span><span class="base"><span class="strut" style="height: 1em; vertical-align: -0.25em;"></span><span class="mord">2</span><span class="mord">/</span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">→</span></span></span></span></span> min. at <span class="math math-inline"><span class="katex"><span class="katex-mathml"><math xmlns="http://www.w3.org/1998/Math/MathML"><semantics><mrow><mi>σ</mi><mo>=</mo><mn>1</mn></mrow><annotation encoding="application/x-tex">\sigma=1</annotation></semantics></math></span><span class="katex-html" aria-hidden="true"><span class="base"><span class="strut" style="height: 0.43056em; vertical-align: 0em;"></span><span class="mord mathnormal" style="margin-right: 0.03588em;">σ</span><span class="mspace" style="margin-right: 0.277778em;"></span><span class="mrel">=</span><span class="mspace" style="margin-right: 0.277778em;"></span></span><span class="base"><span class="strut" style="height: 0.64444em; vertical-align: 0em;"></span><span class="mord">1</span></span></span></span></span>.</p>
</blockquote>
<p>All good!</p><!--kg-card-end: html--><h4 id="an-aside-on-batch-wise-kl-vs-element-wise-kl">An aside on batch-wise KL vs. element-wise KL</h4><p>A common misconception when examining this KL loss is how it relates to each batch of data. One could imagine that the KL loss is meant to shape a given batch of latent vectors so the approximate distribution within <em>the batch</em> is standard Gaussian. This interpretation would mean that on average, the latents should have mean 0 and variation 1 – but each sample can vary. It is possible to nearly-perfectly match the standard Gaussian while conveying information about the input image.</p><p>The above interpretation is incorrect. Instead, the KL loss is applied <em>element-wise</em>. Each image is encoded into a mean+variation pair, and the loss encourages this explicit mean and variation to resemble 0 and 1. Thus to nearly-match the standard Gaussian, each image would encode to a standard Gaussian (and thus no unique information is conveyed). In this way, the KL loss and the recreation objective are conflicting. A balance is located depending on the scaling of the two objectives.</p>]]></content:encoded></item><item><title><![CDATA[For AGI, we need better tasks. For better tasks, we need open-endedness. (ALOE 2022 Notes)]]></title><description><![CDATA[<p>Open-endedness is the idea of a system that "endlessly creates increasingly interesting designs". That's a simple concept with big implications. Biological life on earth is diverse and beautiful and the result of open-ended evolution. Human culture is open-ended; we've created a lot of art and technology without any real plan.</p>]]></description><link>https://kvfrans.com/notes-on-aloe/</link><guid isPermaLink="false">626ef9471668172f903c5d99</guid><category><![CDATA[research]]></category><category><![CDATA[thoughts]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Sun, 15 May 2022 19:00:11 GMT</pubDate><content:encoded><![CDATA[<p>Open-endedness is the idea of a system that "endlessly creates increasingly interesting designs". That's a simple concept with big implications. Biological life on earth is diverse and beautiful and the result of open-ended evolution. Human culture is open-ended; we've created a lot of art and technology without any real plan.</p><p>For those of us interested in AI and the emergence of intelligence, open-endedness is especially important. Because even with powerful learning algorithms, intelligence is one of those tricky things that's hard and complicated to specify. But that's the beauty of open-ended systems: they can generate <em>unexpected</em> designs, designs that go beyond the intention of the designer. And so even if intelligence is hard to create, it might be possible to create a system where intelligence naturally emerges.</p><p>This weekend, a bunch of us thinking about how open-endedness can boost AI gathered at the Agent Learning in Open-Endedness (ALOE) workshop at ICLR. I had the pleasure of attending, and since this topic is really fascinating I took some notes, so here they are.</p><p>As context, I (Kevin Frans) am a master's student at MIT under Phillip Isola, and I've spent time at places like OpenAI and Cross Labs. The hypothesis I've been studying is: a strong learning algorithm, coupled with a rich enough set of tasks, will result in intelligence. We have strong learning algorithms, but our tasks are quite bad, so what are consistent principles for making task distributions that are rich, learnable, yet increasingly complex?</p><h3 id="more-tasks-smarter-agents-">More tasks = smarter agents.</h3><p>A common theme in AI is that we take the same basic learning algorithms, and scale them up with harder tasks, and they can discover more complicated behaviors. There's arguments of whether this fact is good or bad for research, but the fact is that it works, and accepting this can lead to cool practical results.</p><p>XLand, the new task set from Deepmind, works by accepting this lesson and pushing it to the limit. XLand aims to create a whole bunch of related tasks by 1) defining a distribution of games, and 2) making them all multi-agent. </p><p>The key lesson here is that distribution of XLand tasks is just so varied and dense. XLand tasks take place in a 3D physics world, and agents need to do things like move blocks around, catch other agents, or play hide-and-seek. These games are defined in a modular way, so some agents get reward based on Distance(Agent, Ball) while others are rewarded for -Distance(AgentA, AgentB). Each task also takes place on a procedurally-generated terrain. Tasks are inherently multi-agent, so the same task can have a different reward landscape based on which agents are involved. </p><p>Now that we've got a rich distribution of games, just train agents with RL and they will be smart. There's some fancy population-based training involved, but it's a side point, the main idea is that agents trained on this rich set of games have no choice but to become generally capable, and they do. Strong XLand agents perform well on a bunch of the tasks, even unseen ones, and seem to generally do the right thing.</p><p>To me, the nice part about XLand is that it confirms we can achieve strong generalization just by engineering the task space to be richer, denser, and larger. Now AGI is a game design problem, and game design is fun. The downside is that this all seems awfully expensive to compute.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-01-at-6.35.49-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2022/05/Screen-Shot-2022-05-01-at-6.35.49-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2022/05/Screen-Shot-2022-05-01-at-6.35.49-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2022/05/Screen-Shot-2022-05-01-at-6.35.49-PM.png 1600w, https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-01-at-6.35.49-PM.png 2214w" sizes="(min-width: 720px) 720px"></figure><h3 id="procedural-content-generation-ai-">Procedural content generation + AI.</h3><p>My favorite video games can be replayed hundreds of times, and these games usually make use of procedural content generation (PCG). Instead of manually building out levels, we can define an algorithm that builds the levels for us, and use it to generate an infinite amount of content.</p><p>Sam Earle's work is on how PCG and AI intersect. First, can AI methods be used as a PCG algorithm? You could imagine PCG being a data modelling problem – given a set of 1000 Mario levels, train a GAN or autoencoder to model the distribution. But this is open-endedness, we're not satisfied with mimicry. Sam's early work on <a href="https://ojs.aaai.org/index.php/AIIDE/article/view/7416/7341">PCGRL</a> frames content generation as a reinforcement learning task. If we can define a measurement of how good a level is, we can try and learn a generator that best fits this measurement. A <a href="https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&amp;arnumber=9619159">follow-up paper</a> takes this further, and lets the user specify a bunch of parameters like level-length or difficulty, which the AI generator does its best to fulfill.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-13-at-4.28.00-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2022/05/Screen-Shot-2022-05-13-at-4.28.00-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2022/05/Screen-Shot-2022-05-13-at-4.28.00-PM.png 1000w, https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-13-at-4.28.00-PM.png 1336w" sizes="(min-width: 720px) 720px"></figure><p>To me, the simplest open-ended loop is a zero-sum competitive game. Two players compete against each other, and adapt to each other's weakness in a continuous learning loop. So if we want to train a really good Mario-playing agent, it's crucial that we think just as hard about a Mario level-generating agent. That's where AI PCG methods come in – we're only going to get so far with human-made levels, so we're going to need to build generator agents that never stop learning.</p><h3 id="the-tao-of-agents-and-environments">The Tao of Agents and Environments</h3><p>"Instead of a real presentation, here are some slides that I threw together in an hour", started Julius Togelius. "Actually I spent two days on this, but it's funnier to pretend I didn't." He then proceeded to show pictures of Agent Smith from The Matrix, and an environment diagram of the H2O cycle.</p><p>The purpose of all this was to point out that our classic framework of "agents interacting with an environment" is limiting. What if our environment has other agents in it? Why does the game-player have to be the agent? Can the environment be the player? If  we start to question these terms, maybe there's new ground we can think about when it comes to "agent-based" systems.</p><p>From an open-endedness perspective, this idea makes sense, because we want to think of our systems as more than just optimization problems. A traditional reinforcement learning problem is to train a robotic agent to solve a control environment, it's clean and simple and solvable. But in an open-ended system, maybe we want the robot to compete against other robots. Maybe the real goal is not to develop strong robots, but to see the emergence of cooperation or communication. Maybe we care about these robots because we want to discover new challenges for new robots. In an open-endedness view, everything is adapting so everything is an agent, and everything affects everything else so everything is also an environment.</p><h3 id="to-see-emergent-complexity-use-antagonistic-pairs-">To see emergent complexity, use antagonistic pairs.</h3><p>A fundamental open-ended loop is:</p><ol><li>Solve a task</li><li>Generate a new task</li><li>Repeat</li></ol><p>Solving a task is well-studied, we have whole fields on how to do this. How to <em>generate new tasks</em>, however, is ripe for exploration. We need to be careful. If the tasks are too similar, then we're not learning anything new, but if the tasks are too distinct, an agent might not be able to solve them. Ideally, we want to define some implicit curriculum of tasks that an agent can gradually work its way through.</p><p>Past works on open-ended task generation include ideas like <a href="https://arxiv.org/abs/1901.01753">POET</a>, where tasks are generated by 1) making 100 variations of a solved task, and 2) trying to solve as many of the variations as we can. This framework does work; but it's limited by the hard-coded method of making task variations. We can do something smarter.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-13-at-4.28.24-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2022/05/Screen-Shot-2022-05-13-at-4.28.24-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2022/05/Screen-Shot-2022-05-13-at-4.28.24-PM.png 1000w, https://kvfrans.com/content/images/2022/05/Screen-Shot-2022-05-13-at-4.28.24-PM.png 1434w" sizes="(min-width: 720px) 720px"></figure><p><a href="https://natashajaques.ai/publication/paired/">PAIRED</a>, as presented by Natasha Jaques, presents a more adaptive way of generating tasks. If we think about the kinds of tasks we want, its tasks that the agent is just barely capable of solving. PAIRED proposes a heuristic using two agents, a protagonist and an antagonist, and a task generator. The generator tries to propose tasks that the protagonist can't solve, but the antagonist can. This way, we're always generating tasks in the sweet spot where they're solvable, but the protagonist hasn't quite figured it out yet. "Hey protagonist", says the generator, "here's a task that your brother the antagonist can solve, so I know you can solve it too, so start trying".</p><p>I think these little loops are quite powerful, and we just need to find better places to apply them. As a researcher, if I want to train a robust AI agent, I should try and parametrize a super-expressive distribution of tasks, then let loose an open-ended process to gradually expand what it can solve.</p><h3 id="panel-discussion">Panel Discussion</h3><p>Along with the presentations, here were some fun thoughts from the panel discussions:</p><ul><li>Ken: What's the difference between AI and ALIFE? Well in AI we're somewhat focusing on economic value, whereas in ALIFE we really just care about what's cool. </li><li>Ken: ALIFE focuses too much on evolution, which we don't really need. But in AI and machine learning, we need to work on weirder and braver ideas.</li><li>Danijar: We need more open-ended environments. Atari is dumb, robotic control is too simple.</li><li>Ken: Environmental design for OE is a deep field. It needs to be understood. What gives our universe properties that lets it sustain OE for so long?</li><li>Julian: People who designed games like Skyrim and Elden Ring have figured out a bunch of heuristics that make good OE experiences.</li><li>Danijar: Minecraft kind of has crafting, but it's not OE crafting, it's hardcoded. We want to have emergent crafting, but also be computationally feasable.</li><li>Joel: The best part of ALIFE is the belief that it's totally possible to make life or life-like evolution, and it should be easy, and it's just right there. </li><li><strong>Moderator: Say you travel 10 years into the future, and there's an open-ended simulation that works. What does it look like?</strong></li><li>(everyone looks up and thinks)</li><li>Joel: We tackle the meta learning objective. We optimize for human preferences, we tell a system to keep making cool stuff that is cool to us as humans.</li><li>Natasha: Inter-group competition. We get something interesting where we have competition, and groups competing against each other, so we have cooperative stuff like language, communities, etc.</li><li>Julian: The system is super diverse. There are many things going on. I don't know what the principles are, but I won’t understand what’s going on. The metrics look great, but it's like looking at a coral reef, I don't know what's going on but there’s pretty colors. Like a coral reef, there will be many extremely different kinds of things.</li><li>Danijar: We get better and better at simulating some version of the real world, at a level of abstraction that is more and more accurate, maybe some MMO game to collect resources, or a mathematics environment where we discover more theorems.</li><li>Ken: There’s so many possibilities its hard to say, at a low level there’s creatures in a world that are reminiscent of nature, there’s theorem proving, agents that are with humans in the web, OE interacting with humans.</li></ul><p>For more, check out the ALOE <a href="https://sites.google.com/view/aloe2022">proceedings</a> and <a href="https://iclr.cc/virtual/2022/workshop/4547">workshop</a>.</p>]]></content:encoded></item><item><title><![CDATA[A Mathematical Definition of Interestingness]]></title><description><![CDATA[<p>In open-endedness research, we care about building systems that continously generate interesting designs. So what makes something interesting? Previously, <a href="https://kvfrans.com/thoughts-oct-29/">I argued</a> that interestestness depends on a viewer's perspective. But this answer isn't very satisfying.</p><p>Ideally, we would want some mathematically-definable way to measure how interesting something is. A naive approach</p>]]></description><link>https://kvfrans.com/kvfrans-thoughts-nov-15/</link><guid isPermaLink="false">6192dc0c4365cb5561d981dc</guid><category><![CDATA[thoughts]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Mon, 06 Dec 2021 22:28:59 GMT</pubDate><content:encoded><![CDATA[<p>In open-endedness research, we care about building systems that continously generate interesting designs. So what makes something interesting? Previously, <a href="https://kvfrans.com/thoughts-oct-29/">I argued</a> that interestestness depends on a viewer's perspective. But this answer isn't very satisfying.</p><p>Ideally, we would want some mathematically-definable way to measure how interesting something is. A naive approach would to be to define interestingness as variation or entropy, but we run into the following problem:</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/image-2.png" class="kg-image" alt></figure><p>In this example, you get the highest entropy by mixing milk and coffee, since at the molecular level the particles are the most widely distributed. But somehow the middle cup is the most interesting to us as humans, since the half-mixed milk forms all sorts of cool swirls and twists.</p><p>An analogy I like to draw here is the <a href="https://galaxykate0.tumblr.com/post/139774965871/so-you-want-to-build-a-generator">10,000 oatmeal problem</a>, which states:</p><blockquote>I can easily generate 10,000 bowls of plain oatmeal, with each oat being in a different position and different orientation, and <em>mathematically speaking</em> they will all be completely unique. But the user will likely just see <em>a lot of oatmeal</em>.</blockquote><p>Entropy is not all that great! Sure, we get a lot of unique designs, but at some level they all look like the same design. Something is missing in our interestingness formulation.</p><p>The best definition of interestingness I've come across so far is the idea of <em>sophistication</em>, which is mentioned in <a href="https://scottaaronson.blog/?p=762">this blog post by Scott Aaronson</a>:</p><blockquote>Sophistication(x) is the length of the shortest computer program that describes, not necessarily x itself, but a set S of which x is a “random” or “generic” member.</blockquote><p>Sophistication solves our milk-coffee problem and our oatmeal problem. In both cases, the high-entropy examples can be described by very short programs – just sample random positions for each particle. Something like the swirls in the half-mixed cup has higher sophistication, since you'd need a more complicated program to generate those artifacts.</p><p>How can we apply sophistication in the context of open-endedness and machine learning? Let's say we're trying to generate a lot of interesting image. One formulation is to do something like:</p><ol><li>An image gets less reward the more times it is generated.</li><li>Continously search for high-reward images.</li></ol><p>Obviously, this formulation kind of sucks, since the space of images is really big, and we will basically end up with a lot of images of noise. Ideally, we want to eliminate all images that look like 'static noise' in one swoop, and move on to more interesting kinds of variation.</p><p>Here's a sophistication-like formulation that makes use of smarter learning:</p><ol><li>A generative adversarial network is trained to model all past images that have been generated.</li><li>An image gets less reward if the GAN's discriminator thinks it looks in-distribution.</li><li>Continuously search for high-reward images and re-train the GAN.</li></ol><p>In theory, this GAN-based formulation should explore the image space in a more efficient manner. Once a few static-noise images have been generated, the GAN can model all types of static-noise, so high-reward images need to display other kinds of variation. </p><p>The GAN-based formulation kind of resembles sophistication, where we assume that Soph(x) is just -Disc(x), i.e. an image is sophisticated if it the GAN struggles to model it correctly. As long as we're continously re-training our GAN, this approximation shouldn't be too bad.</p><p>(To be continued...)</p><!--kg-card-begin: html--><!---

meta-thoughts on sophistication
it's just the same as curiosity
what we get is a function of "what states are harder to get to"
* because modeling by neural net is hard
* because finding the action seq that gets here is hard
* because creating an design that is seen as this object is hard
here's another option – can we just optimize directly for how hard something is to do. the interesting part is how to measure 'hard'
difficulty = new, but also hard to learn to do (N number of grads to imitate me)
difficulty = novel (tabular?)
difficulty = unstable
difficulty = shortest program to 
what are we trying to achieve?
emergent behavior that is interesting, to us as humans
technical: ground what it means to interesting
at some point, we do need to tie these things back into what humans call interesting
other definition = interesting if it can be used for downstream task
quality-diversity -> quality-[] difficulty? diversity?

agent parkour = agents in a world, and they try to mimic each other, and they get reward if they can't mimic each other, we can have a reward function to bootstrap off of. Maybe there is a cheaper way to measure 'hard to mimic' than imitation? Heuristics such as: less contact with ground, more instability, center of mass is not supported,
(other idea = mimic)
innovator: trained with RL, reward = inverse of fakabilityfaker: given state sequence, output action sequence that matchesgoal-proposer: come up with a imagegoal-achiever: make imagegoal-proposer: come up with a imagegoal-achiever: autoencode this image
goal-proposer: come up with a clip pointgoal-achiever: make image that matches that point
netork input: CLIPoutput: pixelsnetwork in: CLIPOUtput: VQGAN zoutput 2: pixelsStarting point: use translator as a jump, then optimize a little bit aftergoal-proposer: uniformly sample pointsgoal-achiever: make image that matches that point
one definition that leads to inequality in clip space: some images are harder to generate sequentially
Interestnegness is at some point philosophically a means to acheving quality.

--><!--kg-card-end: html-->]]></content:encoded></item><item><title><![CDATA[To extract information from language models, optimize for causal response]]></title><description><![CDATA[<p>One way to view AI progress is that we can measure increasingly abstract concepts.<br>A decade ago, it was really hard to take a photo and measure “how much does this image look like a dog?”. Nowadays, just pick your favorite image classification model and run the image through. Similarly,</p>]]></description><link>https://kvfrans.com/causal-language-model/</link><guid isPermaLink="false">6181770c4365cb5561d97f16</guid><category><![CDATA[research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Fri, 12 Nov 2021 20:56:33 GMT</pubDate><content:encoded><![CDATA[<p>One way to view AI progress is that we can measure increasingly abstract concepts.<br>A decade ago, it was really hard to take a photo and measure “how much does this image look like a dog?”. Nowadays, just pick your favorite image classification model and run the image through. Similarly, it used to be unfeasable to measure “how realistic is this photo?”, but with GANs, we can do it easily.</p><p>With the rise of powerful language models, e.g. GPT-3, we’ve unlocked the ability to work with even higher-level ideas. GPT contains a powerfully rich representation of our world, and it’s been shown to solve logical puzzles, summarize large texts, and utilize factual knowldge about the human world.</p><p>Can we somehow extract this knowledge out of GPT? Let’s say an alien wants to learn about the human world, so it goes up to its local GPT engine and asks, “What is coffee?”</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-08-at-6.18.09-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-08-at-6.18.09-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-08-at-6.18.09-PM.png 1000w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-08-at-6.18.09-PM.png 1440w" sizes="(min-width: 720px) 720px"></figure><p>“Oh, got it”, the alien says, “so coffee is a kind of orange juice.” This definition is shallow! Coffee is definitely something you drink in the morning, but so are a lot of other things. This answer is vague, and vagueness is a common trend when asking GPT to describe various concepts:</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-08-at-6.25.56-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-08-at-6.25.56-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-08-at-6.25.56-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-08-at-6.25.56-PM.png 1600w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-08-at-6.25.56-PM.png 1860w" sizes="(min-width: 720px) 720px"></figure><p>How can we improve these results? Let’s take a step back. We know that GPT has good knowledge of what coffee is, but that doesn’t mean it’s good at <em>telling us</em> what it knows. We want a stronger way of interacting with GPT.</p><p>A nice field to borrow from is <a href="https://distill.pub/2017/feature-visualization/">feature visualization</a>, which focuses on understanding what types of inputs activate various areas of neural networks. Crucially, feature visualization techniques tend to examine <em>causal </em>relationships – results are obtained by optimizing through the neural network many times, and discovering an input which maximizes a certain neuron activation.</p><p>So, let’s do the same with language models. Instead of asking GPT what it thinks about certain topics, let’s just search for phrases which directly affect what GPT is holding in memory. We can do this through what I’ll call <em>causal-response optimization,</em> which tries to maximize the probability of GPT replying “Yes” to a specific question.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-12.26.00-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-12-at-12.26.00-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-12-at-12.26.00-AM.png 1000w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-12.26.00-AM.png 1260w" sizes="(min-width: 720px) 720px"></figure><p>Does this metric actually mean anything? It's hard to confirm this quantitatively, but we can at least look at how the causal-response score of "Am I thinking of coffee?" responds to some human-written options.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-02-at-9.01.41-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-02-at-9.01.41-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-02-at-9.01.41-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-02-at-9.01.41-PM.png 1600w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-02-at-9.01.41-PM.png 2396w" sizes="(min-width: 720px) 720px"></figure><p>OK, it's not bad! The more details the phrase has, the higher the likeihood that GPT responds with "Yes". That sounds reasonable. Now, we just need to replace the human-written options with automatically generated options.</p><p>There’s actually a whole body of work, known as <a href="https://arxiv.org/abs/2107.13586">“prompting”</a>, that deals with how to locate phrases that adjust language model behavior. But we’re going to take a simple approach: 1) generate 20 varied answers for “What is coffee?”, and 2) select the answer with the greatest causal-response score for “Am I thinking of coffee?”.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-25-at-10.55.56-AM-1.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-25-at-10.55.56-AM-1.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-25-at-10.55.56-AM-1.png 1000w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-25-at-10.55.56-AM-1.png 1586w" sizes="(min-width: 720px) 720px"></figure><p>How well does this setup work? Here’s a comparison example on the word-definition task we looked at earlier.</p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-1.05.52-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-12-at-1.05.52-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-12-at-1.05.52-AM.png 1000w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-1.05.52-AM.png 1372w" sizes="(min-width: 720px) 720px"></figure><p>Nice, we're on the right track. Generating text through causal-response optimization results in substantially different results than pure forward-prediction. In general, the causal response definitions are more detailed and specific, which makes sense in this case, since the objective is to maximize GPT's recognition of the word.</p><p>What else can we do with causal-response optimization? Well, it turns out that GPT has a slight tendency to lie when doing forward-prediction. Causal-optimization can help deal with this issue. Take a look at the following list-completions:</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-1.01.46-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-12-at-1.01.46-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-12-at-1.01.46-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-12-at-1.01.46-AM.png 1600w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-12-at-1.01.46-AM.png 1654w" sizes="(min-width: 720px) 720px"><figcaption>Note: Suwa is a city, not a prefecture. Note 2: These are cherry-picked examples, not rigorously tested</figcaption></figure><p>The forward-prediction answers seem right at first glance, but they're a little wrong ... you can't read a painting. If we use causal-response optimization, we get "art in some form", which is a little more satisfying.</p><p>Another fun thing we can do is use causal-response optimization to steer long-form text generation. GPT does a nice job of evaluating various aspects of text through yes-no answers:</p><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-02-at-10.27.38-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-02-at-10.27.38-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-02-at-10.27.38-PM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-02-at-10.27.38-PM.png 1600w, https://kvfrans.com/content/images/size/w2400/2021/11/Screen-Shot-2021-11-02-at-10.27.38-PM.png 2400w" sizes="(min-width: 1200px) 1200px"></figure><p>Since we can now measure concepts like happiness or old-man-ness, we can also optimize for them.</p><p>Let's look at a "story about apples". If we wanted to hear it from the perspective of an old man, one option is to adjust our forward-prediction prompt like so:</p><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-10-at-10.59.44-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-10-at-10.59.44-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-10-at-10.59.44-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-10-at-10.59.44-AM.png 1600w, https://kvfrans.com/content/images/size/w2400/2021/11/Screen-Shot-2021-11-10-at-10.59.44-AM.png 2400w" sizes="(min-width: 1200px) 1200px"></figure><p>Or, we could apply causal-response optimization to each sentence generated. For each line, ten sentence options are generated, and the one that has the best score for "Is this something an old man would say?" is chosen.</p><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-10-at-11.01.15-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Screen-Shot-2021-11-10-at-11.01.15-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Screen-Shot-2021-11-10-at-11.01.15-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/11/Screen-Shot-2021-11-10-at-11.01.15-AM.png 1600w, https://kvfrans.com/content/images/2021/11/Screen-Shot-2021-11-10-at-11.01.15-AM.png 2130w" sizes="(min-width: 1200px) 1200px"></figure><p>Almost every sentence in the causal-optimization story reinforces the idea of an old man speaking: there are references to growing up in the country, apples straight off the tree, and the absence of supermarkets.</p><p>So, as researchers, why should we care about causal-response optimization? </p><p>I'm not trying to argue that this method constantly provides better quality results. The examples in this post are mostly anecdotal, and are somewhat cherry-picked to showcase key differences. While I'm sure causal-response can be used to improve text-generation if applied correctly, that claim would require more rigor than I can provide at the moment.</p><p>What I do want to present, are some hints that causal-response is a useful tool for extracting knowledge out of language models. Causal optimization results in substantially different content than forward-prediction – we're explicitly probing into the language model, instead of relying on the model to say what it is thinking. We should explore this direction more! Strong pre-trained models are going to be around for a while, so understanding how they view the world is an important task to be looking at.</p><p>The code for these results is available at <a href="https://gist.github.com/kvfrans/1ba1f36d490b7a26c442d36f23e0941a">this gist</a>, although you need an OpenAI GPT-3 key to run it. I think the code is very readable, so it's probably easy to re-implement this stuff with e.g. GPT-Neo.</p><p>Further work can involve:</p><ul><li>How can we quantatively measure how good generated sentences are? <br>Are there good metrics for knowledge-extraction?</li><li>What questions do we have about large language-model representations, and can causal-optimization be used to answer those questions?</li><li>Are there better ways of optimizing text for causal-response than bootstrapping off forward-prediction? I.e. evolutionary search, gradient optimization, etc. (Somewhat addressed in Appendix)</li></ul><!--kg-card-begin: html--><hr class="gridhr">
<br>
<div class="float-outside">
<div class="appendix">
    <div class="float-label">
        Appendix
    </div><!--kg-card-end: html--><p>Cite as:</p><pre><code class="language-x">@article{frans2021causalresponse,
  title   = "To extract information from language models, optimize for causal response",
  author  = "Frans, Kevin",
  journal = "kvfrans.com",
  year    = "2021",
  url     = "https://kvfrans.com/causal-language-model/"
}</code></pre><p>This work isn't rigorous enough for me to publish it in a formal setting, but I would love to see further research the directions discussed here. Thanks to Phillip Isola and others in the <a href="http://web.mit.edu/phillipi/">Isolab</a> for the help :)</p><!--kg-card-begin: html--><br><div class="float-label" style="clear: both">
	Source Code
</div><!--kg-card-end: html--><figure class="kg-card kg-bookmark-card"><a class="kg-bookmark-container" href="https://gist.github.com/kvfrans/1ba1f36d490b7a26c442d36f23e0941a"><div class="kg-bookmark-content"><div class="kg-bookmark-title">lang.ipynb</div><div class="kg-bookmark-description">GitHub Gist: instantly share code, notes, and snippets.</div><div class="kg-bookmark-metadata"><img class="kg-bookmark-icon" src="https://github.githubassets.com/favicons/favicon.svg"><span class="kg-bookmark-author">Gist</span><span class="kg-bookmark-publisher">262588213843476</span></div></div><div class="kg-bookmark-thumbnail"><img src="https://github.githubassets.com/images/modules/gists/gist-og-image.png"></div></a></figure><!--kg-card-begin: html--><br><div class="float-label" style="clear: both">
	Things To Read
</div><!--kg-card-end: html--><p><a href="https://arxiv.org/pdf/2107.13586.pdf"><strong>"Pre-train, Prompt, and Predict: A Systematic Survey of Prompting Methods in Natural Language Processing"</strong></a><strong>, by Liu et al. </strong>A growing field in the last years deals with prompting for pre-trained language models. In prompting, instead of fine-tuning the network parameters on a task, the goal is to learn a <em>prompt phrase</em> that is appended to input text in order to influence the behavior of the language model. This paper is a good review of the area.</p><p><strong>"<strong><a href="https://arxiv.org/abs/2010.15980">AutoPrompt: Eliciting Knowledge from Language Models with Automatically Generated Prompts</a></strong>", by Shin et al. </strong>This paper is one of the works in prompting, I picked it because it cleanly states the prompting problem and presents a framework. In AutoPrompt, the aim is to learn a set of trigger words which influences a language model into predicting the correct output for a set of tasks.</p><p><a href="https://arxiv.org/abs/2102.01645">"<strong>Generating images from caption and vice versa via CLIP-Guided Generative Latent Space Search"</strong></a><strong>, by Galatolo et al. </strong>CLIP-GlaSS is a way to move from text-to-image and image-to-text via CLIP encodings. There are similarities to prompting work here, especially in the image-to-text task, where the aim is to locate text which best matches a given image. </p><p>In CLIP-GlaSS they use an evolutionary search to locate text; I tried that initially in the causal-response optimization, but the method struggled with finding gramatically-correct answers.<strong> </strong>I played with various combinations of beam search, where language model logits are weighted by the causal-response objective, but the results had a lot of variation. In the end, generating options via forward-prediction and selecting via causal-response led to be most satisfying results.</p><p><strong><a href="https://arxiv.org/abs/1909.01066">"Language Models as Knowledge Bases?"</a>, by Petroni et al. </strong>This work is an analysis on BERT and how it performs well on various question-answering datasets without additional fine-tuning. To be honest, I chose this paper based on its name, which implies a perspective I agree with – language models are rich knowledge bases, and what we need are tools to extract this knowledge from them.</p></div></div>]]></content:encoded></item><item><title><![CDATA[Data digesters, ML^2, Interestingness]]></title><description><![CDATA[<p>This post is a bunch of random thoughts about things I've read recently. It's organized into sections, and the sections don't really relate to each other, but here they are.</p><h2 id="neural-networks-are-data-digesters">Neural Networks are Data Digesters</h2><p>A common philosophy around my head is that "representation learning is everything". Need to label</p>]]></description><link>https://kvfrans.com/thoughts-oct-29/</link><guid isPermaLink="false">617c58914365cb5561d97d80</guid><category><![CDATA[thoughts]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Tue, 02 Nov 2021 17:16:40 GMT</pubDate><content:encoded><![CDATA[<p>This post is a bunch of random thoughts about things I've read recently. It's organized into sections, and the sections don't really relate to each other, but here they are.</p><h2 id="neural-networks-are-data-digesters">Neural Networks are Data Digesters</h2><p>A common philosophy around my head is that "representation learning is everything". Need to label some images? Train a linear classifier over a powerful representation space. Need to navigate a robot to pick up some blocks? Train a simple policy over a strong representation of the robot. If you can map your input data into a general enough representation, everything becomes easy.</p><p>Eric Jang's blog post <a href="https://evjang.com/2021/10/23/generalization.html">"Just Ask For Generalization"</a> takes this philosophy a step further. If what we really need is strong representations, and neural networks are really strong sponges for absorbing data, then we might as well cram as much relevant data as we can into our models. Eric argues:</p><blockquote>1. Large amounts of diverse data are more important to generalization than clever model biases.<br>2. If you believe (1), then how much your model generalizes is directly proportional to how fast you can push diverse data into a sufficiently high-capacity model.</blockquote><p>To me, it feels like we're reaching diminishing returns in designing new models – e.g, transformers are highly effective at modeling everything from text to video to decision-making. Progress lies, instead, in how we can feed these models the data they need to properly generalize in whatever dimensions we care about.</p><p>One hypothesis is that we should give our models as much data as we can, regardless of how relevant the data is to the task at hand. Take, for example, language models – the strongest models are pretrained on billions of arbitrary texts from the internet. In reinforcement learning, we're starting to see the same parallels. Traditionally, agents are trained to maximize rewards on the task at hand, but it turns out we get sizable improvements by training agents to <a href="https://arxiv.org/abs/1707.01495">reach past states</a>, <a href="https://arxiv.org/abs/2106.01345">achieve a range of rewards</a>, or <a href="https://arxiv.org/abs/1901.01753">solve many mutations of the task</a>.</p><p>If we believe this hypothesis, we should really be researching how to best extract relevant data for the task at hand.  Should a strong RL agent be trained to achieve many goals instead of one, or to achieve a spectrum of rewards instead of optimal reward? Should its model be able to predict future states? To contrast augmented data from fake data? Should it be trained over thousands of games, instead of just one?</p><h2 id="machine-learning-via-machine-learning">Machine Learning via Machine Learning</h2><p>MIT offered a meta-learning class last year, which I really wanted to take, except I was away on gap year so I couldn't. The highlight of the class is to build "a system that passes 6.036 Introduction to Machine Learning as a virtual student". That's so cool! Thankfully the group behind the class released a paper about it, so <a href="https://arxiv.org/pdf/2107.01238.pdf">here it is</a>.</p><blockquote>We generate a new training set of questions and answers consisting of course exercises, homework, and quiz questions from MIT’s 6.036 Introduction to Machine Learning course and train a machine learning model to answer these questions. Our system uses Transformer models within an encoder-decoder architecture with graph and tree representations ... Thus, our system automatically generates new questions across topics, answers both open-response questions and multiple-choice questions, classifies problems, and generates problem hints, pushing the envelope of AI for STEM education.</blockquote><p>Summary-wise, they've got a Transformer model which reads in machine learning questions, and outputs a graph of mathematical operators which solve the questions. It seems to work quite well, and passes the class with an A (!). </p><figure class="kg-card kg-image-card"><img src="https://kvfrans.com/content/images/2021/10/Screen-Shot-2021-10-29-at-5.36.24-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/10/Screen-Shot-2021-10-29-at-5.36.24-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/10/Screen-Shot-2021-10-29-at-5.36.24-PM.png 1000w, https://kvfrans.com/content/images/2021/10/Screen-Shot-2021-10-29-at-5.36.24-PM.png 1488w" sizes="(min-width: 720px) 720px"></figure><p>A lot of the benefit seems to come from the data augmentation:</p><blockquote>After collecting questions from 6.036 homework assignments, exercises, and quizzes, each question is augmented by paraphrasing the question text. This results in five variants per question. Each variant, in turn, is then augmented by automatically plugging in 100 different sets of values for variables. Overall, we curate a dataset consisting of 14,000 augmented questions.</blockquote><p>One thing that's a bit dissapointing is that coding questions are skipped. In my opinion well-formulated coding questions are the best part of classes, since they force you to replicate the content and really understand it. Maybe a tandem team with <a href="https://openai.com/blog/openai-codex/">OpenAI Codex</a> can pass the course in full.</p><h2 id="what-makes-something-interesting">What makes something interesting?</h2><p>In open-endedness research, we want to build systems that can keep producing unbounded amounts of interesting things. There are many problems with this definition. What qualifies as unbounded? What makes something interesting? What makes something a thing?</p><p>One answer is that all these questions depend on the perspective that we look at the world in. <a href="https://www.youtube.com/watch?v=2gExH4Tzkzc">Alyssa M Adams' talk at Cross Roads</a> this month gives an example:</p><blockquote>The sea was never blue. Homer described the sea as a deep wine color. The greek color experience was made of movement and shimmer, while modern humans care about hue, saturation and brightness.</blockquote><p>It's a representation learning problem! Interestingness isn't a fact by itself, an object is only interesting to the agent that is looking at it. Normally, this answer isn't that satisfying, since modeling what is interesting <em>to humans</em> is expensive at best and inaccurate at worst.</p><p>With the rise of <a href="https://arxiv.org/abs/2108.07258">foundation models</a>, however, we can make a non-trivial argument that AI models like <a href="https://openai.com/blog/gpt-3-apps/">GPT-3 </a>and <a href="https://openai.com/blog/clip/">CLIP</a> are learning strong human-centric representations of reality. Could we just, like, <em>ask</em> these models what they think is interesting? </p><p>Maybe GPT-3 thinks "walking" and "running" are too similar, but "flying" is qualitatively different. Maybe CLIP thinks a truck is a truck regardless of what roads it's on, but if the truck is in the ocean, now that is interesting.</p><h2 id="interestingness-in-architecture">Interestingness in Architecture</h2><p>"Nearly everything being built is boring, joyless, and/or ugly", writes Nathan Robinson. In <a href="https://www.currentaffairs.org/2021/04/when-is-the-revolution-in-architecture-coming">an April article</a>, Robinson showcases colorful classical temples, mosques, and coastal villages. Then he shows photographs of modern brutalist design – cold and gray.</p><figure class="kg-card kg-gallery-card kg-width-wide"><div class="kg-gallery-container"><div class="kg-gallery-row"><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2021/11/Venice-1024x682.png" width="1024" height="682" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/Venice-1024x682.png 600w, https://kvfrans.com/content/images/size/w1000/2021/11/Venice-1024x682.png 1000w, https://kvfrans.com/content/images/2021/11/Venice-1024x682.png 1024w" sizes="(min-width: 720px) 720px"></div><div class="kg-gallery-image"><img src="https://kvfrans.com/content/images/2021/11/enough-1024x899.jpg" width="1024" height="899" alt srcset="https://kvfrans.com/content/images/size/w600/2021/11/enough-1024x899.jpg 600w, https://kvfrans.com/content/images/size/w1000/2021/11/enough-1024x899.jpg 1000w, https://kvfrans.com/content/images/2021/11/enough-1024x899.jpg 1024w" sizes="(min-width: 720px) 720px"></div></div></div></figure><blockquote>What is beauty? Beauty is that which gives aesthetic pleasure ... a thing can be beautiful to some people and not to others. Pritzker Prize-winning buildings are <em><em>beautiful to architects</em> ..</em>. They are just not beautiful to me.</blockquote><p>Hey, look, it's the problem of interestingness again. What makes a building lively, welcoming, or as Christopher Alexander puts it, contains <a href="https://en.wikipedia.org/wiki/The_Timeless_Way_of_Building">"the quality without a name"</a>? There's probably no objective answer, but there is a human-centered answer. Could an AI understand architectural aesthetics? If it has the right representations, I don't see why not.</p>]]></content:encoded></item><item><title><![CDATA[CLIPDraw: Exploring Text-to-Drawing Synthesis]]></title><description><![CDATA[<!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/clipdraw_cropped.mp4">
                </video>
                <figcaption>
CLIPDraw synthesizes novel drawings from text. <a href="https://colab.research.google.com/github/ kvfrans/clipdraw/blob/main/clipdraw.ipynb">Play with the codebase yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><p>AI-assissted art has always been intriguing to me. To me, visual art is a very human thing -- I can’t imagine a way that a computer rediscovers our own cultural concepts without some kind of experience living in</p>]]></description><link>https://kvfrans.com/clipdraw-exploring-text-to-drawing-synthesis/</link><guid isPermaLink="false">60d36b8e17e26d47c77e91e1</guid><category><![CDATA[research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Mon, 28 Jun 2021 20:07:30 GMT</pubDate><content:encoded><![CDATA[<!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/clipdraw_cropped.mp4">
                </video>
                <figcaption>
CLIPDraw synthesizes novel drawings from text. <a href="https://colab.research.google.com/github/ kvfrans/clipdraw/blob/main/clipdraw.ipynb">Play with the codebase yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><p>AI-assissted art has always been intriguing to me. To me, visual art is a very human thing -- I can’t imagine a way that a computer rediscovers our own cultural concepts without some kind of experience living in our world. So when AI methods are able to learn from us, and produce artwork that we as humans might make ourselves, something cool is definitely happening.</p><p>This post presents <em>CLIPDraw</em>, a text-to-drawing synthesis method that I’ve been playing with nonstop for the past few weeks. At its core, CLIPDraw is quite simple, yet it is able to produce drawings that display a whole range of interesting behaviors and connections. So for this article, I want to focus on showing off behaviors of CLIPDraw that I personally found intriguing and wanted to learn more about. For more detailed analysis and technical detail, check out the<a href="https://arxiv.org/abs/2106.14843"> CLIPDraw paper</a>, or play around with <a href="https://colab.research.google.com/github/ kvfrans/clipdraw/blob/main/clipdraw.ipynb">the Colab notebook</a> yourself!</p><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-10-at-8.47.23-PM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-10-at-8.47.23-PM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-10-at-8.47.23-PM.png 1000w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-10-at-8.47.23-PM.png 1080w"></figure><h2 id="image-synthesis-drawing-synthesis">Image Synthesis -&gt; Drawing Synthesis</h2><p>The field of text-to-image synthesis has a broad history, and recent methods have shown stunningly realistic image generation through GAN-like methods. Realism, however, is a double-edged sword – there's a lot of overhead in generating photorealistic renderings, when often all we want are simple drawings. With CLIPDraw, I took inspiration from the web game <a href="https://skribbl.io/">Skribbl.io</a>, where players only have a few seconds to draw out a word for other players to guess. What if an AI could play Skribbl.io? What would it draw? Are simple shapes enough to represent increasingly complex concepts?</p><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-18-at-11.05.09-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-18-at-11.05.09-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-18-at-11.05.09-AM.png 1000w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-18-at-11.05.09-AM.png 1214w" sizes="(min-width: 1200px) 1200px"></figure><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-18-at-11.12.52-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-18-at-11.12.52-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-18-at-11.12.52-AM.png 1000w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-18-at-11.12.52-AM.png 1212w" sizes="(min-width: 1200px) 1200px"></figure><h2 id="how-does-clipdraw-work">How does CLIPDraw work?</h2><p>CLIPDraw's knowledge is powered through a pre-trained CLIP model, a recent release which I definitely reccommend <a href="https://openai.com/blog/clip/">reading about.</a> In short, a CLIP model consists of an image encoder and a text encoder, which both map onto the same represenational space. This setup allows us to measure the <em>similarities</em> betweeen images and text. And if we can measure similarities, we can also try to discover images that maximize that similarity, therefore matching a given textual prompt.</p><p>The basic CLIPDraw loop follows this principle of synthesis-through-optimization. First, start with a human-given description prompt and a random set of Bezier curves. Then, gradually adjust those curves through gradient descent so that the drawing best matches the given prompt. There are a few tricks that also help, <a href="https://arxiv.org/abs/2106.14843">as detailed in the paper,</a> but this loop is basically all there is to it.</p><p>Now, let's examine what CLIPDraw comes up with in practice.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.35.48-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-23-at-10.35.48-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-23-at-10.35.48-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/Screen-Shot-2021-06-23-at-10.35.48-AM.png 1600w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.35.48-AM.png 1867w" sizes="(min-width: 1200px) 1200px"></figure><h2 id="what-visual-techniques-does-clipdraw-use">What visual techniques does CLIPDraw use?</h2><p>A reoccuring theme is that CLIPDraw tends to interpret the description prompts in multiple, unexpected ways. A great example for this is "<em>A painting of a starry night sky</em>", which shows a painterly-styled sky with a moon and stars, along with an actual painting canvas and painter in the foreground, which then also features black and blue swirls resembling Van Gogh's <em>"The Starry Night"</em>.</p><p>CLIPDraw also likes to use abstract symbols, the most prominent being when it writes out the literal word inside the image itself. Sometimes, CLIPDraw will use tangentially related concepts, like the Google Maps screenshot when asked for "自転車 (<em>Bicycle</em> in Japanese)”. A fun result is the prompt for "<em>Fast Food", </em>which shows hamburgers and a McDonald's logo, but also has a bunch of joggers racing in the background.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.31.19-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-23-at-10.31.19-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-23-at-10.31.19-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/Screen-Shot-2021-06-23-at-10.31.19-AM.png 1600w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.31.19-AM.png 1855w" sizes="(min-width: 1200px) 1200px"></figure><h2 id="how-does-clipdraw-react-to-different-styles">How does CLIPDraw react to different styles?</h2><p>An experiment I really enjoyed was to synthesize images of cats, but in different artistic styles. With CLIPDraw, this was as easy as changing the descriptor adjectives of the textual prompt. Surprisingly, CLIPDraw is quite robust at handling different styles, contrary to the initial intent to synthesize scribble-like drawings.</p><p>One interesting feature is that CLIPDraw adjusts not only the textures of drawings, ala <a href="https://arxiv.org/abs/1508.06576">Style Transfer</a> methods, but also the structure of the underlying content. In the cat experiments, asking for <em>"a drawing"</em> produces a simplified cartoonish cat, while prompts like <em>"a 3D wireframe"</em> produce a cat in perspective, with depth and shadows.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-width-wide kg-card-hascaption"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.41.16-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-23-at-10.41.16-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-23-at-10.41.16-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/Screen-Shot-2021-06-23-at-10.41.16-AM.png 1600w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.41.16-AM.png 1767w" sizes="(min-width: 1200px) 1200px"><figcaption>"The Eiffel Tower"</figcaption></figure><h2 id="does-stroke-count-affect-the-drawings-that-clipdraw-produces">Does stroke count affect the drawings that CLIPDraw produces?</h2><p>A key aspect of CLIPDraw is that drawings are represented as a set of Bezier curves rather than a matrix of pixels. This feature gives us a nice parameter to tweak in the number of curves a drawing is comprised of.</p><p>Emperically, drawings with low stroke counts tend to result in a more cartoonish or abstract representation of the prompt, such as the 16-stroke version of <em>"The Eiffel Tower"</em> being basically made of a few straight lines. As stroke count is increased, CLIPDraw begins to target more structured shapes, shown as our Eiffel Tower begins gaining 3D perspective, lights, and finally a background.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.36.16-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-23-at-10.36.16-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-23-at-10.36.16-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/Screen-Shot-2021-06-23-at-10.36.16-AM.png 1600w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.36.16-AM.png 1739w" sizes="(min-width: 1200px) 1200px"></figure><h2 id="what-happens-if-abstract-words-are-given-as-a-prompt">What happens if abstract words are given as a prompt?</h2><p>A fun thing to push the limits is to give CLIPDraw abstract descriptions, and see what it does with them. As a human artist, even I would have to stop and think about how to convey these concepts through visuals, so it's interesting to see how the AI will approach things.</p><p>In most cases, CLIPDraw likes to use symbols to showcase concepts that are culturally related to the given phrase, like the fireworks and smiles in <em>"Happiness"</em> or the Japanese and English-like letters in <em>"Translation"</em>. </p><p>My favorite here is the drawing for <em>"Self"</em>, which features a body holding up multiple heads. The drawing can almost be seen as a metaphor for e.g. the idea that a person's self may contain multiple outward personalities, or that a self is actually a sum of many cognitive processes. This piece is defintely the most "art-like" example I came across; there's a lot of room for individual interpretation, and it almost feels like CLIP knows something that I don't.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><figure class="kg-card kg-image-card kg-width-wide"><img src="https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.40.08-AM.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/Screen-Shot-2021-06-23-at-10.40.08-AM.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/Screen-Shot-2021-06-23-at-10.40.08-AM.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/Screen-Shot-2021-06-23-at-10.40.08-AM.png 1600w, https://kvfrans.com/content/images/2021/06/Screen-Shot-2021-06-23-at-10.40.08-AM.png 1891w" sizes="(min-width: 1200px) 1200px"></figure><h2 id="can-drawings-be-fine-tuned-via-negative-prompts">Can drawings be fine-tuned via negative prompts?</h2><p>A final experiment was to see if CLIPDraw behavior could be finely adjusted by introducing additional optimization objectives. In the normal process, drawings are optimized to best match a textual prompt. What if we also tried to optimize unsimilarity with a set of <em>negative prompts</em>?</p><p>In a few situations, this works! By penalizing the prompt for "Words and text", the <em>"Hashtag</em>" drawing features less prominent words and instead draws a set of selfie-like faces. Negative prompts can also do things like adjust the color of images, or force drawings to contain only a single subject rather than many.</p><p>Practially, the examples above were pretty hard to achieve, and most of the time negative prompts don't change much about the final drawing at all. I think there's still a lot of room for experimentation on how to improve this technique. One thing I'd love to see is a panacea prompt, like "A messy drawing", that consistently improves drawing quality if used as a negative objective regardless of context.</p><!--kg-card-begin: html--><hr class="gridhr"><br><!--kg-card-end: html--><h2 id="clipdraw-parting-thoughts">CLIPDraw: Parting Thoughts</h2><p>The CLIPDraw algorithm isn't particularly novel; people have doing synthesis-through-optimization for a while through activation-maximization methods, and recently through CLIP-matching objectives. However, I do believe biasing towards drawings rather than photorealism gives images more freedom of expression, and optimizing Bezier curves is a nice way to do this efficiently. I also personally love this artstyle and I think the drawings are quite similar to what an artist would produce ;). </p><p>That being said, I do think the behaviors showcased here should be pretty generalizable to any CLIP-based optimization method. Already a few extensions come to mind – Can we synthesize videos? 3D models? Can an AI play <a href="https://skribbl.io/">Skribbl</a> or <a href="https://www.brokenpicturephone.com/">Broken Picture Phone</a> with itself? I'm sure you as a reader have a bunch of ideas I haven't even considered. So please, feel free to take this method and go wherever you feel is exciting. And then <a href="https://twitter.com/kvfrans">tell me about it!</a></p><p>You can experiment with <a href="https://colab.research.google.com/github/kvfrans/clipdraw/blob/main/clipdraw.ipynb">the Colab notebook here</a>, and play with CLIPDraw in the browser. Results are generally pretty quick (within a minute) unless you crank up stroke count and iterations. You can also check out <a href="https://arxiv.org/abs/2106.14843">the full paper</a>, which contains a bit of a deeper analysis and details on the technical implementation.</p><!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                                Thanks for reading!<br>
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" style="width:10em" autoplay muted playsinline loop src="https://kvfrans.com/static/outputleo.mp4">
                </video>
                <figcaption>
                    
"Leonardo Dicaprio cheers meme template"
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<br>
<div class="float-outside">
<div class="appendix">
    <div class="float-label">
        Appendix
    </div><!--kg-card-end: html--><p>Thank you to everyone at <a href="http://crosslabs.org/">Cross Labs</a> for assisstance and feedback.</p><!--kg-card-begin: html--><div class="float-label" style="clear: both">
	Things To Read
</div><!--kg-card-end: html--><p><strong><a href="https://kvfrans.com/p/2db72cf5-f055-4ab2-990c-23ac65f4a138/arxiv">"CLIPDraw: Exploring Text-to-Drawing Synthesis via Language-Image Encoders"</a>, by Frans et al. </strong>Of course a nice place to start is the paper for CLIPDraw! The paper mostly contains a similar focus on analyzing interesting behaviors, but also includes some comparisons and a Related Work section of good papers to read up on.</p><p><strong><a href="https://openai.com/blog/clip/">"CLIP: Connecting Text and Images"</a>, by Radford et al. </strong>This is the blog for the original CLIP release, from OpenAI. CLIP is a language-image dual encoder that maps text and images onto the same feature space. Trained on a huge amount of online data, CLIP ends up being super robust and can solve a bunch of image tasks without any additional training. <a href="https://arxiv.org/abs/2103.00020">(Paper)</a></p><p><strong><a href="https://people.csail.mit.edu/tzumao/diffvg/">"Differentiable Vector Graphics Rasterization for Editing and Learning"</a>, by Li et al.</strong> This work is the differentiable vector graphics methods that we use in CLIPDraw to backprop through Bezier curves. It's really impressive and basically works out of the box, so I highly reccommend giving it a read.</p><p><strong><a href=" https://arxiv.org/abs/2105.00162">"Generative Art Using Neural Visual Grammars and Dual Encoders"</a>, by Fernando et al. </strong>This paper is another CLIP-based generative art method, also using stroke-based images, but with an emphasis on AI creativity. Instead of backprop, they use evolutionary methods to evolve a custom LSTM-based grammar which defines the strokes. Also their figures are cool.</p><p><strong><a href="https://twitter.com/images_ai">"Images Generated By AI Machines"</a>, by <a href="https://twitter.com/samburtonking">@samburtonking</a>. </strong>Also check out the <a href="https://colab.research.google.com/drive/1NCceX2mbiKOSlAd_o7IU7nA9UskKN5WR?usp=sharing">Colab</a> <a href="https://colab.research.google.com/drive/1go6YwMFe5MX6XM9tv-cnQiSTU50N9EeT?fbclid=IwAR30ZqxIJG0-2wDukRydFA3jU5OpLHrlC_Sg1iRXqmoTkEhaJtHdRi6H7AI#scrollTo=CppIQlPhhwhs">notebooks</a> by <a href="https://twitter.com/advadnoun">@advadnoun</a> and <a href="https://twitter.com/RiversHaveWings">@RiversHaveWings</a>, respectively. This twitter account basically shows off cool generative art made from CLIP + GAN optimization methods. A lot of interesting pieces have been made and it's well worth looking through.</p><p><strong><a href="https://github.com/ajayjain/VectorAscent">"VectorAscent: Generate vector graphics from a textual description"</a> by Ajay Jain. </strong>Here's a project from earlier this year that actually used the same techniques of combining CLIP and diffvg. The key difference is in the image augmentation, which seems to help in ensuring synthesized drawings look more consistent.</p></div></div>]]></content:encoded></item><item><title><![CDATA[StampCA: Growing Emoji with Conditional Neural Cellular Automata]]></title><description><![CDATA[<!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/emoji.mp4"></video>
                <figcaption>
StampCA growing emoji. <a href="https://colab.research.google.com/drive/1FBEuRymdpgQiDPl5aLPrMDPVIZUM5xpg#scrollTo=8_qZe_c1uPHf">Play with the codebase yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><p>When a baby is born, it doesn’t just appear out of nowhere -- it starts as a single cell. This seed cell contains all the information needed to replicate and grow into a full adult. In biology, we call this process</p>]]></description><link>https://kvfrans.com/stampca-conditional-neural-cellular-automata/</link><guid isPermaLink="false">6062a14517e26d47c77e8f44</guid><category><![CDATA[research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Sat, 03 Apr 2021 08:54:15 GMT</pubDate><content:encoded><![CDATA[<!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/emoji.mp4"></video>
                <figcaption>
StampCA growing emoji. <a href="https://colab.research.google.com/drive/1FBEuRymdpgQiDPl5aLPrMDPVIZUM5xpg#scrollTo=8_qZe_c1uPHf">Play with the codebase yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><p>When a baby is born, it doesn’t just appear out of nowhere -- it starts as a single cell. This seed cell contains all the information needed to replicate and grow into a full adult. In biology, we call this process <em>morphogenesis</em>: the development of a seed into a structured design. </p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://lh5.googleusercontent.com/fhkkGwxPLLkTrMwxA97oC03menLZUIjv9wN8sZpDecGxfEA08O_lFcq9KsU7VKp1nFlF_1DigX1LGfgkwPx3pJ0QJQLDr8VYUUcWKAtz6Zg1WpGTtdUKCOuFzEewfe-k4qgx_dMO" class="kg-image" alt><figcaption>Morphogenesis builds up an embyro. https://www.nature.com/articles/s41467-018-04155-2</figcaption></figure><p>Of course, if there’s a cool biological phenomenon, someone has tried to replicate it in artificial-life land. One family of work examines Neural Cellular Automata (NCA), where a bunch of cells interact with their neighbors to produce cool emergent designs. By adjusting the behavior of the cells, we can influence them to grow whatever we want. NCAs have been trained to produce a whole bunch of things, such as <a href="https://distill.pub/2020/growing-ca/">self-repairing images</a>, <a href="https://distill.pub/selforg/2021/textures/">patternscapes</a>, <a href="https://arxiv.org/abs/2102.02579">soft robots,</a> and <a href="https://arxiv.org/pdf/2103.08737.pdf">Minecraft designs</a>.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://lh4.googleusercontent.com/f7HLyTjDE5Nrri0h96md7NDbzxTEEupm3wGVdYIEZFO1y4lBeytkRiKuq_uAevQsD8wpXvLEQMdOPCBnuVsJATDDiG1aWUoXvO34eeRxQzrJg3DvaOJshtFavHPAX6LSPZzRZvC_" class="kg-image" alt><figcaption>Growing a Minecraft tree through Neural Cellular Automata. https://arxiv.org/pdf/2103.08737.pdf</figcaption></figure><p>A common thread in these works is that a single NCA represents a single design. Instead, I wanted to explore <em>conditional</em> NCAs, which represent design-producing functions rather than single designs. The analogy here is DNA translation -- our biologies don’t just define our bodies, they define a system which translates DNA into bodies. Change the DNA, and the body changes accordingly.</p><p>This post introduces StampCA, a conditional NCA which acts like a translation network. In StampCA, the starting cell receives an encoding vector, which dictates what design it must grow into. StampCA worlds can mimic the generative capabilities of traditional image networks, such as an MNIST-digit GAN and an Emoji-producing autoencoder. Since design-specific information is stored in cell state, StampCA worlds can 1) support morphogenetic growth for many designs, and 2) grow them all in the same world.</p><p>At the end, I’ll go over some thoughts on how NCAs relate to work in convnets, their potential in evo-devo and artificial life, and why I’m pretty excited about them looking ahead.</p><h2 id="neural-cellular-automata-how-do-they-work">Neural Cellular Automata: How do they work?<br></h2><p>First, let’s go over how an NCA works in general. The ideas here are taken from this <a href="https://distill.pub/2020/growing-ca/">excellent Distill article by Mordvintsev et. al</a>. Go check it out for an nice in-depth explanation.</p><p>The basic idea of an NCA is to simulate a world of cells which locally adjust based on their neighbors. In an NCA, the world is a grid of cells. Let’s say 32 by 32. Each cell has a state, which is a 64-length vector. The first four components correspond to RGBA, while the others are only used internally.</p><p>Every timestep, all cells will adjust their states locally. Crucially, cells can only see the states of themselves and their direct neighbors. This information is given to a neural network, which then updates the cell's state accordingly. The NCA will iterate for 32 timesteps.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://lh4.googleusercontent.com/veb-wU-ubWAbLuz9ItBP0No5UzSM_2UJSwlyscd-326_KRLMzgkitfsZzBDte7JFFTGG68ukZsV9FfmbGP2nKospbP5RjGMCNB1WgO9GReV9tw_JU7aRxm_8baLRS_rhJK6Ne_qQ" class="kg-image" alt><figcaption>One timestep of an NCA. Every cell updates their state based on their neighbors' states, as defined by a convolutional neural network.</figcaption></figure><p>The key idea here is to learn a set of local rules which combine to form a global structure. Cells can only see their neighbors, so they need to rely on their local surroundings to figure out what to do. Again, the analogy is in biological development -- individual cells adjust their structure based on their neighbors, creating a functional animal when viewed as a whole.</p><p>In practice, we can define an NCA as a 3x3 convolution layer, applied recursively many times. This convolutional kernel represents a cell's receptive field – each cell updates its own state based on the states of its neighbors. This is repeated many times, until a final result is produced. Since everything is differentiable, we can plug in our favorite loss function and optimize.</p><p>By being creative with our objectives, we can train NCAs with various behaviors. The classic objective is to start with every cell set to zero except a single starting cell, and to grow a desired design in N timesteps. We can also train NCAs that <em>maintain</em> designs by using old states as a starting point, or NCAs which <em>repair</em> designs by starting with damaged states. There’s a lot of room for play here.</p><!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <img class="b-lazy b-loaded gridvideo" src="https://kvfrans.com/static/open/stampca2/diagram2.png">
                <figcaption>
Different objectives for training NCAs.
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<br>
<br><!--kg-card-end: html--><p>The common shortcoming here is that the NCA needs to be retrained if we want to produce a different design. We’ve successfully learned an algorithm which locally generates a global design -- now can we learn an algorithm which locally generates <em>many</em> global designs?</p><h2 id="stampca-conditional-neural-cellular-automata">StampCA: Conditional Neural Cellular Automata</h2><p><br>One way to view biological development is as a system of two factors: DNA and DNA translation. DNA produces a code; a translation system then creates proteins accordingly. In general, the same translation system is present in every living thing. This lets evolution ignore the generic heavy work (how to assemble proteins) and focus on the variation that matters for their species (which proteins to create).</p><p>The goal of a conditional NCA is in the same direction: separate a system which <em>grows</em> designs from a system which <em>encodes</em> designs. The key distinction is that encodings are design-specific, whereas the growth system is shared for all designs.</p><p>It turns out that the NCA formulation gives us a straightforward way to do this, which we’ll refer to as StampCA. In a typical NCA, we initialize the grid with all zeros except for a seed cell. Instead of initializing this seed cell with a constant, let’s set its state to an <em>encoding vector.</em> Thus, the final state of the NCA becomes a function of its neural network parameters along with the given encoding.</p><!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <img style="height: auto" class="b-lazy b-loaded gridvideo" src="https://kvfrans.com/static/open/stampca2/diagram3.png">
                <figcaption>
A StampCA is an NCA where the seed is conditioned on a learned encoding.
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<br><!--kg-card-end: html--><p>Armed with this, we can construct an autoencoder-type objective. Given a desired design, an encoder network first compresses it into an encoding vector. This encoding vector is used as the initial state in the NCA, which then iterates until a final design has been created. We then optimize everything so the final design matches the desired design.</p><p>A correctly-trained StampCA learns general morphogenesis rules which can grow an arbitrary distribution of designs. There are two main advantages over single-design NCAs: </p><ol><li>StampCAs can grow new designs without additional training</li><li>Many designs can all grow in the same world</li></ol><p>Let’s look at some examples on Emoji and MNIST datasets.</p><h2 id="stampca-autoencoder-growing-emojis">StampCA Autoencoder: Growing Emojis</h2><p>The <a href="https://distill.pub/2020/growing-ca/">original Distill paper</a> trained NCAs to grow emojis, so let’s follow suit. We’ll take a dataset of ~3000 32x32 emojis from <a href="https://github.com/googlefonts/noto-emoji/tree/main/png/32">this font repo</a>. Our aim here is to train a StampCA which can grow any emoji in the dataset.</p><p>After training for about 4 hours on a Colab GPU, here’s the outputs:</p><!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/stampca_compare.mp4"></video>
                <figcaption>
StampCA Emojis growing. <a href="https://colab.research.google.com/drive/1FBEuRymdpgQiDPl5aLPrMDPVIZUM5xpg#scrollTo=8_qZe_c1uPHf">Replicate these results yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><p>Overall, it’s doing the right thing! A single network can grow many kinds of emoji. </p><p>For a StampCA network to correctly work, two main problems need to be solved: cells need to share the encoding information with each other, and cells need to figure out where they are relative to the center.</p><!--kg-card-begin: html--><div class="float-outside">
<div class="float-figure">
<video class="b-lazy b-loaded gridvideo-full" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/stampca_zoom.mp4"></video>
    <figcaption>Ripple growth pattern.</figcaption>
</div><!--kg-card-end: html--><p>It’s interesting that the growth behavior for every emoji is roughly the same -- waves ripple out from the center seed, which then form into the right colors. Since cells can only interact with their direct neighbors, information has to travel locally, which can explain this behavior. The ripples may contain information on where each cell is relative to the original seed, and/or information about the overall encoding.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--><p>A cool thing about StampCAs is that multiple designs can grow in the same world. This works because all designs share the same network, they just have different starting seeds. Since NCAs by definition only affect their local neighbors, two seeds won’t interfere with each other, as long as they’re placed far enough apart.</p><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/stampca_circle.mp4"></video>
                <figcaption>
Placing seeds down along a path.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/emoji_rand.mp4"></video>
                <figcaption>
Different designs can grow in the same world.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/stampca_mess.mp4"></video>
                <figcaption>
If seeds are placed too close, weird behavior.
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><h2 id="stampca-gan-learning-fake-mnist"><br>StampCA GAN: Learning fake MNIST</h2><p>A nice part about conditional NCAs is that they can approximate any kind of vector-to-image function. In the Emoji experiment we used an autoencoder-style loss to train NCAs that could reconstruct images. This time, let's take a look at a GAN objective: given a random vector, grow a typical MNIST digit.</p><p>OK, in an ideal world, this setup would be pretty straightforward. A GAN setup involves a generator network and a discriminator network. The generator attempts to produce fake designs, while the discriminator attempts to distinguish between real and fake. We train these networks in tandem, and over time the generator learns to produce realistic-looking fakes.</p><p>It turns out that this was a bit tricky to train, so I ended up using a hack. GANs are notoriously hard to stabilize, since the generator and discriminator need to be around the same strength. Early on, NCA behavior is quite unstable, so the NCA-based generator has a hard time getting anywhere.</p><p>Instead, the trick was to first train a traditional generator with a feedforward neural network. This generator learns to take in a random vector, and output a realistic MNIST digit. Then, we train an NCA to mimic the behavior of the generator. This is a more stable objective for the NCA to solve, and it eventually learns to basically match the generator's performance.</p><!--kg-card-begin: html--><figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/mnist_full.mp4"></video>
                <figcaption style="width:80%">
StampCA MNIST digits growing. Unlike with Emojis, we don't need to supply a dataset to encode->decode.<br>MNIST digits are generated from scratch with the GAN objective. <a href="https://colab.research.google.com/drive/1kgYy6jebUl3bPBVlybWtHOlv635HrcL2#scrollTo=iqBXLoqV5bQq">Replicate these results yourself.</a>
                </figcaption>
            </div>
        </div>
    </div>
</figure><!--kg-card-end: html--><!--kg-card-begin: html--><div class="float-outside">
<div class="float-figure">
<video class="b-lazy b-loaded gridvideo-full" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/mnist_zoom.mp4"></video>
    <figcaption>MNIST StampCA grows following the digits' curve, in comparison to the ripple of Emoji StampCA.</figcaption>
</div>
<!--kg-card-end: html--><p>The cool thing about these GAN-trained StampCAs is that we don't need a dataset of base images anymore. Since GANs can create designs from scratch, every digit that is grown from this MNIST StampCA is a fake design.</p><p>Another interesting observation is how the MNIST StampCA grows its digits. In the Emoji StampCA, we saw a sort of ripple behavior, where the emojis would grow outwards from a center seed. In the MNIST case, growth more closely follows the path of the digit. This is especially visible when generating "0" digits, since the center is hollow.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/mnist_circle.mp4"></video>
                <figcaption>
Placing seeds down along a path.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/mnist_many.mp4"></video>
                <figcaption>
Many MNIST digits in the same world. Note the "drifter" artifact coming from the 9 digit, reminiscent of cellular automata like Conway's Game of Life.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/stampca2/mnist_crowded.mp4"></video>
                <figcaption>
If seeds are placed too close, weird behavior.
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><h2 id="discussion">Discussion</h2><p>First off, you can replicate these experiments through these Colab notebooks, <a href="https://colab.research.google.com/drive/1FBEuRymdpgQiDPl5aLPrMDPVIZUM5xpg?usp=sharing">here for the Emojis</a> and <a href="https://colab.research.google.com/drive/1kgYy6jebUl3bPBVlybWtHOlv635HrcL2?usp=sharing">here for the MNIST</a>. </p><p>The point of this post was to introduce conditional NCAs and show that they exist as a tool. I didn't focus much on optimizing for performance, so I didn't spend much time tuning hyperparameters or network structures. If you play around with the models, I'm sure you can improve on the results here.</p><p>An NCA is really just a fancy convolutional network, where the same kernel is applied over and over. If we view things this way, then it's easy to plug in an NCA to whatever kind of image-based tasks you're interested in.</p><p>The interesting part about NCAs are that it becomes easy to define objectives that last over time. In this work, we cared about NCAs that matched the appearance of an image after a certain number of timesteps. We could have also done more complicated objectives: e.g, grow into a pig emoji after 50 timesteps, then transform into a cow emoji after 100.</p><p>The other aspect distinguishing NCAs are their locality constraints. In an NCA, it's guaranteed that a cell can only interact with its direct neighbors. We took advantage of this by placing multiple seeds in the same world, since we can be confident they won't interfere with one another. Locality also means we can increase the size of an NCA world to an arbitrary size, and it won't change the behavior at any location. </p><p>Looking ahead, there are a bunch of weirder experiments we can do with linking NCAs to task-based objectives. We could view an NCA as defining a multi-cellular creature, and train it to grow and capture various food spots. We can also view NCAs as communication systems, since cells have to form channels to tranfer information from one point to another. Finally, there's a bunch of potential in viewing NCAs as evo-devo systems: imagine an NCA defining a plant which dynamically adjusts how it grows based on the surroundings it encounters.</p><p>Code for these experiments is availiable at <a href="https://colab.research.google.com/drive/1FBEuRymdpgQiDPl5aLPrMDPVIZUM5xpg?usp=sharing">these</a> <a href="https://colab.research.google.com/drive/1kgYy6jebUl3bPBVlybWtHOlv635HrcL2#scrollTo=iqBXLoqV5bQq">notebooks</a>. Feel free to <a href="https://twitter.com/kvfrans/status/1379925309311442944">shout at me on Twitter</a> with any thoughts.</p><!--kg-card-begin: html--><hr class="gridhr">
<br>
<div class="float-outside">
<div class="appendix">
    <div class="float-label">
        Appendix
    </div><!--kg-card-end: html--><p>Cite as:</p><pre><code class="language-x">@article{frans2021stampca,
  title   = "StampCA: Growing Emoji with Conditional Neural Cellular Automata",
  author  = "Frans, Kevin",
  journal = "kvfrans.com",
  year    = "2021",
  url     = "http://kvfrans.com/stampca-conditional-neural-cellular-automata/"
}</code></pre><p>Why a blog post instead of a paper? I like the visualizations that the blog format allows, and I don't think a full scientific comparison is needed here – the main idea is that these kinds of conditional NCA are possible, not that the method proposed is the optimal setup.</p><!--kg-card-begin: html--><div class="float-label" style="clear: both">
	Things To Read
</div><!--kg-card-end: html--><p><strong><a href="https://distill.pub/2020/selforg/mnist/">"Growing Neural Cellular Automata"</a>, by Mordvintsev et al</strong>. This article was the original inspiration for this work in NCAs. They show that you can construct a differentiable cellular automata out of convolutional units, and then train these NCA to grow emojis. With the correct objectives, these emojis are also stable over time and can self-repair any damages. Since it's on Distill, there's a bunch of cool interactive bits to play around with.</p><p><strong><a href="https://distill.pub/selforg/2021/textures/">"Self-Organising Textures"</a>, by Niklasson et al</strong>. This Distill article is a continuation of the one above. They show that you can train NCAs towards producing styles or activating certain Inception layers. The driftiness of the NCAs leads to cool dynamic patterns that move around and are resistant to errors.</p><p><strong><a href="https://arxiv.org/abs/2103.08737">"Growing 3D Artefacts and Functional Machines with Neural Cellular Automata"</a>, by Sudhakaran et al.</strong> This paper takes the traditional NCA setup and shifts it to 3D. They train 3D NCAs to grow Minecraft structures, and are able to retain all the nice self-repair properties in this domain. </p><p><a href="https://arxiv.org/abs/2102.02579">"<strong>Regenerating Soft Robots through Neural Cellular Automata"</strong></a><strong>, by Horibe et al. </strong>This paper looks at NCAs that define soft robots, which are tested in 2D and 3D on locomotion. The cool thing here is they look at regeneration in the context of achieving a behavior – robots are damaged, then allowed to repair to continue walking. Key differences is that everything is optimized through evolution, and a separate growth+generation network is learned.</p><p><strong><a href="https://www.amazon.com/Endless-Forms-Most-Beautiful-Science/dp/0393327795">"Endless Forms Most Beautiful"</a>, by Sean B. Carroll. </strong>This book is all about evolutionary development (evo-devo) from a biological perspective. It shows how the genome defines a system which grows itself, through many interacting systems such as repressor genes and axes of organization. One can argue that the greatest aspect of life on Earth isn't the thousands of species, but that they all share a genetic system that can <em>generate</em> a thousand species. There's a lot of inspiration here in how we should construct artificial systems: to create good designs, we should focus on creating systems that can <em>produce</em> designs.</p><p><strong><a href="https://arxiv.org/abs/2006.12155">"Neural Cellular Automata Manifold"</a>, by Ruiz et al.</strong> This paper tackles a similar problem of using NCAs to grow multiple designs. Instead of encoding the design-specific information in the cell state, they try to produce unique network parameters for every design. The trick is that you can produce these parameters through a secondary neural network to bypass retraining.</p></div></div>]]></content:encoded></item><item><title><![CDATA[Quality Diversity: Evolving Ocean Creatures]]></title><description><![CDATA[<p>If you've ever folded a paper airplane, chances are you folded The Dart. It's fast and simple, and kids all around the world have launched it through the air. Soon, however, they will move on to other designs – airplanes that glide, make loops, or twirl in flight. There are thousands</p>]]></description><link>https://kvfrans.com/quality-diversity-evolving-ocean-creatures/</link><guid isPermaLink="false">60225e9617e26d47c77e8ecf</guid><category><![CDATA[research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Wed, 02 Dec 2020 02:05:00 GMT</pubDate><content:encoded><![CDATA[<p>If you've ever folded a paper airplane, chances are you folded The Dart. It's fast and simple, and kids all around the world have launched it through the air. Soon, however, they will move on to other designs – airplanes that glide, make loops, or twirl in flight. There are thousands of paper airplane designs, each with its own quirks and structure.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2021/02/dart.jpg" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/02/dart.jpg 600w, https://kvfrans.com/content/images/2021/02/dart.jpg 640w"><figcaption>Paper Airplane Designs: https://www.foldnfly.com/#/1-1-1-1-1-1-1-1-2</figcaption></figure><p>In the artificial life world, we would call paper airplane design a <strong>quality-diversity</strong> algorithm. A paper airplane should be fly fast, but we don’t only care about an optimal design – we’re also interested in airplanes with unique bodies or wing types. Rather than a single design, quality-diversity algorithms aim to develop a population of diverse solutions to a problem.</p><p>In this post, let’s explore how quality-diversity algorithms can give us a peek into a world of swimming, soft creatures.</p><h2 id="oceanworld">Oceanworld</h2><hr><!--kg-card-begin: html--><br>
<div class="float-outside">
<div class="float-figure">
<img style="width:100%" src="https://kvfrans.com/static/open/4/creature.png">
    <figcaption>An Oceanworld organism. Red squares represent nodes, and blue connections are bonds. Green lines represent rigid bonds that are affected by fluid resistance.</figcaption>
</div>
In Oceanworld, a creature is a 16x16 matrix of flexible nodes and bonds. Each node is a simulated physics object, and bonds can rotate around and deform.
<br><br>
A creature moves by a periodic heartbeat. Whenever a heartbeat begins, any muscle bonds will extend their length, deforming the shape of the creature. When the heartbeat ends, the muscle bonds will return to original length. In addition, a bond can also be rigid. Rigid bonds are affected by fluid resistance – when they move, a force will pushback in the opposite direction.
<br><br>
The great part about evolutionary methods is that this is all a black box -- we don't actually need to understand how the physics works, we just need to run the algorithm.
</div><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/1.mp4"></video>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/2.mp4"></video>
                <figcaption>
                    Various Oceanworld creatures formed by sampling random genomes. When rigid bonds move, fluid resistance applies a force.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/3.mp4"></video>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>Cool, we have some creatures. But how much of the total possible space do these creatures actually represent?</p><p>One way we can gain intuition about the possible space is through a technique called <strong>novelty search</strong>. Novelty search tries to find a population of creatures that are diverse from each other according to some metric. In our case, let’s use the number of nodes, and the ratio of bonds to nodes.</p><!--kg-card-begin: html--><div class="float-outside">
<div class="float-figure-big">
<video class="b-lazy b-loaded gridvideo-full" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/4.mp4"></video>
    <figcaption>Novelty search over Oceanworld genomes. Grey points represent the current population of genomes. Red points represent genomes sampled from a random distribution.</figcaption>
</div><!--kg-card-end: html--><p>Novelty Search works by a simple loop:</p><ol><li>Maintain a population of 1000 genomes.</li><li>Mutate each genome and add it to the population.</li><li>Sum the distance each genome has to its 10 closest neighbors, in terms of some novelty metric.</li><li>Sort the populations by increasing distance, and keep only the top 1000.</li></ol><p>Working correctly, the novelty search algorithm should optimize for a group of genomes that are spread out in terms of number of nodes and bonds-node ratio.</p><p>I always like to sanity check by looking in, so let’s take a peek at some sample points and see how the creatures look.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/5.mp4"></video>
                <figcaption>
                    10% nodes, 100% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/6.mp4"></video>
                <figcaption>
                    50% nodes, 100% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/7.mp4"></video>
                <figcaption>
                    100% nodes, 100% bond ratio
                </figcaption>
            </div>
        </div>
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/8.mp4"></video>
                <figcaption>
                    10% nodes, 10% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/9.mp4"></video>
                <figcaption>
                    50% nodes, 10% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/10.mp4"></video>
                <figcaption>
                    100% nodes, 10% bond ratio
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>There’s a key thing missing in our Oceanworld creatures – they don’t actually do anything. The “Quality” in Quality-Diversity has to measure something. So let’s choose something: a creature that swims faster is better.</p><p>First, let’s try out a greedy evolutionary algorithm. We’ll start with a random creature, and measure how far it travels. Then, we’ll randomly mutate its genome and try again. If the new creature performs better, it replaces the old one.</p><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/11.mp4"></video>
                <figcaption>
Initial genome #1
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/12.mp4"></video>
                <figcaption>
Halfway through training
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/13.mp4"></video>
                <figcaption>
Fully trained
                </figcaption>
            </div>
        </div>
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/14.mp4"></video>
                <figcaption>
Initial genome #2
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/15.mp4"></video>
                <figcaption>
Halfway through training
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/16.mp4"></video>
                <figcaption>
                    Fully trained

                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>Like a pair of deformed jellyfish, our creatures have evolved to move forwards by pulsating their muscle bonds. I doubt I would have been able to hand-design any of these creatures – the physics relationships are just too complex. Thankfully, evolution takes care of it for us.</p><p>But are these really the best creatures? The key observation here is that different starting genomes converge to different solutions. This tells us that our possibility space has local minima – genomes where any small change will hurt performance, even though a larger change could be beneficial.</p><p>It turns out that local minima are part of a classic paradox: “the path to success does not look like success”. When Carl Benz proposed the motor engine in a world of horse carriages, it was slow and heavy. But he kept at it, and motor cars emerged the victor.</p><p>In some form, Benz was using a quality-diversity algorithm: motor engines were different, so they deserved a chance. Quality-diversity algorithms help us escape local minima, because they force us to consider a group of genomes rather than a single leader. With more diversity, there is more chance to discover a path forwards.</p><h2 id="map-elites">MAP-Elites</h2><hr><p>MAP-Elites is a quality-diveristy algorithm that combines novelty search with evolution. Instead of evolving a single optimal genome, we want to evolve an optimal genome for every point in some grid of features.</p><!--kg-card-begin: html--><div class="float-outside">
<div class="float-figure-big">
<video class="b-lazy b-loaded gridvideo-full" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/18.mp4"></video>
    <figcaption>MAP-Elites over Oceanworld genomes. Yellow spots represent genomes with higher performance.</figcaption>
</div><!--kg-card-end: html--><p>In our case, we can use the two metrics from before – node count and bond ratio. Every spot in the matrix corresponds to some range of the metrics (e.g. 50-60 nodes and 30%-40% bonds), and we store the optimal genome that fits the criteria.</p><p>MAP-Elites works by:</p><ol><li>Maintain a 32 x 32 matrix of genomes.</li><li>Pick a random genome, and mutate it.</li><li>Measure its metrics, and find which spot in the grid the genome belongs to.</li><li>If the new genome performs better than the old genome in the spot, replace it.</li></ol><p>As always, let's get our hands dirty and check out what the optimal genomes end up looking like.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map10.mp4"></video>
                <figcaption>
10% nodes, 100% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map5.mp4"></video>
                <figcaption>
50% nodes, 100% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map3.mp4"></video>
                <figcaption>
100% nodes, 100% bond ratio
                </figcaption>
            </div>
        </div>
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map9.mp4"></video>
                <figcaption>
10% nodes, 50% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map4.mp4"></video>
                <figcaption>
50% nodes, 50% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map6.mp4"></video>
                <figcaption>
100% nodes, 50% bond ratio
                </figcaption>
            </div>
        </div>
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map8.mp4"></video>
                <figcaption>
10% nodes, 10% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map12.mp4"></video>
                <figcaption>
50% nodes, 10% bond ratio
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/map2.mp4"></video>
                <figcaption>
100% nodes, 10% bond ratio
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>In the MAP-Elites genomes, we get a whole bunch of different creatures, and some cool patterns emerge. The fastest creatures turn out to be the smaller ones, likely due to their lower mass. Larger creatures tend to develop a kind of "tail" -- my guess is that this acts as a stabilizer so the creature doesn't drift off direction.</p><p>The last row is especially interesting, since the larger creatures develop a completely different movement strategy. Instead of moving its entire body, the creature unrolls itself to stretch forwards.</p><p>Quality-diversity algorithms are an easy way to explore the possibilty space of an open-ended world. The techniques, such as novelty search and MAP-Elites, are simple but yield strong results when given proper metrics of diversity. In this Oceanworld, the limiting factor is the expressivity of our genome -- there are only so many ways to build swimming creatures out of a single type of bond.</p><p>Further work can involve:</p><ul><li>How can we increase the expressivity of our genome? Can we make the genome more information-dense in relation to the creature it defines? What kinds of genomes allow innovations to transfer between creatures?</li><li>What are good definitions of diversity? Can we define diversity in terms of a creature's behavior, rather than its structure?</li></ul><!--kg-card-begin: html--><hr class="gridhr">
<br>
<div class="float-outside">
<div class="appendix">
    <div class="float-label">
        Appendix
    </div><!--kg-card-end: html--><p>In the MAP-Elites method, it turns to be helpful to first run novelty search to get a population of diverse genomes. We start the MAP-Elites matrix with this diverse population, rather than a population of random genomes.</p><!--kg-card-begin: html--><div style="float: left;
  position: relative;
  width: 50%">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/17.mp4"></video>
                <p>
MAP-Elites from random starting population.
                </p>
</div>
<div style="float: right;
  position: relative;
  width: 50%">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/4/18.mp4"></video>
                <p>
MAP-Elites from novelty-search population.
                </p>
</div><!--kg-card-end: html--><!--kg-card-begin: html--><div class="float-label" style="clear: both">
	Things To Read
</div><!--kg-card-end: html--><p><em>For more on quality-diversity, a good starting point are the ICML tutorial slides below. The rest are some highlight papers with useful ideas.</em></p><p><strong><a href="https://icml.cc/media/Slides/icml/2019/halla(10-09-15)-10-09-15-4336-recent_advances.pdf">"Recent Advances in Population-Based Search"</a>, ICML 2019 Tutorial, by Jeff Clune, Joel Lehman, and Ken Stanley</strong>. This is a great presentation that goes over recent works related to quality diversity. It introduces novelty search as a way to get around environments where greedily optimzing traps you in a local minimum, and then goes to to motivate quality-diversity algorithms where we simultaneously train populations of high-performing agents.</p><p><strong><a href="https://www.cs.swarthmore.edu/~meeden/DevelopmentalRobotics/lehman_ecj11.pdf">"Abandoning Objectives: Evolution through the Search for Novelty Alone"</a>, by Joel Lehman and Ken Stanley</strong>. This is the paper on Novelty Search. It argues that basing evolutionary search on fitness is limiting, since only creatures with high fitness will be considered. In environments such as hard mazes, the fitness landscape is non-convex and greedy optimization will fail. Instead, strong diversity metrics are often a better landscape to search through.</p><p><strong><a href="https://arxiv.org/pdf/1504.04909.pdf">"Illuminating search spaces by mapping elites"</a>, by Jean-Baptiste Mouret and Jeff Clune</strong>. This is the MAP-Elites paper. They introduce the idea of "illumination algorithms", which aim to find the highest-performing solution for every point in a given feature space. The core MAP-Elites idea is pretty straightforward -- store the best solution at every point, and keep trying new solutions.</p><p><strong><a href="https://arxiv.org/abs/1901.01753">"POET: Paired Open-Ended Trailblazer"</a>, by Rui Wang et al</strong>. POET is a quality-diversity algorithm with a step towards open-endedness. The idea is similar to MAP-Elites, but instead of niches defined as properties of the genome, each niche is a task. New tasks are introduced randomly, and stay in the pool if agents can solve them. You end up with a growing population of tasks and population of agents that can solve those tasks.</p><p><br><br><em>What are good diversity metrics? Hand-designed metrics such as creature size only work to some degree. In my mind, this is an important question as "novelty" is hard to define concretely.</em></p><p><strong><a href="https://arxiv.org/pdf/1802.06070.pdf">"Diversity is All You Need"</a>, by Benjamin Eysenbach et al</strong>. This paper describes a system where we want to find a set of diverse policies. They define diversity as:</p><ol><li>A discriminator should be able to predict which policy is active based on state history.</li><li>A policy should cover as many states as it can.</li></ol><p><strong><a href="https://arxiv.org/pdf/1704.03012.pdf">"Stochastic Neural Networks for Hierarchical Reinforcement Learning"</a>, by Carlos Florensa, Yan Duan, and Pieter Abeel</strong>. Motivated by information theory, this work aims to find diverse policies by maximizing entropy. This ends up being a similar method as above -- policies should visit different states to be maximally diverse.</p><p><strong><a href="https://arxiv.org/pdf/1905.11874.pdf">"Autonomous skill discovery with Quality-Diversity and Unsupervised Descriptors"</a>, by Antoine Cully</strong>. This work takes hints from dimensionality reducing methods such as PCA (Principle-component analysis). Basically, given a set of trajectories, we can map every trajectory onto a 2D space. This mapping becomes our diversity metric for running a quality-diversity method such as MAP-Elites.</p><p><br><br><em>Aside from how to discover diverse populations, the world and environment this all takes place in is crucial as well. These are some nice references for work in ALIFE and evolved creatures.</em></p><p><strong><a href="https://openai.com/blog/evolution-strategies/">"Evolution Strategies"</a>, OpenAI</strong>. A good overview of how evolutionary systems can potentially outscale traditional reinforcement learning. The basic idea is that ES is inherently parallalizable, so systems with lots of CPU cores can scale well. ES also treats an entire episode with a single reward, so it doesn't suffer from things such as sparsity or time dependency.</p><p><strong><a href="http://www.karlsims.com/papers/siggraph94.pdf">"Evolving Virtual Creatures"</a>, by Karl Sims</strong>. This is a classic work from 1994, where virtual swimming robots are evolved to move forwards. The methods used as similar to what we did in this blog post, but the genome is defined as a hierarchy of parts, and there are basic neural networks for control.</p><p><strong><a href="https://arxiv.org/abs/1711.06605">"Evolving soft locomotion"</a>, by Francesco Corucci et al</strong>. A more recent paper that uses evolutionary methods to design soft swimming and walking robots. The world here uses creatures of deformable voxels, akin to legos, that apply a fluid resistance force when expanding (similar to what we did but in 3D). Their genome is a CPPN, a neural network C(x,y) that outputs the creature's properties at each coordinate.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--></div>]]></content:encoded></item><item><title><![CDATA[Open-Endedness 3: Multicell World]]></title><description><![CDATA[<p>How many shapes can be made from six 2x4 LEGO bricks? The answer is 915,103,765. By combining only a few simple parts, it’s surprisingly easy to achieve high levels of complexity.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2021/06/lego--1-.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/lego--1-.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/lego--1-.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/lego--1-.png 1600w, https://kvfrans.com/content/images/2021/06/lego--1-.png 1624w" sizes="(min-width: 720px) 720px"><figcaption>A Lego Counting Problem: http://web.math.ku.dk/~eilers/lego.html</figcaption></figure><p>If you look at</p>]]></description><link>https://kvfrans.com/open-endedness-3/</link><guid isPermaLink="false">60225e9617e26d47c77e8ece</guid><category><![CDATA[research]]></category><dc:creator><![CDATA[Kevin Frans]]></dc:creator><pubDate>Thu, 12 Nov 2020 01:35:00 GMT</pubDate><content:encoded><![CDATA[<p>How many shapes can be made from six 2x4 LEGO bricks? The answer is 915,103,765. By combining only a few simple parts, it’s surprisingly easy to achieve high levels of complexity.</p><figure class="kg-card kg-image-card kg-card-hascaption"><img src="https://kvfrans.com/content/images/2021/06/lego--1-.png" class="kg-image" alt srcset="https://kvfrans.com/content/images/size/w600/2021/06/lego--1-.png 600w, https://kvfrans.com/content/images/size/w1000/2021/06/lego--1-.png 1000w, https://kvfrans.com/content/images/size/w1600/2021/06/lego--1-.png 1600w, https://kvfrans.com/content/images/2021/06/lego--1-.png 1624w" sizes="(min-width: 720px) 720px"><figcaption>A Lego Counting Problem: http://web.math.ku.dk/~eilers/lego.html</figcaption></figure><p>If you look at the single-celled organisms on Earth today, you’ll mostly be stuck with bacteria. But once you enter the multi-cellular domain, suddenly life shows all sorts of shapes and structures. The <strong>possibilty space</strong> of multi-celled organisms is much richer.</p><p>Recall our principles of open-endedness:</p><ul><li><strong>Possibility Space:</strong> A world consists of a set of organisms. The space of possible organisms must be arbitrarily large.</li><li><strong>Peer Dependency:</strong> When new organisms appear, an environment must develop new niches.</li><li><strong>Search Algorithm:</strong> A procedure must find new organisms to fit new niches.</li></ul><p>In this post, let’s dive into how multi-cellularity can expand our possibility space towards increasingly complex genomes.</p><h2 id="multicell-world">Multicell World</h2><hr><p>Last time, we worked with a simple Cellworld where cells could gain energy and reproduce. Starting with this is a base, let’s restructure the genome to support multi-cellularity, and also mix in the our bitwise chemical system.</p><p>In the Multicell world, an organism is made of a group of cells. Each cell performs a single chemical reaction, creating energy. When a cell has gathered enough energy, it will create a new cell in a neighboring space.</p><p>Similar to the Bitwise Chemicals project, there are eight chemicals. Every space in the grid has its own chemical counts, and a cell will only be able to perform a reaction if the required chemicals are there. Within an organism’s cells, chemical counts and energy are averaged out.</p><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/1.mp4"></video>
                <figcaption>
                    Green cells represent the reaction (01+02). These chemicals are provided periodically, at a higher rate near the bottom.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/2.mp4"></video>
                <figcaption>
                    A brown reaction appears, which reacts the waste produced by the green reaction.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/3.mp4"></video>
                <figcaption>
                    A wide variety of cells coexist, living off the outputs of neighboring reactions.
                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>So, we now have a world that supports multi-cell organisms, and they can interact with each other and evolve. But there’s an obvious flaw: every cell is doing the same thing!</p><p>While we have different species performing different reactions, within an organism all the cells are the same. Let’s allow for cells to specialize, and perform different roles.</p><!--kg-card-begin: html--><div class="float-outside">
<div class="float-figure">
<img style="width:100%" src="https://kvfrans.com/static/open/3/cell2.png">
    <figcaption>An organism comprised of different specialized cells. Cells behave differently depending on their pheromone.</figcaption>
</div>
    In this setup, a genome is a decision tree. When a new cell is born, this tree will tell it what reaction to specialize in. Once a cell specializes, it will only perform that reaction.
<br><br>
Crucially, the genome decision tree takes as input a pheromone – four bits that are unique to each indicidual cell. When a cell makes a child, it can assign it a new pheromone, which can cause the child to specialize in a different reaction.
    <br><br>
Put together, the genome of a organism is like an algorithm that tells the organism how to grow and what its cells should specialize in. Will species evolve to take advantage of multi-cellularity?
</div><!--kg-card-end: html--><!--kg-card-begin: html--><hr class="gridhr">
<figure class="kg-card kg-gallery-card">
    <div class="kg-gallery-container">
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/4.mp4"></video>
                <figcaption>
                    In the beginning, cells consist only of green reactions, living mostly near the bottom.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/5.mp4"></video>
                <figcaption>
                    A species with both green and brown cells is twice as energy-efficient, since the output of green is used to power the brown reaction.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/6.mp4"></video>
                <figcaption>
                    The next generation begins to differentiate, and we see diverse combinations of multi-cell reactions.
                </figcaption>
            </div>
        </div>
        <div class="kg-gallery-row">
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/7.mp4"></video>
                <figcaption>
                    Here, the first species develops Killer cells. Killer cells cost energy, and have to be powered by neighboring reaction cells.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/8.mp4"></video>
                <figcaption>
                    Descendants of the first Killer species spread. They develop different base reactions to power themselves.
                </figcaption>
            </div>
            <div class="kg-gallery-image" style="flex: 1 1 0%;">
                <video class="b-lazy b-loaded gridvideo" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/9.mp4"></video>
                <figcaption>
                    The world eventually settles into species containing a few reaction cells along with sporadic Killer cells.

                </figcaption>
            </div>
        </div>
    </div>
</figure>
<hr class="gridhr"><!--kg-card-end: html--><p>Through the multi-cell experiment, we see a bunch of species appear that use multiple specialized cells to perform chains of reactions. These reactions usually synergize, with one providing the materials for the other. As the world advances, we also see cells that use their excess energy to power Killer cells and gain space.</p><p>One interesting note are the patterns of cells. With the way genomes work, there’s no real definition of a “pattern” – it all depends on the dynamics of the pheromones. Perhaps a cell with a pheromone of [0001] will produce a child of pheromone [0010], whose own child will loop back to [0001].</p><p>In real-life plants, we often see the Fibonacci series, or spiral patterns. These are patterns that naturally arise when cells grow by replication, since they are easy ways to gain repetition. In our gridworld, we often see checker-board style patterns, which are similarly easy to achieve.</p><p>Further work can involve:</p><ul><li>How can we increase the complexity of inter-species interactions beyong sharing chemicals and/or killing each other? What kinds of complexity lead to novel/interesing behavior?</li><li>Other than grid-based worlds, are there more interesting ways to represent physical organisms? What is the simplest physics system that will allow for movement?</li></ul><!--kg-card-begin: html--><hr class="gridhr">
<br>
<div class="float-outside">
<div class="appendix">
    <div class="float-label">
        Appendix
    </div>
    <video class="b-lazy b-loaded" type="video/mp4" autoplay muted playsinline loop src="https://kvfrans.com/static/open/3/recfastS.mp4"></video>
<br><br><!--kg-card-end: html--><p>Here is a full run of the Multicell world. Shown on the left are average chemical counts throughout the grid, and on the right is a legend for what reactions each color represents.</p><!--kg-card-begin: html--></div><!--kg-card-end: html--></div>]]></content:encoded></item></channel></rss>