<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <title>Cristóbal Alcázar</title>
  <link rel="self" href="https://alkzar.cl/feed.xml"/>
  <link href="https://alkzar.cl/"/>
  <updated>2024-10-02T00:00:00Z</updated>
  <id>https://alkzar.cl/</id>
  <author><name>Cristóbal Alcázar</name></author>
  <entry>
    <title>Reinforcement Learning</title>
    <link href="https://alkzar.cl/posts/reinforcement-learning/"/>
    <id>https://alkzar.cl/posts/reinforcement-learning/</id>
    <published>2024-10-02T00:00:00Z</published>
    <updated>2024-10-02T00:00:00Z</updated>
    <content type="html"><![CDATA[<details>
    <summary><b>Table of Contents</b></summary>
    <ul>
        <li><a href="#the-framework-for-learning-to-act">The Framework for Learning to Act</a></li>
        <li><a href="#policy-optimization">Policy Optimization</a>
            <ul>
                <li><a href="#learning-the-policy">Learning the Policy</a></li>
                <li><a href="#gradient-estimation-via-score-function">Gradient Estimation via Score Function</a></li>
            </ul>
        </li>
        <li><a href="#vanilla-policy-gradient-aka-reinforce">Vanilla Policy Gradient, aka REINFORCE</a></li>
        <li><a href="#actor-critic-methods">Actor-Critic Methods</a></li>
        <li><a href="#references">References</a></li>
    </ul>
</details>
<br>
<p>Reinforcement learning (RL) <a href="http://incompleteideas.net/book/the-book-2nd.html" target="_blank">(Sutton, 1998)</a> is all about the interaction between an agent and its environment, where learning occurs through trial-and-error. The agent observes the current state of the environment, takes actions based on these observations, and influences new possible state configurations while receiving rewards based on its actions. The primary objective is to maximize cumulative rewards, which drives the agent's sequence of decisions towards achieving specific goals, such as escaping from a maze, <a href="https://arxiv.org/abs/1312.5602" target="_blank">winning an Atari (Mnih. 2013) </a>, or <a href="https://deepmind.google/technologies/alphago/" target="_blank">defeating the world champion of Go (Silver, 2016)</a>. But how does the agent learn to act effectively to achieve its goal? RL algorithms are designed to maximize the total rewards obtained by the agent, thereby guiding its actions towards these objectives.</p>
<p>In this post, we will introduce the essential concepts of RL required to implement these agents. We will specifically focus on model-free RL, where the agent learns to act without constructing a model of its environment, as opposed to model-based RL, which involves such modeling. The goal is to design agents that learn to perform well solely by consuming experiences from their environment. By understanding the fundamentals of designing such agents, we will explore policy optimization methods, such as REINFORCE and PPO, which are used to refine the agent’s behavior.</p>
<p>With the knowledge gained from this chapter, we will be equipped to set-up and implement this framework under popular research environment such as ATARI pong.</p>
<!-- <b>Table of Contents:</b>
- [The Framework for Learning to Act](#the-framework-for-learning-to-act)
- [Policy Optimization](#policy-optimization)
  - [Learning the Policy](#learning-the-policy)
  - [Gradient Estimation via Score Function](#gradient-estimation-via-score-function)
- [Vanilla Policy Gradient, aka REINFORCE](#vanilla-policy-gradient-aka-reinforce)
- [Actor-Critic Methods](#actor-critic-methods)
- [References](#references) -->
<h2>The Framework for Learning to Act</h2>
<p>The starting point for designing agents that learn to act is the Markov Decision Process (MDP) framework \cite{Sutton1998}. An MDP is a mathematical object that describes the interaction between the agent and the environment. This interaction is characterized by a tuple <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">⟨</mo><mrow><mi>𝒮</mi></mrow><mo>,</mo><mrow><mi>𝒜</mi></mrow><mo>,</mo><mi>P</mi><mo>,</mo><mi>R</mi><mo>,</mo><msub><mi>ρ</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><mi>γ</mi><mo symmetric="false" stretchy="false">⟩</mo></math></span>, where:</p>
<ul>
<li><span class="math inline"><math display="inline"><mrow><mi>𝒮</mi></mrow></math></span>, <strong>state space</strong>, set of possible states in the environment.</li>
<li><span class="math inline"><math display="inline"><mrow><mi>𝒜</mi></mrow></math></span>, <strong>action space</strong>, set of possible actions available to the agent.</li>
<li><span class="math inline"><math display="inline"><mi>P</mi><mo>:</mo><mrow><mi>𝒮</mi></mrow><mo>×</mo><mrow><mi>𝒜</mi></mrow><mo>→</mo><mi mathvariant="normal">Δ</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝒮</mi></mrow><mo symmetric="false" stretchy="false">)</mo></math></span>, <strong>transition probability distribution</strong>, which gives the probability of the environment for transitioning to a new state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub></math></span> with a reward <span class="math inline"><math display="inline"><msub><mi>r</mi><mi>t</mi></msub></math></span> given the current state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub></math></span> and action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span>.</li>
<li><span class="math inline"><math display="inline"><mi>R</mi><mo>:</mo><mrow><mi>𝒮</mi></mrow><mo>×</mo><mrow><mi>𝒜</mi></mrow><mo>→</mo><mrow><mi>ℝ</mi></mrow></math></span>, <strong>reward function</strong>, which provides a scalar feedback signal <span class="math inline"><math display="inline"><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub></math></span> (aka reward) to the agent after taking an action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> and reaching the subsequent state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub></math></span>.</li>
<li><span class="math inline"><math display="inline"><msub><mi>ρ</mi><mrow><mn>0</mn></mrow></msub></math></span>, <strong>initial state distribution</strong>, which determines the probability of the agent starting in a particular state.</li>
<li><span class="math inline"><math display="inline"><mi>γ</mi><mo>∈</mo><mrow><mo stretchy="true">[</mo><mn>0</mn><mo>,</mo><mn>1</mn><mo stretchy="true">]</mo></mrow></math></span> is the <strong>discount factor</strong>, which determines the importance of future rewards.</li>
</ul>
<!-- \begin{figure}[ht]
    \centering
    \includegraphics[scale=0.63]{ch3-rl/MDP-diagram.pdf}
    \captionsetup{width=\textwidth} % set the width of the caption
    \caption{\textbf{Left:} A loop representation of a Markov Decision Process (MDP). \textbf{Right:} An unrolled MDP depecting an episodic case with a finite horizon $T$ and a parameterized policy $\pi_{\theta}$.}
    
  \end{figure} -->
<br>
<figure id="fig:mdp-diagram">
  <img src="https://alkzar.cl/img/rl/MDP-diagram.png" alt="Markov Decision Process Diagram">
  <figcaption><br><small>Figure 1. <b>Left:</b> A loop representation of a Markov Decision Process (MDP). <b>Right:</b> An unrilled MDP depecting an episodic case with a finite horizon $T$ and a parameterized policy $\pi\_{\theta}$.</small></figcaption>
</figure>
<p>Markov Decision Processes generate sequences of state-action pairs, or trajectories <span class="math inline"><math display="inline"><mi>τ</mi></math></span>, starting from an initial state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo>∼</mo><msub><mi>ρ</mi><mrow><mn>0</mn></mrow></msub></math></span>. The agent's behavior is determined by a policy <span class="math inline"><math display="inline"><mi>π</mi><mo>:</mo><mrow><mi>𝒮</mi></mrow><mo>→</mo><mi mathvariant="normal">Δ</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝒜</mi></mrow><mo symmetric="false" stretchy="false">)</mo></math></span>, which maps states to a probability distribution over actions. An action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mn>0</mn></mrow></msub><mo>∼</mo><mi>π</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> is chosen, leading to the next state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mn>1</mn></mrow></msub></math></span> according to the transition distribution <span class="math inline"><math display="inline"><mi>P</mi></math></span>, and a reward <span class="math inline"><math display="inline"><msub><mi>r</mi><mrow><mn>0</mn></mrow></msub><mo>=</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> is received. This cycle repeats iteratively, with the agent selecting actions, transitioning through states, and receiving rewards, as shown on the left side of <a href="#fig:mdp-diagram">Figure 1</a>. Thus, the trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> encapsulates the dynamic sequence of state-action pairs resulting from the agent's interaction with its environment.</p>
<p>The process can continue indefinitely, known as an infinite horizon, or be confined to episodes that end in the terminal state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub></math></span>, referred to as episodic tasks, such as winning or losing a game, as illustrated on the right side of <a href="#fig:mdp-diagram">Figure 1</a>. It is important to note that the transition to the next state depends only on the current state and action, not on the sequence of prior events. This characteristic is known as the <em>Markov property</em>, which states that the future and the past are conditionally independent, given the present (<em>memoryless</em>). In this work, we focus on the episodic setting, where the trajectory begins at <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub></math></span> and concludes at <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub></math></span>, with a finite horizon <span class="math inline"><math display="inline"><mi>T</mi></math></span>. Therefore, the trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> is defined as <span class="math inline"><math display="inline"><mi>τ</mi><mo>=</mo><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span>, summarizing the agent's behavior throughout the episodic task.</p>
<p>In reinforcement learning, the primary goal is for the agent to develop a behavior that maximizes the expected return from its actions results within the environment. This concept of maximization is formalized through the objective function <span class="math inline"><math display="inline"><msub><mrow><mi>𝒥</mi></mrow><mrow><mtext>RL</mtext></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>, which aims to maximize the expected return over a collection of trajectories $ {\tau^{(i)}}_{1:N} $ generated by the policy <span class="math inline"><math display="inline"><mi>π</mi></math></span>, commonly referred to as "policy rollouts". The term "rollout" is used to describe the process of simulating the agent's behavior in the environment by executing the policy <span class="math inline"><math display="inline"><mi>π</mi></math></span> and observing the resulting trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span>. The objective function is defined as follows:</p>
<p><div class="math display"><math display="block"><msub><mrow><mi>𝒥</mi></mrow><mrow><mtext>RL</mtext></mrow></msub><mo>=</mo><munder><mrow><mtext>maximize&nbsp;</mtext></mrow><mrow><mi>π</mi></mrow></munder><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><mi>π</mi></mrow></msub><mrow><mo stretchy="true">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></math></div></p>
<p>The return over a trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> is defined as the accumulated discounted rewards of the trajectory, <span class="math inline"><math display="inline"><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msubsup><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msubsup><msup><mi>γ</mi><mrow><mi>t</mi></mrow></msup><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub></math></span>. The reward signals <span class="math inline"><math display="inline"><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub></math></span> are the inmmediate effect of taking the actions, and the return is the cumulative rewards obtained during the trajectory, considering a discount factor <span class="math inline"><math display="inline"><mi>γ</mi></math></span>, which gives more importance to the rewards of nearer actions than to future rewards.</p>
<h2>Policy Optimization</h2>
<p>In reinforcement learning there are different approaches to solve the MDP formulated in the previous section, which are summarized in <a href="#fig:rl-model-free-taxonomy">Figure 2</a>. The most common are value-based methods and policy-based methods. In value-based methods, the agent learns which state is more valuable and take action that leads to it. In policy-based methods, the agent learns a policy that directly maps states to actions. In this work we will focus on the latter methods, specifically in policy gradients.</p>
<figure id="fig:rl-model-free-taxonomy">
  <img src="https://alkzar.cl/img/rl/rf-solve-methods-schulman-thesis-img.png" alt="Markov Decision Process Diagram">
  <figcaption>Figure 2. <b>Illustration of a taxonomy of model-free RL algorithms.</b> Source: <a href="https://rail.eecs.berkeley.edu/deeprlcourse/" target="_blank">Optimizing Expectations: From Deep Reinforcement Learning to Stochastic Computation Graphs by John, Schulman (2016)</a> \cite{schulman2016optimizing}.</figcaption>
</figure>
<p>Other approaches for finding a policy is by non solving the MDP, but by directly optimizing the policy. This is the case of derivative free optimization (DFO), or evolutionary algorithms, in which the policy is parameterized by a vector <span class="math inline"><math display="inline"><mi>θ</mi></math></span>, and the agent explores the space of parameters by searching. Nothing of the temporal structure and actions of the MDPs are considered in this kind of solution.</p>
<p>Policy gradient methods provide a way to reduce reinforcement learning to stochastic gradient descent, by providing a connection between how function approximation is solved in supervised learning settings, but with the key diffrence that the dataset is collected using the model itself plus a reward signal that acts as a "label".</p>
<h3>Learning the Policy</h3>
<p>The starting point is to think of trajectories as units of learning instead of individual observations (i.e., actions). What dynamics generate a trajectory?
Given a policy <span class="math inline"><math display="inline"><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></math></span>, represented as a function with parameter <span class="math inline"><math display="inline"><mi>θ</mi><mo>∈</mo><msup><mrow><mi>ℝ</mi></mrow><mrow><mi>d</mi></mrow></msup></math></span>, whose input is a representation of the state and whose output is action selection probabilities, we can deploy the agent into its environment at an initial state <span class="math inline"><math display="inline"><msub><mi>s</mi><mn>0</mn></msub></math></span> and observe its actions in inference mode or <em>evaluation phase</em> \citep{sutton1999policy}. The agent continuously promotes actions based on the current state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub></math></span> until the episode ends in a terminal state, when <span class="math inline"><math display="inline"><mi>t</mi><mo>=</mo><mi>T</mi></math></span>. At this point, we can determine if the goal was accomplished, such as winning the ATARI Pong game, <a href="https://github.com/alcazar90/ddpo-celebahq" target="_blank"><i>or generating aesthetically pleasing samples from a diffusion model</i></a>.
The returns are the scalar value that assets perfomance whether we have achieved the ultimate goal, effectively acting as a "proxy" of a label for the overall trajectory. Thus, the trajectory serves as our unit of learning, and the remaining task is to establish the feedback mechanism for the <em>learning phase</em>.</p>
<p>Intuitivelly, we want to collect the trajectories and make the good trajectories and actions more probable, and push the actions towards betters actions.</p>
<p>Mathematically, we aim to perform stochastic optimization to learn the agent’s parameters. This involves obtaining gradient information from sample trajectories, with performance assessed by a scalar-value function (i.e. reward). The optimization is stochastic because both the agent and the environment contain elements of randomness, meaning we can only compute estimates of the gradient. Crucially, we are estimating the gradient of the expected return with respect to the policy parameters. To address this, we employ Monte Carlo Gradient Estimation \citep{mohamed2020monte}, specifically using the score function method. From a machine learning perspective, this involves dealing with the stochasticity of the gradient estimates, <span class="math inline"><math display="inline"><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover></math></span>, and using gradient ascent algorithms to update the policy parameters based on these estimates, along with a learning rate <span class="math inline"><math display="inline"><mi>α</mi></math></span> to control the step size of the optimization process,</p>
<pre><code class="code">\theta \leftarrow \theta + \alpha \hat{g}_{N}.
</code></pre>
<h3>Gradient Estimation via Score Function</h3>
<p>The gradient estimation can be obtained using the score function gradient estimator. Let's introduce the following probability objective <span class="math inline"><math display="inline"><mrow><mi>ℱ</mi></mrow></math></span>, defined in the <a href="https://en.wikipedia.org/wiki/Ambient_space_(mathematics)" target="_blank">ambient space</a> <span class="math inline"><math display="inline"><mrow><mi>𝒳</mi></mrow><mo>∈</mo><msup><mrow><mi>ℝ</mi></mrow><mi>n</mi></msup></math></span> and with parameters <span class="math inline"><math display="inline"><mi>θ</mi><mo>∈</mo><msup><mrow><mi>ℝ</mi></mrow><mi>n</mi></msup></math></span>,</p>
<p><div class="math display"><math display="block"><mrow><mi>ℱ</mi></mrow><mo symmetric="false" stretchy="false">(</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi><mo>;</mo><mi mathvariant="normal">θ</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mtext>&nbsp;</mtext><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>Here, <span class="math inline"><math display="inline"><mi>f</mi></math></span> is a scalar-valued function, similar to how the reward is represented in the reinforcement learning setting. The <em>score function</em> is the derivative of the log probability distribution  <span class="math inline"><math display="inline"><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> with respect to its parameters <span class="math inline"><math display="inline"><mi>θ</mi></math></span>. We can use the following identity to establish a connection between the score function and the probability distribution <span class="math inline"><math display="inline"><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>,</p>
<pre><code class="code">\begin{split}
    \nabla_\theta\log p(\mathrm{x};\theta) &amp;= \frac{\nabla_{\theta}p(\mathrm{x}; \theta)}{p(\mathrm{x};\theta)} \\\\
    p(\mathrm{x};\theta) \nabla_{\theta}\log p(\mathrm{x};\theta) &amp;= \nabla_{\theta}p(\mathrm{x};\theta).
\end{split}
</code></pre>
<p>Therefore, taking the gradient of the objective <span class="math inline"><math display="inline"><mrow><mi>ℱ</mi></mrow><mo symmetric="false" stretchy="false">(</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> with respect to the parameter <span class="math inline"><math display="inline"><mi>θ</mi></math></span>, we have</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><mi>g</mi><mo>=</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></mtd><mtd><mo>=</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mi>𝒳</mi></mrow></msub><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mtext>&nbsp;</mtext><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></mtd></mtr></mtable></math></div></p>
<p>The use of the log-derivative rule on the above equation to introduce the score function is also known as the <a href="https://blog.shakirm.com/2015/11/machine-learning-trick-of-the-day-5-log-derivative-trick/" target="_blank"><em>log-derivative trick</em></a>. Now, we can compute an estimate of the gradient, <span class="math inline"><math display="inline"><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover></math></span>, using Monte Carlo estimation with samples from the distribution <span class="math inline"><math display="inline"><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> as follows:</p>
<p><div class="math display"><math display="block"><msub><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover><mrow><mi>N</mi></mrow></msub><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mi>N</mi></mrow></mfrac><munderover><mo movablelimits="false">∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mrow><mi>N</mi></mrow></munderover><mi>f</mi><mrow><mo stretchy="true">(</mo><msup><mover><mrow><mrow><mi mathvariant="normal">x</mi></mrow></mrow><mi>^</mi></mover><mrow><mo symmetric="false" stretchy="false">(</mo><mi>i</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo stretchy="true">)</mo></mrow><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mrow><mo stretchy="true">(</mo><msup><mover><mrow><mrow><mi mathvariant="normal">x</mi></mrow></mrow><mi>^</mi></mover><mrow><mo symmetric="false" stretchy="false">(</mo><mi>i</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo>;</mo><mi>θ</mi><mo stretchy="true">)</mo></mrow><mo>.</mo></math></div></p>
<p>We draw <span class="math inline"><math display="inline"><mi>N</mi></math></span> samples <span class="math inline"><math display="inline"><mover><mrow><mrow><mi mathvariant="normal">x</mi></mrow></mrow><mi>^</mi></mover><mo>∼</mo><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>, compute the gradient of the log-probability for each sample, and multiply by the scalar-valued function <span class="math inline"><math display="inline"><mi>f</mi></math></span> evaluated at the sample. The average of these terms is an unbiased estimate of the gradient of the objective <span class="math inline"><math display="inline"><mi>g</mi></math></span>, which we can use for gradient ascent.</p>
<p>There are two important points to mention about the previous equation.</p>
<ul>
<li>The function <span class="math inline"><math display="inline"><mi>f</mi></math></span> can be any arbitrary function we can evaluate on <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">x</mi></mrow></math></span>. Even if <span class="math inline"><math display="inline"><mi>f</mi></math></span> is non-differentiable with respect to <span class="math inline"><math display="inline"><mi>θ</mi></math></span>, it can still be used to compute the gradient estimation <span class="math inline"><math display="inline"><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover></math></span>.</li>
<li>The expectation of the score function is zero, meaning that the gradient estimator is unbiased</li>
</ul>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></mtd><mtd><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mfrac><mrow><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></mfrac><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝓍</mi></mrow></mrow></msub><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mo movablelimits="false">∫</mo><mrow><mrow><mi>𝒳</mi></mrow></mrow></msub><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mi>d</mi><mrow><mi mathvariant="normal">x</mi></mrow><mo>=</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mn>1</mn><mo>=</mo><mn>0</mn><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>The last point is particularly useful because we can replace <span class="math inline"><math display="inline"><mi>f</mi></math></span> with a shifted version given a constant <span class="math inline"><math display="inline"><mi>β</mi></math></span>, and still obtain an unbiased estimate of the gradient, which can be beneficial for the optimization task:</p>
<p><div class="math display"><math display="block"><msub><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover><mrow><mi>N</mi></mrow></msub><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><msub><mrow><mi mathvariant="normal">x</mi></mrow><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mo symmetric="false" stretchy="false">(</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>β</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>Using a <em><strong>baseline function</strong></em> to determine <span class="math inline"><math display="inline"><mi>β</mi></math></span>, that does not depend on the parameter
<span class="math inline"><math display="inline"><mi>θ</mi></math></span>, can reduce the variance of the estimator \citep{mohamed2020monte}. The baseline function,
which satisfies the property that the score function expectation is zero, can be any function independent of <span class="math inline"><math display="inline"><mi>θ</mi></math></span>. When a baseline is chosen to be close to the scalar-valued function <span class="math inline"><math display="inline"><mi>f</mi></math></span>, it effectively reduces the variance of the estimator. This reduction in variance helps stabilize the updates by minimizing fluctuations in the gradients estimates, leading to more reliable and efficient learning.</p>
<h2>Vanilla Policy Gradient, aka REINFORCE</h2>
<p>The REINFORCE algorithm \citep{williams1992simple} translates the previous
derivation of gradient estimation via the score function into reinforcement learning terminology. This is the earliest member of the Policy Gradient family (Figure~\ref{fig:rl-model-free-taxonomy}), where the objective is to maximize the expected return of the trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> under a policy <span class="math inline"><math display="inline"><mi>π</mi></math></span> parameterized by <span class="math inline"><math display="inline"><mi>θ</mi></math></span> (e.g., a neural network). At each state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub></math></span>, the agent takes an action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> according to the policy <span class="math inline"><math display="inline"><mi>π</mi></math></span>, which generates a probability distribution over actions <span class="math inline"><math display="inline"><mi>π</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>. Here, we will use the notation <span class="math inline"><math display="inline"><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo symmetric="false" stretchy="false">)</mo></math></span> instead of <span class="math inline"><math display="inline"><mi>π</mi><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo>;</mo><mi>θ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>.</p>
<p>As we mentioned in previous section, a trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> represents the sequence of state-action pairs resulting from the agent's interaction with its environment. From the initial state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub></math></span> to the terminal state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub></math></span>, the trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> is a sequence of states and actions, <span class="math inline"><math display="inline"><mi>τ</mi><mo>=</mo><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span>, which describes how the agent acts during the episodic task. Let <span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> be the
probability of obtaining the trajectory under the policy <span class="math inline"><math display="inline"><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></math></span>.</p>
<p>We thus have a distribution of trajectories. Remember that the trajectory <span class="math inline"><math display="inline"><mi>τ</mi></math></span> is the learning unit for our policy <span class="math inline"><math display="inline"><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></math></span>, as it
tells us if the consequences of each action led to a favorable final outcome on the terminal state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub></math></span> (e.g. win/lose). The goal is to maximize the exptected return of the trajectories on average, and the return <span class="math inline"><math display="inline"><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> could be the cumulative rewards obtained during the <em><strong>episode</strong></em> or the discounted rewards. The expected return is given by the following expression:</p>
<p><div class="math display"><math display="block"><mrow><mi>𝒥</mi></mrow><mo symmetric="false" stretchy="false">(</mo><mi>θ</mi><msub><mo symmetric="false" stretchy="false">)</mo><mrow><mtext>RL</mtext></mrow></msub><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>This is the objective we want to maximize, which is a particular case of Equation~(\ref{eqn:probability-objective}) with the scalar-valued function <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi mathvariant="normal">x</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>, representing the return of the trajectory. Let's use the techniques from the previous section to compute the gradient of the objective in Equation~(\ref{eqn:rl-objective}) with respect to the policy parameter <span class="math inline"><math display="inline"><mi>θ</mi></math></span>. The gradient estimation is given by:</p>
<p><div class="math display"><math display="block"><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mo symmetric="false" stretchy="false">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo symmetric="false" stretchy="false">]</mo><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>What is <span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> exactly? Given that the trajectory is a sequence of states and actions, and assuming the Markov property imposed by the MDP, the probability of the trajectory is defined as follows:</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mtd><mtd><mo>=</mo><msub><mi>p</mi><mi>θ</mi></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>1</mn></mrow></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>T</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><mi>ρ</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mtext>&nbsp;</mtext><munderover><mo movablelimits="false">∏</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mtext>&nbsp;</mtext><mi>P</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>In the above expression, <span class="math inline"><math display="inline"><mi>ρ</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mn>0</mn></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> denotes the distribution of initial states, while <span class="math inline"><math display="inline"><mi>P</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub><mo>,</mo><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> represents the transition model, which updates the environment context based on the action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> taken in the current state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub></math></span>. A crucial step in estimating the gradient is computing the logarithm of the trajectory probability. Following this, we calculate the gradient with respect to the policy parameter <span class="math inline"><math display="inline"><mi>θ</mi></math></span>,</p>
<pre><code class="code">\begin{split}
    \log p\_{\theta}(\tau) &amp;= \log \rho(s_0) + \sum\_{t=0}^{T-1}\log \pi_{\theta}(a_{t}\mid s\_{t}) + \log P(s\_{t+1}, r\_{t}\mid a\_{t}, s\_{t}) \\\\
    \nabla\_{\theta}\log p\_{\theta}(\tau) &amp;= \log \nabla\_{\theta}\rho(s\_0) + \sum\_{t=0}^{T-1}\nabla\_{\theta}\log \pi\_{\theta}(a\_{t}\mid s\_{t}) + \log\nabla\_{\theta} P(s\_{t+1}, r\_{t}\mid a\_{t}, s\_{t}) \\\\
    \nabla\_{\theta} \log p\_{\theta}(\tau) &amp;=  \sum\_{t=0}^{T-1}\nabla\_{\theta}\log \pi\_{\theta}(a\_{t}\mid s\_{t}).
\end{split}
</code></pre>
<p>The distribution of initial states and the transition probabilities are disregarded because they are independent of <span class="math inline"><math display="inline"><mi>θ</mi></math></span>, thereby simplifying significantly the computations needed for gradient estimation. By substituting the final expression from Equation~(\ref{eqn:trajectory-gradient-score}) into the gradient estimation of the objective in Equation~(\ref {eqn:rl-gradient-estimator-vanilla}), we derive the REINFORCE gradient estimator</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><mi>g</mi></mtd><mtd><mo>=</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mo symmetric="false" stretchy="false">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo symmetric="false" stretchy="false">]</mo></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mi>t</mi></msub><mo>∣</mo><msub><mi>s</mi><mi>t</mi></msub><mo symmetric="false" stretchy="false">)</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover></mtd><mtd><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mo>∣</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup><mo>∣</mo></mrow></mfrac><munder><mo movablelimits="false">∑</mo><mrow><mi>τ</mi><mo>∈</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></mrow></munder><mrow><mo stretchy="true">[</mo><mtext>&nbsp;</mtext><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>The core concept is to collect a set of trajectories <span class="math inline"><math display="inline"><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></math></span> under the policy <span class="math inline"><math display="inline"><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></math></span> and update the policy parameters <span class="math inline"><math display="inline"><mi>θ</mi></math></span> to increase the likelihood of high-reward trajectories while decreasing the likelihood of low-reward ones, as illustrated in Figure~\ref{fig:anatomy-rl-trajectories}. This trial-and-error learning approach, described in <a href="#alg:reinforce">Algorithm 1</a>, repeats this process over multiple iterations, reinforcing successful trajectories and discouraging unsuccessful ones, thus encoding the agent's behavior in its parameters.</p>
<!-- % algoritmo naive REINFORCE -->
<div id="alg:reinforce">
    <big><b>Algorithm 1: Vanilla Policy Gradient, aka REINFORCE</b></big>
    <ol>
        <li>Initialize policy \( \pi_{\theta} \), set learning rate \( \alpha \)</li>
        <!-- The commented out line can be included or excluded as needed -->
        <!-- <li>Generate \( \tau=(s_0, a_0, ..., s_{T-1}, a_{T-1}, s_{T}) \) by sampling from current \( \pi_{\theta} \)</li> -->
        <li>For \( \text{iteration}=0, 1, 2, \dots, N \):
            <ol>
                <li>Collect a set of trajectories \( \mathcal{D}^{\pi_{\theta}}=\{\tau^{(i)}\} \) by sampling from the current policy \( \pi_{\theta} \)</li>
                <li>Calculate the returns \( R(\tau) \) for each trajectory \( \tau\in\mathcal{D}^{\pi_{\theta}} \)</li>
                <li>Update the policy: \( \theta \leftarrow \theta + \alpha \left(\frac{1}{|\mathcal{D}^{\pi_{\theta}}|}\sum_{\tau\in\mathcal{D}^{\pi_{\theta}}}\left[\sum_{t=0}^{T-1}\nabla_{\theta}\log\pi_{\theta}(a_{t}| s_{t})R(\tau)\right]\right) \)</li>
            </ol>
        </li>
    </ol>
</div>
<p><strong>Reducing the variance of the estimator</strong>. Using two techniques,
<a href="https://spinningup.openai.com/en/latest/spinningup/rl_intro3.html#don-t-let-the-past-distract-you" target="_blank">reward-to-go</a> and <em>baseline</em>, we can improve the quality of the gradient estimator in Equation~(\ref{eqn:reinforce-gradient-estimator}).</p>
<figure id="fig:anatomy-rl-trajectories" style="text-align: center;">
  <img src="https://alkzar.cl/img/rl/simulated-trajectories-levine-slides.png" alt="Simulated trajectories levine slides">
  <figcaption>
    <b>Illustration of three simulated trajectories</b>, denoted as $\\\{\tau^{(i)}\\\}$ where $i=(1,2,3)$, traversing the parametric space $\theta\in\mathbb{R}^2$ under the policy $\pi\_{\theta}$. Each trajectory is marked with a colored symbol (cross, check) representing its _goodness_ based on the reward function $R(\tau^{(i)})$. <b>Source:</b> <a href="https://rail.eecs.berkeley.edu/deeprlcourse/" target="_blank">Policy Gradients Lecture, Deep Reinforcement Learning Course</a> by Sergey Levine.
  </figcaption>
</figure>
<p>The reward-to-go technique is a simple trick that can reduce the variance of the gradient estimator by taking advantage of the <em>temporal structure</em> of the problem. The idea is to weight the gradient of the log-probability of an action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> by the sum of rewards from the current timestep <span class="math inline"><math display="inline"><mi>t</mi></math></span> to the end of the trajectory <span class="math inline"><math display="inline"><mi>T</mi><mo>−</mo><mn>1</mn></math></span>. This way, the gradient of the log-probability of an action is only weighted by the consequence of that action on the future rewards, removing terms that do not depend on <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span>. Let's introduce this technique by using the gradient estimation in Equation~(\ref{eqn:reinforce-gradient-estimator}) and replacing <span class="math inline"><math display="inline"><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span> naively using the sum of total trajectory reward <strong>\footnote</strong>{The same applies for discounted returns or other kind of returns <span class="math inline"><math display="inline"><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>.}</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><mover><mrow><mi>g</mi></mrow><mi>^</mi></mover></mtd><mtd><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mo>∣</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup><mo>∣</mo></mrow></mfrac><munder><mo movablelimits="false">∑</mo><mrow><mi>τ</mi><mo>∈</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></mrow></munder><mrow><mo stretchy="true">[</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mo>∣</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup><mo>∣</mo></mrow></mfrac><munder><mo movablelimits="false">∑</mo><mrow><mi>τ</mi><mo>∈</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></mrow></munder><mrow><mo stretchy="true">[</mo><mtext>&nbsp;</mtext><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mrow><mo stretchy="true">(</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub><mo>+</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>r</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo stretchy="true">)</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mo>∣</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup><mo>∣</mo></mrow></mfrac><munder><mo movablelimits="false">∑</mo><mrow><mi>τ</mi><mo>∈</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></mrow></munder><mrow><mo stretchy="true">[</mo><mtext>&nbsp;</mtext><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>r</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mtext>&nbsp;</mtext><mo stretchy="true">]</mo></mrow><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>As we saw at the end of Section~\ref{sec:gradient-estimation-score-function}, it is possible to reduce the variance of the gradient estimator by using a baseline function, <span class="math inline"><math display="inline"><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span>, without biasing the estimator. However, is the expectation of the score still unbiased in this setting?</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mrow><mo stretchy="true">(</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>r</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo>−</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">)</mo></mrow><mo stretchy="true">]</mo></mrow><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>The proof follows a similar argument as shown in Equation~(\ref{eqn:score-function-expectation-zero}), with the key difference being that the expectation is taken with respect <span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></math></span>, which is a sequence of random variables. By leveraging the linearity of the expectation property, we can focus on a single term at step <span class="math inline"><math display="inline"><mi>t</mi></math></span> of Equation~(\ref{eqn:reinforce-gradient-estimator-baseline}) to demonstrate that the baseline does not affect the expectation of the score function. We split the trajectory sequence <span class="math inline"><math display="inline"><mi>τ</mi></math></span> at step <span class="math inline"><math display="inline"><mi>t</mi></math></span> into: <span class="math inline"><math display="inline"><msub><mi>τ</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub></math></span> and <span class="math inline"><math display="inline"><msub><mi>τ</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn><mo>:</mo><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub></math></span>, and then expand it into state-action pairs <strong>\footnote</strong>{A criterion used when splitting the trajectory is that state-action pairs are formed given that <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub></math></span> is a consequence of action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></math></span>, and taking action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> results in state <span class="math inline"><math display="inline"><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub></math></span>. Notice both expectations from step 1 and 2 in Equation~(\ref{eqn:reinforce-baseline-unbiased}).}</p>
<p><div class="math display"><math display="block"><mtable class="menv-alignlike"><mtr><mtd><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mi>t</mi></msub><mo>∣</mo><msub><mi>s</mi><mi>t</mi></msub><mo symmetric="false" stretchy="false">)</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mi>t</mi></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>τ</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mn>0</mn><mo>:</mo><mi>t</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>τ</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mi>t</mi><mo>+</mo><mn>1</mn><mo>:</mo><mi>T</mi><mo>−</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn><mo>:</mo><mi>T</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mi>t</mi><mo>:</mo><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn><mo>:</mo><mi>T</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mi>t</mi><mo>:</mo><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>s</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi></mrow></msub><mo>,</mo><msub><mi>a</mi><mrow><mn>0</mn><mo>:</mo><mi>t</mi><mo>−</mo><mn>1</mn></mrow></msub></mrow></msub><mrow><mo stretchy="true">[</mo><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mn>1</mn><mo stretchy="true">]</mo></mrow></mtd></mtr><mtr><mtd></mtd></mtr><mtr><mtd></mtd><mtd><mo>=</mo><mn>0</mn><mo>.</mo></mtd></mtr></mtable></math></div></p>
<p>We can remove irrelevant variables from the expectation over the portion of the trajectory <span class="math inline"><math display="inline"><msub><mi>τ</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mi>t</mi><mo>+</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo><mo>:</mo><mo symmetric="false" stretchy="false">(</mo><mi>T</mi><mo>−</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo></mrow></msub></math></span> because we are focusing on the term at step <span class="math inline"><math display="inline"><mi>t</mi></math></span>. The only relevant variable is <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span>, and the expectation <span class="math inline"><math display="inline"><msub><mrow><mi>𝔼</mi></mrow><mrow><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> is 1. Given that the gradient with respect to <span class="math inline"><math display="inline"><mi>θ</mi></math></span> of a constant is zero, and <span class="math inline"><math display="inline"><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span> is multiplying it, the effect of the baseline on the expectation is nullified. This argument can be applied to any other term in the sequence due to the linearity of the expectation. Therefore, we have proven that using a baseline also keeps the gradient estimator unbiased in the policy gradient setting.</p>
<p>Choosing an appropriate baseline is a critical decision in reinforcement learning \citep{foundations-deeprl-series-l3}, as different methods can offer unique strengths and limitations. Common baselines include fixed values, moving averages, and learned value functions.</p>
<ul>
<li>Constant baseline: <span class="math inline"><math display="inline"><mi>b</mi><mo>=</mo><mrow><mi>𝔼</mi></mrow><mrow><mo stretchy="true">[</mo><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">]</mo></mrow><mo>≈</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mi>m</mi></mrow></mfrac><msubsup><mo movablelimits="false">∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mrow><mi>m</mi></mrow></msubsup><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msup><mi>τ</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mi>i</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo symmetric="false" stretchy="false">)</mo></math></span>.</li>
<li>Optimal constant baseline: $ b=\frac{\sum_{i}(\nabla_{\theta} \log P_{\theta}(\tau^{(i)}))^{2} R(\tau^{(i)})}{\sum_{i}(\nabla_{\theta}\log P_{\theta}(\tau^{(i)}))^{2}}$.</li>
<li>Time-dependent baseline: <span class="math inline"><math display="inline"><msub><mi>b</mi><mrow><mi>t</mi></mrow></msub><mo>=</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mi>m</mi></mrow></mfrac><msubsup><mo movablelimits="false">∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mrow><mi>m</mi></mrow></msubsup><msubsup><mo movablelimits="false">∑</mo><mrow><mi>k</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msubsup><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msubsup><mi>s</mi><mrow><mi>k</mi></mrow><mrow><mo symmetric="false" stretchy="false">(</mo><mi>i</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msubsup><mo>,</mo><msubsup><mi>a</mi><mrow><mi>k</mi></mrow><mrow><mo symmetric="false" stretchy="false">(</mo><mi>i</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msubsup><mo symmetric="false" stretchy="false">)</mo></math></span>.</li>
<li>State-dependent expected return: <span class="math inline"><math display="inline"><mi>b</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mrow><mi>𝔼</mi></mrow><mrow><mo stretchy="true">[</mo><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub><mo>+</mo><msub><mi>r</mi><mrow><mi>t</mi><mo>+</mo><mn>1</mn></mrow></msub><mo>+</mo><msub><mi>r</mi><mrow><mi>t</mi><mo>+</mo><mn>2</mn></mrow></msub><mo>+</mo><mi>⋯</mi><mo>+</mo><msub><mi>r</mi><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></msub><mo stretchy="true">]</mo></mrow><mo>=</mo><msup><mi>V</mi><mrow><mi>π</mi></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span>.</li>
</ul>
<p>The control variates method can significantly reduce estimator variance, enhancing the stability and performance of RL algorithms \cite{NIPS2001_584b98aa}. Despite the nuances and differences among baseline methods, the primary concept is the <em>advantage</em>, shown in Equation~(\ref{eqn:pg-objective-with-value-baseline}), which refers to increase log probabilities of action <span class="math inline"><math display="inline"><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub></math></span> proportionally to how much its returns, <span class="math inline"><math display="inline"><msub><mi>r</mi><mrow><mi>t</mi></mrow></msub></math></span>, are better than the expected return under the current policy, which is determined by the value function <span class="math inline"><math display="inline"><msup><mi>V</mi><mrow><mi>π</mi></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span></p>
<p><div class="math display"><math display="block"><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mrow><mo stretchy="true">(</mo><munder><munder><mrow><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><msup><mi>V</mi><mrow><mi>π</mi></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mo stretchy="true">⏟</mo></munder><mrow><mtext>advantage</mtext></mrow></munder><mo stretchy="true">)</mo></mrow><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>What remains is how do we get estimates for <span class="math inline"><math display="inline"><msup><mi>V</mi><mrow><mi>π</mi></mrow></msup></math></span> in practice.</p>
<h2>Actor-Critic Methods</h2>
<p>Actor-Critic referred to learn concurrently models for the policy and the value function. This methods are more data efficient because they amortize the samples collected <span class="math inline"><math display="inline"><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></math></span> used for Monte Carlo estimations while reducing the variance of the gradient estimator. The actor controls how the agent behaves---<em>by updating the policy parameters <span class="math inline"><math display="inline"><mi>θ</mi></math></span> as we see in previous sections</em>---whereas the critic measures how good the taken action is, and could be a state-value (<span class="math inline"><math display="inline"><mi>V</mi></math></span>) or action-value (<span class="math inline"><math display="inline"><mi>Q</mi></math></span>) <strong>\footnote</strong>{Action-value function (<span class="math inline"><math display="inline"><mi>Q</mi></math></span>) refers to the value of take action <span class="math inline"><math display="inline"><mi>a</mi></math></span> on state <span class="math inline"><math display="inline"><mi>s</mi></math></span> under a policy <span class="math inline"><math display="inline"><mi>π</mi></math></span>.} function. Notice that we are combining in some way both approaches for solving MDPs as is depicted in Figure~\ref{fig:rl-model-free-taxonomy}.</p>
<p>We are introducing a new function approximator for the value function, <span class="math inline"><math display="inline"><msub><mi>V</mi><mrow><mi>ϕ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo></math></span>, where <span class="math inline"><math display="inline"><mi>ϕ</mi></math></span> are the parameters of the value function</p>
<p><div class="math display"><math display="block"><msub><mrow><mi>𝔼</mi></mrow><mrow><mi>τ</mi><mo>∼</mo><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>τ</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msub><mrow><mo stretchy="true">[</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><msub><mi>∇</mi><mrow><mi>θ</mi></mrow></msub><mspace width="0.1667em" /><mi>log</mi><mo>⁡</mo><mspace width="0.1667em" /><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi></mrow></msub><mo>∣</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mrow><mo stretchy="true">(</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><msubsup><mi>V</mi><mrow><mi>ϕ</mi></mrow><mrow><mi>π</mi></mrow></msubsup><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">)</mo></mrow><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p>The objective is to minimize the mean squared error (MSE) between the estimated value and the empirical return, i.e. we are regress the value against empirical return in a supervised learning fashion</p>
<!-- V% \ca{Mencionar conexión con el mse a partir de la varianza del gradiente? (Seita post)}: -->
<p><div class="math display"><math display="block"><mi>ϕ</mi><mo>←</mo><munder><mrow><mi>arg</mi><mo>⁡</mo><mspace width="0.1667em" /><mi>min</mi></mrow><mrow><mi>ϕ</mi></mrow></munder><mfrac><mrow><mn>1</mn></mrow><mrow><mo>∣</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup><mo>∣</mo></mrow></mfrac><munder><mo movablelimits="false">∑</mo><mrow><mi>τ</mi><mo>∈</mo><msup><mrow><mi>𝒟</mi></mrow><mrow><msub><mi>π</mi><mrow><mi>θ</mi></mrow></msub></mrow></msup></mrow></munder><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><mrow><mo stretchy="true">[</mo><msup><mrow><mo stretchy="true">(</mo><mrow><mo stretchy="true">(</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>t</mi><mi>′</mi><mo>=</mo><mi>t</mi></mrow><mrow><mi>T</mi><mo>−</mo><mn>1</mn></mrow></munderover><mi>R</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>a</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo>,</mo><msub><mi>s</mi><mrow><mi>t</mi><mi>′</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">)</mo></mrow><mo>−</mo><msub><mi>V</mi><mrow><mi>ϕ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><msub><mi>s</mi><mrow><mi>t</mi></mrow></msub><mo symmetric="false" stretchy="false">)</mo><mo stretchy="true">)</mo></mrow><mn>2</mn></msup><mtext>&nbsp;</mtext><mo stretchy="true">]</mo></mrow><mo>.</mo></math></div></p>
<p><a href="#alg:reinforcemet-with-critic">Algorithm 2</a> describes the steps for a REINFORCE variant with advantage , which combines the actor-critic approach with the traditioinoal REINFORCE algorithm. More components were introduced and can influence in the performance when the algorithm is implemented. For instance, the policy and value networks can share parameters or not. A useful study that make abalations and suggestions to pay attention when these algorithms are implemented is <i>What Matters In On-Policy Reinforcement Learning? A Large-Scale Empirical Study (Andrychowicz, 2020 \cite{andrychowicz2020mattersonpolicyreinforcementlearning})</i>.</p>
<div id="alg:reinforce-with-critic">
    <big><b>Algorithm 2: REINFORCE with advantage</b></big>
    <ol>
        <li>Initialize policy \( \pi_{\theta} \)</li>
        <li>Initialize value \( V_{\phi} \)</li>
        <li>Set learning rates \( \alpha_{a} \) and \( \alpha_{c} \)</li>
        <li>For \( \text{iteration}=0, 1, 2, \dots, N \):
            <ol>
                <li>Collect a set of trajectories \( \mathcal{D}^{\pi_{\theta}}=\{\tau^{(i)}\} \) by sampling from the current policy \( \pi_{\theta} \)</li>
                <li>Calculate the returns \( R(\tau) \) for each trajectory \( \tau\in\mathcal{D}^{\pi_{\theta}} \)</li>
                <li>Update the policy:
                    <ul>
                        <li>\( \theta \leftarrow \theta + \alpha_{a} \left(\frac{1}{|\mathcal{D}^{\pi_{\theta}}|}\sum_{\tau\in\mathcal{D}^{\pi_{\theta}}}\left[\sum_{t=0}^{T-1}\nabla_{\theta}\log\pi_{\theta}(a_{t}| s_{t})\left(\sum_{t'=t}^{T-1} R(a_{t'}, s_{t'}) - V_{\phi}^{\pi_{\theta}}(s_{t})\right)\right]\right) \)</li>
                    </ul>
                </li>
                <li>Update the value:
                    <ul>
                        <li>\( \phi \leftarrow \phi + \alpha_{c} \left(\frac{1}{|\mathcal{D}^{\pi_{\theta}}|}\sum_{\tau\in\mathcal{D}^{\pi_{\theta}}}\left[\sum_{t=0}^{T-1}\left(\sum_{t'=t}^{T-1} R(a_{t'}, s_{t'}) - V_{\phi}^{\pi_{\theta}}(s_{t})\right)\nabla_{\phi}V_{\phi}^{\pi_{\theta}}(s_{t})\right]\right) \)</li>
                    </ul>
                </li>
            </ol>
        </li>
    </ol>
</div>
<h2>References</h2>
<p>
    [1] Sutton, R. S. (2018). 
    <a href="http://incompleteideas.net/book/the-book-2nd.html" target="_blank">Reinforcement learning: An introduction</a>. A Bradford Book.
</p>
<p>
    [2] Mnih, V. (2013). <a href="https://arxiv.org/abs/1312.5602" target="_blank">Playing atari with deep reinforcement learning</a>. arXiv preprint arXiv:1312.5602.
</p>
<p>
    [3] Silver, D., Schrittwieser, J., Simonyan, K., Antonoglou, I., Huang, A., Guez, A., ... & Hassabis, D. (2017). <a href="https://www.nature.com/articles/nature24270" target="_blank">Mastering the game of go without human knowledge</a>. nature, 550(7676), 354-359.
</p>
<p>
    [4] Schulman, J. (2016). <a href="https://www2.eecs.berkeley.edu/Pubs/TechRpts/2016/EECS-2016-217.html" target="_blank">Optimizing expectations: From deep reinforcement learning to stochastic computation graphs</a> (Doctoral dissertation, UC Berkeley).
</p>
<!-- 

## Improving Sample Efficiency: Behavior and Target Policies


The main drawback of the REINFORCE algorithm is its sample complexity. Once we roll out the policy and collect the data, we cannot reuse it after the policy has been updated. We must collect new data following the \textit{target policy} $\pi_{\theta}$ that we want to update. In RL literature, this is referred to as \textit{on-policy} learning. Reusing the data $\mathcal{D}\sim\pi_{\theta_{\text{old}}}$ to update the current policy $\pi_{\theta}$ would significantly improve sample efficiency\footnote{This issue also arises when attempting to transfer behavior from one task to another using existing data.}. However, once we update the policy, the previously collected data is no longer valid because the policy has changed. The distribution from which the data was sampled is now $\pi_{\theta_{\text{old}}}$. \\

 Using behavior data learned from another policy, known as a \textit{behavior policy}, to update the current policy is referred to as \textit{off-policy} learning in RL literature. Let's introduce a \textit{behavior policy} in the RL objective defined in Equation~(\ref{eqn:rl-objective}) using \href{https://timvieira.github.io/blog/post/2014/12/21/importance-sampling/}{importance sampling} (See Mckay book, Section 29.2 \cite{mackay-book}):

% Derive RL objective with importance sampling to use data from another policy

    \begin{split}
        \nabla_{\theta}\mathcal{J}(\theta) &= \mathbb{E}_{\tau \sim p_{\theta}(\tau)}\bigg[ \nabla_{\theta}\log p_{\theta}(\tau) R(\tau)\bigg] \\
        &= \mathbb{E}_{\tau \sim p_{\theta}(\tau)} \bigg[ \frac{\nabla_{\theta}p_{\theta}(\tau)}{p_{\theta}(\tau)} R(\tau) \bigg] \\
        &= \int_{\mathcal{X}} p_{\theta}(\tau) \frac{\nabla_{\theta}p_{\theta}(\tau)}{p_{\theta}(\tau)} R(\tau) d\tau \\
        &= \int_{\mathcal{X}} \frac{p_{\theta_{\text{old}}}(\tau)}{p_{\theta_{\text{old}}}(\tau)} \cancel{p_{\theta}(\tau)} \frac{\nabla_{\theta}p_{\theta}(\tau)}{\cancel{p_{\theta}}(\tau)} R(\tau) d\tau \\
        &= \int_{\mathcal{X}} p_{\theta_{\text{old}}}(\tau) \frac{\nabla_{\theta}p_{\theta}(\tau)}{p_{\theta_{\text{old}}}(\tau)} R(\tau) d\tau \\
        &= \mathbb{E}_{\tau\sim p_{\theta_{\text{old}}}(\tau)} \bigg[\frac{\nabla_{\theta} p_{\theta}(\tau)}{p_{\theta_{\text{old}}}(\tau)} R(\tau)\bigg].
    \end{split}


 We derive a new objective that is more general and reconciles both \textit{on-policy} and \textit{off-policy} learning in the importance weight,
or importance correction ($p_{\theta}(\tau) / p_{\theta_{\text{old}}}(\tau)$) 

% RL objective with IS

    \mathcal{J}_{\text{IS}}(\theta) = \mathbb{E}_{\tau\sim p_{\theta_{\text{old}}}(\tau)}\bigg[\frac{p_{\theta}(\tau)}{p_{\theta_{\text{old}}}(\tau)} R(\tau)\bigg].


 We can assume that the data collected from the behavior policy is
not so different from the target policy, and use first order approximation to
update the policy 


    \begin{split}
        \nabla_{\theta}\mathcal{J}(\theta)\rvert_{\theta=\theta_{\text{old}}} &= \mathbb{E}_{\tau\sim p_{\theta_{\text{old}}}(\tau)} \bigg[\frac{\nabla_{\theta} p_{\theta}(\tau)\rvert_{\theta=\theta_{\text{old}}}}{p_{\theta_{\text{old}}}(\tau)} R(\tau)\bigg] \\
        &= \mathbb{E}_{\tau\sim p_{\theta_{\text{old}}}(\tau)} \big[\nabla_{\theta}\log p_{\theta}(\tau)\rvert_{\theta=\theta_{\text{old}}} R(\tau) \big].
    \end{split}



 \textbf{The problem with first order approximation}. The gradient estimation it is good only in the inmediate vecinity, because is a local approximation of the function. Hence, the step size is crucial to avoid a policy degradation, a situation where the policy is updated with a bad gradient,
it is difficult to recover from this situation. Given that the data is collected by the policy, the feedback loop can be dangerous for the training
stability. \\

## Trust Region and Proximal Policy Optimization

Trust Region Policy Optimization (TRPO) \cite{schulman2015trust} allows us to
avoid the policy degradation given bad updates. The idea is to 
define a trust region in which update the policy parameter is safer and
balancing the policy improvement with stability

\begin{align}
    \text{Surrogate loss:} \quad & \underset{\pi_{\theta}}{\max}~L(\pi) = \mathbb{E}_{\pi_{\theta_{\text{old}}}} \left[ \frac{\pi_{\theta}(a\mid s)}{\pi_{\theta_{\text{old}}}(a\mid s)} A^{\pi_{\theta_{\text{old}}}}(s, a) \right]  \\
    \text{Constraint:} \quad & \mathbb{E}_{\pi_{\text{old}}} \left[ D_{\text{KL}}(\pi_{\theta} || \pi_{\theta_{\text{old}}}) \right]  \leq \epsilon \nonumber.
\end{align}

% \textbf{Maximize data efficiency in comparison to traditional policy gradients}. 

 Increase data efficiency while avoiding step size problems in updating parameters, compared to traditional policy gradients (PG). The main idea is to improve a surrogate objective significantly while making minimal changes to the policy. These minimal changes are quantified using the KL divergence between action distributions. The trust region is the area where the new policy remains close to the old one, guarantee training stability. \\

% The trust region is the area where the new policy remains close to the old one, allowing for constrained improvement. \\

 Proximal Policy Optimization (PPO) \cite{schulman2017proximal} is about simplify TRPO in order to (i) be easier to implement avoiding solve the second order optimization in Equation~(\ref{eqn:trpo-loss}), (ii) taking advantage of first order optimizer such as ADAM \cite{kingma2017adammethodstochasticoptimization}, and (iii) be more compatible with neural networks operations such as dropout that are incompatible with TRPO setting. \\

 Let's rename the importance weights as the probability ratio $r$: 


    r_{t}(\theta) = \frac{\pi_{\theta}(a_{t}\mid s_{t})}{\pi_{\theta_{\text{old}}}(a_{t}\mid s_{t})}.


 The strategy is to keep this ratio closer to 1. We can create a trust region via clipping the ratio to force within a range $\left[1-\epsilon, 1+\epsilon \right]$,


\mathcal{L}^{\text{CLIP}}(\theta) = \hat{\mathbb{E}}_t \left[ \min \left( r_t(\theta) \hat{A}_t, \text{clip}(r_t(\theta), 1 - \epsilon, 1 + \epsilon) \hat{A}_t \right) \right].


 For a walkthrough implementation that cover important details avoid in the paper and that impact significatnly in the performance, review the work \textit{``The 37 implementation details of proximal policy optimization''} (Huang, 2023 \cite{dlr191986}).


% algoritmo naive REINFORCE
\begin{algorithm}
    \caption{Proximal Policy Optimization (PPO), Actor-Critic Style}
    \begin{algorithmic}
    \STATE Initialize policy parameter $\theta$, set learning rate $\alpha$
    \STATE Initialize value $V_{\phi}$
    \FOR{$\text{iteration}=0, 1, 2, \dots N$}
        \FOR{$\text{actor}=0, 1, 2, \dots M$}
            \STATE Run policy $\pi_{\theta_{\text{old}}}$ in environment for $T$ timesteps        
            \STATE Compute advantage estimates $\hat{A}_{0}, \dots, \hat{A}_{T-1}$
        \ENDFOR
        \STATE Optimize surrogate $\mathcal{L}^{\text{CLIP}}$ wrt $\theta$ (Equation~\ref{eqn:clip-ac-objective}), with $K$ epochs and minibatch size $M\leq NT$
    \ENDFOR
    \end{algorithmic}
\end{algorithm}


## Summary

In this chapter, we have explored the foundational concepts and methodologies in reinforcement learning (RL). The core of RL is the interaction between an agent and its environment, where learning occurs through trial-and-error. The agent's goal is to maximize cumulative rewards by taking actions based on its observations, influencing the state of the environment, and receiving rewards. 

 We began by introducing the Markov Decision Process (MDP), a mathematical framework that describes the interaction between the agent and the environment. An MDP is characterized by a state space, action space, transition probabilities, and reward functions. The agent aims to learn a policy that maximizes the expected return, which is the sum of discounted rewards over time. 

 We then delved into policy optimization methods, focusing on policy gradients, a popular approach in model-free RL. Policy gradient methods reduce RL to a problem of stochastic gradient descent, leveraging trajectories of state-action pairs to update the policy parameters. We discussed techniques such as the reward-to-go and baselines to reduce the variance of gradient estimators, thus improving learning efficiency. 

 In conclusion, reinforcement learning offers a powerful framework for designing intelligent agents capable of learning optimal behaviors through interaction with their environment. By understanding and implementing the principles and techniques covered in this chapter, one can develop sophisticated RL agents for a wide range of applications. 
 -->
]]></content>
  </entry>
  <entry>
    <title>A Deep Learning Workflow Part 1, Hugging Face datasets + Weights &amp; Biases</title>
    <link href="https://alkzar.cl/posts/a-deep-learning-project-workflow-part-1/"/>
    <id>https://alkzar.cl/posts/a-deep-learning-project-workflow-part-1/</id>
    <published>2023-02-22T00:00:00Z</published>
    <updated>2023-02-22T00:00:00Z</updated>
    <content type="html"><![CDATA[<p>Last update: 22/02/2023</p>
<figure>
<img src="/img/deep-learning-workflow/post-banner.png"
alt="wandb.ai/alcazar90/cell-segmentation W&B project runs.summary picture">
</figure> 
<a href="https://colab.research.google.com/drive/1tCN__7HxJ61WFUm14kr6ziloNaOH9Def?usp=sharing" target="_blank">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
<p><i>This post was highlighted by the Weights &amp; Biases community and published in their <a href="https://wandb.ai/fully-connected" target="_blank">Fully Connected blog</a>. You can read the interactive version <a href="https://wandb.ai/alcazar90/cell-segmentation/reports/A-Deep-Learning-Project-Workflow-Part-1--VmlldzozNjE4NzYy" target="_blank">here.</a></i></p>
<blockquote>
<p><em>tl;dr Colabs are powerful, but they make experimentation difficult. In this article, we explore how to change your workflow with HuggingFace and Weights &amp; Biases.</em></p>
</blockquote>
<p>Over the years, I've used many Google Colab notebooks. They're great for experimentation and sharing your deep learning projects with others and save you the hassle of needing to set up a Python environment. That makes it much easier to open a notebook and start to figure out what's actually inside or to jump directly into working on your problem.</p>
<p>Past that, if you’re working on a deep learning problem, you'll undoubtedly require GPUs. Thankfully, Colab provides you with a free usage quota. Exporting into a standard Jupyter notebook is trivial if you want to start a GitHub repository and move the notebook there.</p>
<p>In a nutshell, Colab is an excellent tool to prototype small to mid-size projects and create tutorials and interactive code to share with your community.</p>
<p>Still, they aren't perfect. Here are the two major friction points for Colabs as I see them:</p>
<ol>
<li>Experimenting with your custom dataset challenges reproducibility and makes collaboration harder.</li>
<li>Training or fine-tuning a model involves running the Colab multiple times and changing hyperparameters many times. Things quickly get messy. Tracking your experiments in Colab is suboptimal and basically nets out to you using a knife as a fork.</li>
</ol>
<p>Thankfully, there are two open-source tools that I've started to use to alleviate both problems: <a href="https://huggingface.co/docs/datasets/index" target="_blank">Hugging Face Datasets</a>
and <a href="https://wandb.ai/site" target="_blank">Weights &amp; Biases</a>. If you'd like to follow along with this article as a Colab, please follow the link above!</p>
<p>In this post, I'll discuss how these tools allow you to transition from a project notebook approach into a more mature deep learning repository with the respective python modules and a command line interface for running your experiments wherever you want.</p>
<h3>Example Project: Fine-Tuning an Image Segmentation Model</h3>
<p>Our project today involves fine-tuning an image segmentation model (<a href="https://arxiv.org/abs/2105.15203" target="_blank">SegFormer</a>) with cellular images from a high-throughput microscope 🔬.</p>
<p>The idea is to train a model using cellular photography with mask labels that denote the living cells. A good model can connect directly to the microscope and help scientists detect cells quickly affected by a given treatment.</p>
<center>
<figure>
<img src="/img/deep-learning-workflow/one_cell.png"
alt="One observation from the alkzar90/cell_benchmark dataset">
</figure> 
<figcaption>
Figure 1: One observation from the dataset. At the left is the input image, and at the right is the mask for labelling the transformed cells.
</figcaption>
</center>
<br>
<p>We are talking about reproducibility and collaboration so that you can follow the <a href="https://colab.research.google.com/drive/1tCN__7HxJ61WFUm14kr6ziloNaOH9Def?usp=sharing" target="_blank">Google Collab Notebook</a> (it's the same at the banner at the start of the post). The notebook has three sections:</p>
<ol>
<li><strong>Image Segmentation Walkthrough with SegFormer 📸:</strong> Model usage on this specific domain task and how the dataset interacts with them.</li>
<li><strong>Training + Weights &amp; Biases experiment tracking 🪄 + 🐝:</strong> training, hyperparameter optimization, and experiment tracking using Weights &amp; Biases (W&amp;B).</li>
<li><strong>Training script via command line 🚀:</strong> A section to experiment with different model configurations using a training script.</li>
</ol>
<p>These three sections broadly follow the development of the project.
First, we'll understand how to use the model and the dataset manipulation to feed it. Next, we'll start working on training, namely what we want to track and record, and the hyperparameters configurations (such as the learning rate and batch size).</p>
<p>These two initial sections work as an internal exploration and as documentation for anyone who wants to understand the project. The third and final step is an engineering effort to make life easier and get the job done.</p>
<h3>Hugging Face Datasets</h3>
<p>If you've dabbled in machine learning, chances are you've worked on the classic MNIST dataset alongside PyTorch. The code below makes this easy so you don't have to run boilerplate code or download MNIST every time you want to experiment. You can instead focus on learning new models and concepts.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import python">import</span> <span class="meta qualified-name python"><span class="meta generic-name python">torchvision</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">datasets</span></span> <span class="keyword control import as python">as</span> <span class="meta qualified-name python"><span class="meta generic-name python">datasets</span></span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">MNISt</span></span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">datasets</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">MNIST</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">root</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">./data<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">train</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">download</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">transform</span><span class="keyword operator assignment python">=</span><span class="constant language python">None</span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>In the same spirit of pulling MNIST from PyTorch, we want our data for this project in a central repository, ready to consume, and easy to share. We can get these three features for free using the HuggingFace dataset repository.</p>
<figure>
  <img src="/img/deep-learning-workflow/hf-cell-dataset.png" alt="alkzar90/cell_benchmark Hugging Face dataset picture">
    <center><figcaption>Figure 2: <a href="https://huggingface.co/datasets/alkzar90/cell_benchmark" target="_blank">alkzar90/cell_benchmark</a> Hugging Face dataset repository</figcaption></center>
</figure>
<br>
<p>In Figure 2, we have a picture of the dataset repository of the example project, a page where you have a preview and documentation of the data. There is also a <a href="https://huggingface.co/datasets/alkzar90/cell_benchmark/tree/main" target="_blank">repository for storing images, text, or other data you have</a> and there is no limit size for the data storage reported in the documentation. Still, I created a repository that stored the <a href="https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset" target="_blank">NIH Chest X-ray dataset</a> (&gt; 40gb) without problems, and there are datasets with terabytes of memory.</p>
<p>The purpose of investing time in creating a repository for your dataset is that you'll end up with a Python module for downloading and loading the data in the same fashion that <a href="https://pytorch.org/vision/stable/datasets.html" target="_blank"><code>torchvision.datasets</code></a> provides you with the MNIST and others benchmark datasets.</p>
<p>Below you can see how we load the dataset for the cell segmentation example project:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">datasets</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">load_dataset</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">repo_name</span></span> <span class="keyword operator assignment python">=</span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">alkzar90/cell_benchmark<span class="punctuation definition string end python">&quot;</span></span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">train_ds</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">load_dataset</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">repo_name</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">split</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">train<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">streaming</span><span class="keyword operator assignment python">=</span><span class="constant language python">False</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">val_ds</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">load_dataset</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">repo_name</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">split</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">validation<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">streaming</span><span class="keyword operator assignment python">=</span><span class="constant language python">False</span></span><span class="punctuation section arguments end python">)</span></span> 
<span class="meta qualified-name python"><span class="meta generic-name python">test_ds</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">load_dataset</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">repo_name</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">split</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">test<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">streaming</span><span class="keyword operator assignment python">=</span><span class="constant language python">False</span></span><span class="punctuation section arguments end python">)</span></span> 
</span></code></pre>
<p>How does HuggingFace know how the dataset load (i.e. read files and split the data)?</p>
<ol>
<li>The module supports <a href="https://huggingface.co/docs/datasets/how_to" target="_blank">common data types</a> by default, such as tabular, images, text, audio, etc. You need to follow a template for storing the organized dataset, something like <a href= "https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html" target="_blank"><code>torchvision.datasets.ImageFolder</code></a> approach (i.e. data_split/label/image_001.jpg). This helps the Hugging Face dataset module figures out how to load your data.</li>
<li>Sometimes you must accommodate multiple datatypes  (<a href="https://huggingface.co/datasets/alkzar90/CC6204-Hackaton-Cub-Dataset" target="_blank">CUB 200 2011 dataset</a>), or you might want to provide different configuration of the dataset for various tasks such as image classification and object detection (<a href="https://huggingface.co/datasets/alkzar90/rock-glacier-dataset" target="_blank">in the dataset preview rock glacier dataset, notice the data subset option</a>), or your dataset type lacks default support by the library. In this case, you must write a custom python (<a href="https://huggingface.co/datasets/alkzar90/cell_benchmark/blob/main/cell_benchmark.py" target="_blank">simple example</a> / <a href="https://huggingface.co/datasets/alkzar90/NIH-Chest-X-ray-dataset/blob/main/NIH-Chest-X-ray-dataset.py" target="_blank">complex example</a>) loader script to tell the module how to navigate the file structure, deal with data and labels.</li>
</ol>
<p>The datasets repository works similarly to a GitHub repo. You can get version code and data with commits, collaborate with others via pull requests, and have README for general dataset documentation. There is also support for <a href="https://huggingface.co/docs/datasets/about_arrow" target="_blank">apache arrow format</a> to get the data in streaming mode, Finally, when the data is too heavy to load at once, you can download it, and via the caching system, you can load batches on demand in memory.</p>
<h3>Experiment tracking with W&amp;B</h3>
<p>Weight &amp; Biases (W&amp;B or <code>wandb</code>) provides free services for logging information about your project into a web server that you can monitor from a dashboard. It’s helpful to think of your W&amp;B project as a database with tools for interacting with your experiment information.</p>
<p>Once you have a W&amp;B account, you can create a project such as <a href="https://wandb.ai/alcazar90/cell-segmentation" target="_blank">alcazar90/cell-segmentation</a> to log the information from each experiment you run.</p>
<p>In section 2 of the google colab, sub-section “🦾 Run experiment”, you'll initialize a run with <code>wandb.init</code> providing the following arguments: (i.) name of the project and (ii.) a config dictionary that provides context about your experimentation such as number of epochs, batch size, etc. Also, you can name your runs something memorable, but if you don’t, W&amp;B create random expressions such as resplendent-rocket-27 or abundant-moon-38 (yes, the number is the experiment number).</p>
<p>Commonly, there will be a lot of runs in your project because when you get a taste of the improvements in you can make in how you log information, you'll find yourself getting a ton of new ideas.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="variable other constant python">PROJECT</span></span> <span class="keyword operator assignment python">=</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">cell-segmentation<span class="punctuation definition string end python">&#39;</span></span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">wandb</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">init</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">project</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="variable other constant python">PROJECT</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">config</span><span class="keyword operator assignment python">=</span><span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">epochs<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="variable other constant python">EPOCHS</span></span><span class="punctuation separator dictionary-or-set python">,</span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">batch_size<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">BS</span></span><span class="punctuation separator dictionary-or-set python">,</span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">lr<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">LR</span></span><span class="punctuation separator dictionary-or-set python">,</span>
                                    <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">lr_scheduler_exponential__gamma<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="variable other constant python">GAMMA</span></span><span class="punctuation separator dictionary-or-set python">,</span> 
                                    <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">seed<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="variable other constant python">SEED</span></span><span class="punctuation section dictionary-or-set end python">}</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Add additional configs to wandb if needed
</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">wandb</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">config</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">len_train<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">len</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">datasets</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">train_ds</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">wandb</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">config</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">len_val<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">len</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">datasets</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">valid_ds</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>For example, you can see the config information in the dashboard from the run <code>auspicious-paper-44</code> in the <a href="https://wandb.ai/alcazar90/cell-segmentation/runs/goqt0zp7/overview?workspace=user-alcazar90" target="_blank">overview option at the left menu</a>. There is a table describing the context of this experiment (mostly hyperparameters settings in this case):</p>
<figure>
<img src="/img/deep-learning-workflow/w&b-config-screenshoot.png"
alt="wandb.ai/alcazar90/cell-segmentation W&B project runs.summary picture">
    <center><figcaption>Figure 3: <a href="https://wandb.ai/alcazar90/cell-segmentation/runs/goqt0zp7?workspace=user-alcazar90" target="_blank">alcazar90/cell-segmentation</a> W&B project</figcaption></center>
</figure>
<br>
<p>After initializing a run and logging the config, we want to log information during the model training. Typically we want to track the main metrics in the train and validation set; these will be floating points across time that we log using <code>wandb.log</code>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">epoch</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tqdm</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">EPOCHS</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
  <span class="constant language python">...</span>
  <span class="meta qualified-name python"><span class="meta generic-name python">metrics</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">train/train_loss<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">train_loss</span></span><span class="punctuation separator dictionary-or-set python">,</span> 
               <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">train/epoch<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">step</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">1</span> <span class="keyword operator arithmetic python">+</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">n_steps_per_epoch</span></span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">epoch</span></span><span class="punctuation section group end python">)</span></span><span class="punctuation section group end python">)</span></span> <span class="keyword operator arithmetic python">/</span> <span class="meta qualified-name python"><span class="meta generic-name python">n_steps_per_epoch</span></span><span class="punctuation separator dictionary-or-set python">,</span> 
               <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">train/example_ct<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">example_ct</span></span><span class="punctuation separator dictionary-or-set python">,</span>
               <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">train/cur_learning_rate<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">state_dict</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">_last_lr<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="punctuation section dictionary-or-set end python">}</span></span>
  <span class="constant language python">...</span>
  <span class="meta qualified-name python"><span class="meta generic-name python">val_metrics</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">val/val_loss<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">val_loss</span></span><span class="punctuation separator dictionary-or-set python">,</span> 
                 <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">val/val_accuracy<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">accuracy</span></span><span class="punctuation separator dictionary-or-set python">,</span>
                 <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">val/mIoU<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="meta qualified-name python"><span class="meta generic-name python">mIoU</span></span><span class="punctuation section dictionary-or-set end python">}</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">wandb</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">log</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">metrics</span></span><span class="punctuation separator dictionary-or-set python">,</span> <span class="keyword operator unpacking mapping python">**</span><span class="meta qualified-name python"><span class="meta generic-name python">val_metrics</span></span><span class="punctuation section dictionary-or-set end python">}</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>W&amp;B knows how to display these metrics, so it makes charts for you automatically in the run’s dashboard.</p>
<figure>
<img src="/img/deep-learning-workflow/w&b-charts.png"
alt="wandb.ai/alcazar90/cell-segmentation W&B project runs.summary picture">
    <center><figcaption>Figure 4: <a href="https://wandb.ai/alcazar90/cell-segmentation/runs/goqt0zp7?workspace=user-alcazar90" target="_blank">alcazar90/cell-segmentation</a> W&B project</figcaption></center>
</figure>
<br>
<p>Beyond the obvious things to log (like training and validation loss), you can log whatever you want for your specific project. In the Figure 5, I log information from the dev set into a
<a href="https://wandb.ai/stacey/mnist-viz/reports/Guide-to-W-B-Tables--Vmlldzo2NTAzOT" target="_blank">wand.Table</a> includes:</p>
<ul>
<li>The actual image and mask</li>
<li>The predicted mask</li>
<li>The probability map (it’s so cool)</li>
<li>The <strong>intersection over union</strong> (iou) metric for individual examples</li>
</ul>
<figure>
<img src="/img/deep-learning-workflow/w&b-project-summary-table.png"
alt="wandb.ai/alcazar90/cell-segmentation W&B project runs.summary picture">
    <center><figcaption>Figure 5: <a href="https://wandb.ai/alcazar90/cell-segmentation/runs/goqt0zp7?workspace=user-alcazar90" target="_blank">alcazar90/cell-segmentation</a> W&B project</figcaption></center>
</figure>
<br>
```python
# 🐝 Create a wandb Table to log images, labels and predictions to
table = wandb.Table(columns=["image", "mask", "pred_mask", "probs", "iou"])
for img, mask, pred, prob, iou_metric in zip(images.to("cpu"), masks.to("cpu"), predicted.to("cpu"), probs.to("cpu"), iou_by_example.to("cpu")):
    plt.imshow(prob.detach().cpu());
    plt.axis("off");
    plt.tight_layout();
    table.add_data(
      wandb.Image(img.permute(1,2,0).numpy()), 
      wandb.Image(mask.view(img.shape[1:]).unsqueeze(2).numpy()),
      wandb.Image(np.uint8(pred.unsqueeze(2).numpy())*255),
      wandb.Image(plt),
      iou_metric)
```
<p>Notice in the code that <code>wand.Table</code> has image columns that we add using <code>wand.Image</code> and requires numpy arrays as input, but you can also log plots created with matplotlib, like in the case of the probability column. These allow us to have tables with rendered images as values that we can inspect quickly. This feature is convenient as a complement to traditional metrics. However, for generative image models, checking the pictures you generated by the model during training gives you more information about your model than tracking the loss.</p>
<p>Finally, on your project’s main page, in the “Table” option at the left menu, you have a bird-eye view of all runs and their metrics to compare. You can export this info into a csv file or download it by API to analyze.</p>
<figure>
<img src="/img/deep-learning-workflow/w&b-runs-table.png"
alt="wandb.ai/alcazar90/cell-segmentation W&B project, summary table">
    <center><figcaption>Figure 5: the <a href="https://wandb.ai/alcazar90/cell-segmentation/table?workspace=user-alcazar90" target="_blank">summary table</a> by each experiment running in the alcazar90/cell-segmentation project</figcaption></center>
</figure>
<br>
<p><em><strong>Note 1:</strong> Whenever you initialize a run (<code>wandb.init</code>), W&amp;B will ask you to provide the API key for authentication; you can find it at W&amp;B&gt;settings&gt;API keys.</em></p>
<p><em><strong>Note 2:</strong> There is a short course of W&amp;B called <a href="https://www.wandb.courses/courses/effective-mlops-model-development">"Effective MLOps: Model Development"</a> to learn the fundamentals.</em></p>
<h3>Training Script Via Command Line 🚀</h3>
<p>In the last section, we saw how to integrate W&amp;B to log information about our model training. Still, fine-tuning a model or training from scratch requires a lot of experimentation. The idea is to iterate and try many configurations.</p>
<p>And sure, you can do this in the notebook, but doing it that way is redundant and non-optimal. Think about re-running every time the code cells after you change the batch size in your data loaders, for example. That's not ideal. The next step is to wrap all the code cells required to train your model into a <strong>training script</strong>, such as downloading the dataset, creating the data loaders, importing utility functions, and setting hyperparameters and training configurations.</p>
<p>Wrapping the code into a training script, plus using the
<a href="https://docs.python.org/3/library/argparse.html"><code>argparse</code></a> module, you'll be able to call the training script directly from the command line:</p>
<pre><code class="code lang-python"><span class="source python">!<span class="meta qualified-name python"><span class="meta generic-name python">python</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">finetune_model</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">py</span></span> <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">train_batch_size</span></span> <span class="constant numeric integer decimal python">4</span> <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">validation_batch_size</span></span> <span class="constant numeric integer decimal python">3</span><span class="punctuation separator continuation line python">\</span>
     <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">init_learning_rate</span></span> <span class="constant numeric float python">3e-4</span> <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">learning_rate_scheduler_gamma</span></span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>92</span><span class="punctuation separator continuation line python">\</span>
     <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">num_train_epochs</span></span> <span class="constant numeric integer decimal python">15</span> <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">reproducibility_seed</span></span> <span class="constant numeric integer decimal python">42313988</span><span class="punctuation separator continuation line python">\</span>
     <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">log_images_in_validation</span></span> <span class="constant language python">True</span> <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">dataloader_num_workers</span></span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator continuation line python">\</span>
     <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">model_name</span></span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">huggingface-segformer-nvidia-mit-b0<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator continuation line python">\</span>
     <span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">project_name</span></span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">cell-segmentation<span class="punctuation definition string end python">&quot;</span></span></span>
</span></code></pre>
<p>You can see the <a href="https://github.com/alcazar90/cell-segmentation/blob/main/finetune_model.py" target="_blank">training script here</a>, but the main heavy work is done by the argparse module where you can define a parser protocol to define and read the parameters for running the script via the command line. The idea is as follows:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import python">import</span> <span class="meta qualified-name python"><span class="meta generic-name python">argparse</span></span></span> 

<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">parse_args</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">input_args</span></span><span class="meta function parameters default-value python"><span class="keyword operator assignment python">=</span><span class="constant language python">None</span></span><span class="meta function parameters python"><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">parser</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">argparse</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">ArgumentParser</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">description</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">Training loop script for fine-tune a pretrained SegFormer model.<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
  
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">parser</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">add_argument</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
    <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">--train_batch_size<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">type</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="support type python">int</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">default</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">4</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">help</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">Batch size (per device) for<span class="invalid illegal unclosed-string python">
</span></span></span>    <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">training</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">dataloader</span></span><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python"><span class="invalid illegal unclosed-string python">
</span></span></span>  </span><span class="punctuation section arguments end python">)</span></span>
  <span class="constant language python">...</span>
</span></code></pre>
<p>Thus far, we'll developed the entire code project in Google Colab, created a HuggingFace dataset repository, and integrated W&amp;B to log the model training information. But the project doesn’t have any home in GitHub.</p>
<p>When is it actually necessary to create a code repository for the project? It depends. For example, if creating the dataset requires pre-processing scripts and tests, keeping all those files in a GitHub repository makes sense. Regardless, the moment we develop the training script, creating a GitHub repository for the training script and its dependencies is a good decision. We want to make the training script accessible, even for us. In the last section of the google colab notebook, I downloaded the training script from the GitHub repo to train using the computation provided by colab.</p>
<pre><code class="code lang-python"><span class="source python">!<span class="meta qualified-name python"><span class="meta generic-name python">wget</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">https</span></span><span class="punctuation separator annotation variable python">:</span><span class="keyword operator arithmetic python">/</span><span class="keyword operator arithmetic python">/</span><span class="meta qualified-name python"><span class="meta generic-name python">raw</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">githubusercontent</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">com</span></span><span class="keyword operator arithmetic python">/</span><span class="meta qualified-name python"><span class="meta generic-name python">alcazar90</span></span><span class="keyword operator arithmetic python">/</span><span class="meta qualified-name python"><span class="meta generic-name python">cell</span></span><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">segmentation</span></span><span class="keyword operator arithmetic python">/</span><span class="meta qualified-name python"><span class="meta generic-name python">main</span></span><span class="keyword operator arithmetic python">/</span><span class="meta qualified-name python"><span class="meta generic-name python">finetune_model</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">py</span></span>
</span></code></pre>
<p>Side actions to do with your repo: upload a Jupyter notebook version of the Colab and write a nice readme to provide context.</p>
<h3>Next Steps 🦶🏼</h3>
<p>If you (or I) wanted to continue the project, some next steps to consider:</p>
<ol>
<li>Create a bash script for running a <code>.txt</code> file with different training
configurations.</li>
<li>Create files to separate code from <code>training.py</code>, such as <code>model.py</code> for code related to downloading and loading models and <code>inference.py</code> for computing evaluation metrics.</li>
<li>Use a cloud provider such as <a href="https://lambdalabs.com/service/gpu-cloud" target="_blank">Lambda Cloud</a> to connect via ssh and run the training script. Check if the results save in W&amp;B.</li>
<li>Explore how to use Hugging Face GitHub actions/webhooks to save model checkpoints in HF every time the training script outperforms the current best model. Check <a href="https://github.com/nateraw/huggingface-sync-action" target="_blank">hugging face sync action</a>, and <a href="https://huggingface.co/docs/hub/webhooks" target="_blank">HuggingFace Webhooks</a></li>
<li>Study cases for popular deep learning code repositories such as the
<a href="https://github.com/openai/whisper" target="_blank">OpenAI whisper model</a> and <a href="https://github.com/karpathy/nanoGPT" target="_blank">nanoGPT</a>.</li>
</ol>
<h3>Conclusion</h3>
<p>Using W&amp;B and Hugging Face makes projects like the one above a lot easier to manage, reproduce, and understand. Having the code ready in a Colab gives us GPU access and makes running discrete steps a breeze.</p>
<p>I hope this piece helps you as you consider how best to experiment in your project. If you have any questions, feel free to drop them in the comments below. Thanks!</p>
]]></content>
  </entry>
  <entry>
    <title>Ukiyo-e style postcard generator App</title>
    <link href="https://alkzar.cl/posts/ukiyo-e-style-postal-generator-app/"/>
    <id>https://alkzar.cl/posts/ukiyo-e-style-postal-generator-app/</id>
    <published>2023-02-02T00:00:00Z</published>
    <updated>2023-02-02T00:00:00Z</updated>
    <content type="html"><![CDATA[<p><img src="https://collectionapi.metmuseum.org/api/collection/v1/iiif/55735/140194/main-image" alt="Source: Ejiri in Suruga Province (Sunshū Ejiri), from the series Thirty-six Views of Mount Fuji (Fugaku sanjūrokkei)" /></p>
<a href="https://colab.research.google.com/drive/1F7SH4T9y5fJKxj5lU9HqTzadv836Zj_G?usp=sharing" target="_blank">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
<p>How can we train a model to generate novelty images while preserving the style? What is a generative model? what is a diffusion model? Is there a way to control the image generation process (aka sampling)? In this short post, I will answer these questions at a high level with a mini project implementation, the <a href="https://huggingface.co/spaces/alkzar90/ukiyo-e-postal" target="_blank">Ukiyo-e style postcard generator App</a>.</p>
<p>The project, in a nutshell, consists of the following:</p>
<ol>
<li>Take a bunch of images with the Ukiyo-e style</li>
<li>Trained (finetune really) a diffusion model to learn a distribution of our images</li>
<li>Use the learned "unconditional distribution" to generate novel images with ukiyo-e stylish</li>
<li>Explore ways to gain control when we sample new pictures from our distribution</li>
<li>Wrap the whole inference pipeline: model (distribution) + image generation (sampling) into a Gradio App</li>
</ol>
<h2>What is Ukiyo-e style?</h2>
<p>Pictures-of-the-floating-world..., that's what means the Japanese word Ukiyo-e. It's a term to refer to an entire art gender, so don't be confused that it is an artist's name or a pseudonym. Its roots are popular, even vulgar, and accessible, so it was easy to find Ukiyo-e prints in Japanese houses. It was a genre that took influence from the west and china, but at some point, it influenced Europe and the western world. If you want to know more about this Japanese art genre, I highly recommend this interactive piece from the New York Times titled <i><a href="https://www.nytimes.com/interactive/2020/08/07/arts/design/hokusai-fuji.html" target="_blank">A Picture of Change for a World in Constant Motion" (Farago 2020)</a></i>. It's a journey from this popular genre's history and distinctive elements.</p>
<p>The dataset <a href="https://huggingface.co/datasets/huggan/ukiyoe2photo" target="_blank">ukiyoe2photo</a> contains pictures of the floating world, which I used in this project to learn a model to generate new images with ukiyo-e stylish. The below figure shows some of the ukiyo-e images in the dataset.</p>
<center>
<img src="/img/ukiyo-e-postalcard-app/ukiyoe-dataset-portrait-lightweight.png">
</center>
<br>
<h2>Unconditional image generation</h2>
<p><a href="https://huggingface.co/tasks/unconditional-image-generation" target="_blank">What is unconditional image generation?</a></p>
<blockquote>
<p><em>Unconditional image generation is the task of generating images with no condition in any context (like a prompt text or another image). Once trained, the model will create images that resemble its training data distribution.</em></p>
</blockquote>
<p>My goal here is to highlight the main elements involved in a
<a href="https://arxiv.org/pdf/2006.11239.pdf" target="_blank">Denoising Diffusion Probabilistic Model (Ho 2020)</a>, or DDPM for short, and its training dynamic. I will write a more detailed and technical post about this model in the future.</p>
<p>Let's start with the training dynamic; a diffusion model consists of two chain processes of the same number of steps, <span class="math inline"><math display="inline"><mi>T</mi></math></span>.</p>
<ol>
<li>A <strong>forward process</strong>: in which the model takes an input image and gradually destroys it, adding gaussian noise until the entire image structure is reduced to just noise</li>
<li>And a <strong>backward process</strong>: an inverse process encoded by a distribution with learnable parameters whose goal is to predict the noise added in each transition step (i.e. denoising)</li>
</ol>
<center>
<figure>
  <img src="/img/ukiyo-e-postalcard-app/calvin-luo-2022-figure-5.png">
  <figcaption> Source: <a href="https://arxiv.org/abs/2208.11970" target="_blank">Understanding Diffusion Models: A Unified Perspective (Luo 2022)</a></figcaption>
</figure>
</center>
<br>
<p>There are important elements to take into account:</p>
<ul>
<li>We use gaussian distributions for both processes, yeah the lovely normal distribution <span class="math inline"><math display="inline"><mrow><mi>𝒩</mi></mrow><mo symmetric="false" stretchy="false">(</mo><mi>µ</mi><mo>,</mo><mi>σ</mi><mo symmetric="false" stretchy="false">)</mo></math></span></li>
<li>Both processes are markovian; that means the distribution of a given step t depends on the immediately previous state <span class="math inline"><math display="inline"><mi>t</mi><mo>−</mo><mn>1</mn></math></span> or <span class="math inline"><math display="inline"><mi>t</mi><mo>+</mo><mn>1</mn></math></span> in the case of the backward process</li>
<li>The gaussian distribution of the forward process (<span class="math inline"><math display="inline"><mi>q</mi><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo symmetric="false" stretchy="false">)</mo></math></span> in the above diagram) has fixed parameters; in other words, we don't have to learn any parameter here</li>
<li>In contrast, the gaussian distribution for the backward process (<span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mi>θ</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo symmetric="false" stretchy="false">)</mo></math></span> in the diagram) has learnable parameters <span class="math inline"><math display="inline"><mrow><mi>𝛉</mi></mrow></math></span></li>
<li>We used a noise schedule to destroy the images; this means that we have a deterministic function to inject the amount of noise during the <span class="math inline"><math display="inline"><mi>T</mi></math></span>-length forward process</li>
<li>We used a neural network architecture (e.g. U-net) to predict the parameters of the backward gaussian distribution <span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mrow><mi>𝛉</mi></mrow></mrow></msub></math></span></li>
<li>If we have a process of length <span class="math inline"><math display="inline"><mi>T</mi><mo>=</mo><mn>1000</mn></math></span> (like Ho 2020), <span class="math inline"><math display="inline"><msub><mrow><mi>𝐱</mi></mrow><mrow><mn>0</mn></mrow></msub></math></span> is the image input, <span class="math inline"><math display="inline"><msub><mrow><mi>𝐱</mi></mrow><mrow><mn>1000</mn></mrow></msub></math></span> is pure noise of the same input resolution, and any intermediate <span class="math inline"><math display="inline"><msub><mrow><mi>𝐱</mi></mrow><mi>t</mi></msub></math></span> with <span class="math inline"><math display="inline"><mn>0</mn><mo><</mo><mi>t</mi><mo><</mo><mi>T</mi></math></span> is a latent state; a blend of some degree between the input image and noise level (given by the noise scheduler)</li>
<li>The last means that every latent space has the same resolutions as the input (a difference from other variational autoencoder models)</li>
<li>The advantage of using a noise scheduler + gaussian distribution with known parameters <span class="math inline"><math display="inline"><mi>q</mi><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo symmetric="false" stretchy="false">)</mo></math></span> is that we have a closed expression to compute the latent state at any given level <span class="math inline"><math display="inline"><mi>t</mi></math></span> (we don't need to calculate the entire chain from <span class="math inline"><math display="inline"><mn>0</mn></math></span> to <span class="math inline"><math display="inline"><mi>t</mi></math></span>).</li>
</ul>
<p>Back to the training dynamic, a one parameter update cycle looks like the
following:</p>
<ol>
<li>Get a batch of images <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">x</mi></mrow></math></span> from our dataset <code>(batch_size, width, height)</code></li>
<li>Sample random gaussian noise <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">ϵ</mi></mrow></math></span> for each image <code>(batch_size, width, height)</code></li>
<li>Pick a random vector <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">t</mi></mrow></math></span>, which determine in which part of the forward process we are for each of the images <code>(batch_tize,)</code> (think we are extending  parallel non-equal length chains)</li>
<li>Compute the latent state <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">z</mi></mrow></math></span>: each image in the batch use their
corresponding <span class="math inline"><math display="inline"><mi>t</mi></math></span> level in <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">t</mi></mrow></math></span> with the closed expression. The
noise scheduler knows how to blend the image with the noise at the right level <code>(batch_size, width, height)</code></li>
<li>Pass <span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">z</mi></mrow></math></span> for the model <span class="math inline"><math display="inline"><msub><mi>p</mi><mrow><mrow><mi>𝛉</mi></mrow></mrow></msub></math></span> to get the predicted noise <span class="math inline"><math display="inline"><msup><mi>ϵ</mi><mrow><mi>∗</mi></mrow></msup></math></span></li>
<li>Using the mean squared error as a loss function, we compare the actual noise (<span class="math inline"><math display="inline"><mrow><mi mathvariant="normal">ϵ</mi></mrow></math></span>) with the predicted noise (<span class="math inline"><math display="inline"><msup><mrow><mi mathvariant="normal">ϵ</mi></mrow><mrow><mi>∗</mi></mrow></msup></math></span>). Remember that we know the actual noise beforehand because our noise schedule needs it to blend it with the image and create the latent state.</li>
<li>Backprop to compute the gradients</li>
<li>Update the parameters in the direction that minimizes the loss</li>
</ol>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">num_epochs</span></span> <span class="keyword operator assignment python">=</span> <span class="constant numeric integer decimal python">1</span>  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> number of epochs
</span><span class="meta qualified-name python"><span class="meta generic-name python">lr</span></span> <span class="keyword operator assignment python">=</span> <span class="constant numeric float python">1e-5</span>  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> learning rate
</span><span class="meta qualified-name python"><span class="meta generic-name python">grad_accumulation_steps</span></span> <span class="keyword operator assignment python">=</span> <span class="constant numeric integer decimal python">2</span>  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> how many batches to accumulate the gradient before the update step
</span><span class="meta qualified-name python"><span class="meta generic-name python">optimizer</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">torch</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">optim</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">AdamW</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">image_pipe</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">unet</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">parameters</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">lr</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">lr</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">epoch</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">num_epochs</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
    <span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">step</span><span class="punctuation separator target-list python">,</span> <span class="meta generic-name python">batch</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tqdm</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">enumerate</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">train_dataloader</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">total</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">len</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">train_dataloader</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> get a batch image 
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span></span> <span class="keyword operator assignment python">=</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">batch</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">images<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">to</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">device</span></span></span><span class="punctuation section arguments end python">)</span></span>
        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Sample noise to add to the images
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">noise</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">torch</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">randn</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">shape</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">to</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">device</span></span></span><span class="punctuation section arguments end python">)</span></span>
        <span class="meta qualified-name python"><span class="meta generic-name python">bs</span></span> <span class="keyword operator assignment python">=</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">shape</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Sample a random timestep for each image
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">timesteps</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">torch</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">randint</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
            <span class="constant numeric integer decimal python">0</span><span class="punctuation separator arguments python">,</span>
            <span class="meta qualified-name python"><span class="meta generic-name python">image_pipe</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">num_train_timesteps</span></span><span class="punctuation separator arguments python">,</span>
            <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">bs</span></span><span class="punctuation separator tuple python">,</span><span class="punctuation section group end python">)</span></span><span class="punctuation separator arguments python">,</span>
            <span class="variable parameter python">device</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">device</span></span><span class="punctuation separator arguments python">,</span>
        </span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">long</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span>

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Add noise to the clean images according to the noise magnitude at each timestep
</span>        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> (this is the forward diffusion process)
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">noisy_images</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">image_pipe</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">add_noise</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">clean_images</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">noise</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">timesteps</span></span></span><span class="punctuation section arguments end python">)</span></span>

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Get the model prediction for the noise
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">noise_pred</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">image_pipe</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">unet</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">noisy_images</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">timesteps</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">return_dict</span><span class="keyword operator assignment python">=</span><span class="constant language python">False</span></span><span class="punctuation section arguments end python">)</span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Compare the prediction with the actual noise:
</span>        <span class="meta qualified-name python"><span class="meta generic-name python">loss</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">F</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">mse_loss</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
            <span class="meta qualified-name python"><span class="meta generic-name python">noise_pred</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">noise</span></span>
        </span><span class="punctuation section arguments end python">)</span></span>  

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Update the model parameters with the optimizer based on this loss
</span>        <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">loss</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">backward</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">loss</span></span></span><span class="punctuation section arguments end python">)</span></span>

        <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Gradient accumulation:
</span>        <span class="meta statement if python"><span class="keyword control flow conditional python">if</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">step</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span> <span class="keyword operator arithmetic python">%</span> <span class="meta qualified-name python"><span class="meta generic-name python">grad_accumulation_steps</span></span> <span class="keyword operator comparison python">==</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation section block conditional python">:</span></span>
            <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">optimizer</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">step</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span>
            <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">optimizer</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">zero_grad</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>It looks straightforward, but there are many theoretical building blocks to end using a mean squared error loss. The code above uses the hugging face <a href="https://huggingface.co/docs/diffusers/index" target="_blank">diffuser library 🧨</a>, so some lines are more complex—for instance, <a href="https://huggingface.co/docs/diffusers/api/schedulers/overview" target="_blank"><code>image_pipe.scheduler_add_noise</code></a> knows exactly how to blend the images with the noise to get a determined latent state at <span class="math inline"><math display="inline"><mi>t</mi></math></span>.   It's initialized before with the <span class="math inline"><math display="inline"><mi>T</mi></math></span> length, noise schedule type, etc. The object <code>image_pipe.unet</code> contains the neural network architecture to process the images; remember that the latent space is of the same shape that the input (i.e. image). The last explained the decision by the authors to choose a <a href="https://arxiv.org/abs/1505.04597" target="_blank">U-net architecture</a>, well known because the output has the exact dimensions as the input.</p>
<p>Training a generative model such as DDPM takes a long time and requires a lot of images. Instead, we can get fair results without much training and pictures using the same approach but with a pre-trained model as a starting point, such as <a href="https://huggingface.co/google/ddpm-celebahq-256" target="_blank">Google/ddpm-celebahq-256</a>. Of course, we need to make some compromises to the model resolutions we are using to finetune our data; the Google model was trained using a 256px resolution.</p>
<p>Now we can use the model to sample 12 postcards:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">torch</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">randn</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">12</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">3</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">256</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">256</span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">to</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">device</span></span></span><span class="punctuation section arguments end python">)</span></span>  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Batch of 12 
</span><span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">i</span><span class="punctuation separator target-list python">,</span> <span class="meta generic-name python">t</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tqdm</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">enumerate</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">timesteps</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">model_input</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">scale_model_input</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">t</span></span></span><span class="punctuation section arguments end python">)</span></span>
    <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">torch</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">no_grad</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block with python">:</span></span>
        <span class="meta qualified-name python"><span class="meta generic-name python">noise_pred</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">image_pipe</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">unet</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">model_input</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">t</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">sample<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">scheduler</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">step</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">noise_pred</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">t</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">prev_sample</span>
</span></code></pre>
<p>Notice that we need to denoise the gaussian random noise to get samples from the model throughout the backward process chain. In the next section, we will add some complexity to this inference pipeline beyond the model and the noise scheduler. You can find the model used to generate the images in <a href="https://huggingface.co/alkzar90/sd-class-ukiyo-e-256" target="_blank">alkzar90/sd-class-ukiyo-e-256</a>.</p>
<p><img src="https://huggingface.co/alkzar90/sd-class-ukiyo-e-256/resolve/main/ukyo-e-portrait.jpeg" alt="" /></p>
<p>There are psychodelich images, dreamy ones without some realistic object, but still with an artistic appeal. Observe the opaque pastel colours characteristic of the Ukiyo-e style.</p>
<h2>Classifier Guidance</h2>
<p>What is the guidance technique? We can take this unconditional image generation process and guide it toward images that have a desired attribute or property of interest we want. It could be like conditioning the image distribution hackily because we don't need to learn the conditional distribution like in popular text-to-image models such as <a href="https://stability.ai/blog/stable-diffusion-public-release" target="_blank">Stable Diffusion</a> or <a href="https://openai.com/dall-e-2/" target="_blank">Dall-E2</a>; we reuse the same model. Instead, we guide the sampling process (or denoising) by introducing a loss function that measures the property of interest in the sample generated, orienting the denoising process to minimize the designed objective. Specifically, in the Ukiyo-e style postcard generator App, we used:</p>
<ul>
<li>Colour guidance: samples images that have a (guess) desired colour</li>
<li>Text prompt guidance: sample images according to a text description</li>
</ul>
<p>So, before we look that generating a new image, or sampling, means taking a pure gaussian noise and passing through this denoising process. But we do it in many little steps; you take the noise and denoising slowly, passing through an entire stochastic Markovian process of saying T=1000 steps (<a href="https://arxiv.org/pdf/2006.11239.pdf" target="_blank">Ho 2020</a>,). When we use guidance, we append a gradient graph to the sample tensor and calculate the loss for the sample in each state w.r.t our objective attribute (e.g. colour or text prompt). Then we compute the gradient and update the tensor in the direction that minimizes the loss, guiding the following sample image to be more appealing to the attribute that measures the loss. The process is called classifier guidance, and it was introduced in the paper <a href="https://arxiv.org/pdf/2105.05233.pdf" target="_blank">Diffusion Models Beat GANs on Image Synthesis (Dhariwal 2021)</a>.</p>
<p>Now a beautiful image of a sakura tree...</p>
<center><img src="https://media.istockphoto.com/vectors/sketchy-little-tree-spring-vector-id92725461?b=1&k=20&m=92725461&s=170667a&w=0&h=xmfz8Gqa7-Gb0FtKiwPu0ZclWYm6WJUZ1pjec_BVj2I=" alt="A sakura tree"></center>
<br>
<p>We will use <em>"a sakura tree"</em> as a text prompt and the same starting gaussian noise as above, for which we got 12 ukiyo-e-postcard. But, this time, we will drive the sampling process using the gradients we get by comparing the sample image text encoding in each step <span class="math inline"><math display="inline"><mi>t</mi></math></span> with the text prompt encoding vector.</p>
<p><img src="https://huggingface.co/alkzar90/sd-class-ukiyo-e-256/resolve/main/ukyo-e-sakura-tree.jpeg" alt="" /></p>
<p>Compare these new postcards with the others. Notice how the text guidance makes that sakura tree's pattern emerge during the denoising process: branches and pink leaves here and there, with hallucination touches. The model is neither perfect nor trained with surgery hands, but it's still amazing that we have gained some control over the sampling process. Behind the curtains, to generate the above images, the following steps happen:</p>
<ol>
<li>Download a pre-trained model to generate image captions such as OpenAI <a href="https://openai.com/blog/clip/" target="_blank">CLIP model</a></li>
<li>Pass the noise tensor to the model to get an encoding vector</li>
<li>Use a loss function that compares the vector for the current sample state w.r.t. encoding for the text prompt; the last vector is always the same because the prompt doesn't change during the process</li>
<li>Compute the sample gradient w.r.t loss value and update the sample values in the direction that minimizes the loss</li>
<li>Repeat the process during the whole stochastic Markovian process</li>
</ol>
<p>Generally, any designed loss uses a scale factor to increase/decrease the attribute effect. It allows you to move between novelty and fidelity. Moreover, there's nothing to block you from using more than one objective; the Ukiyo-e postcards generator uses colour, and text prompts together as guidance, and each loss contributes to accumulating gradients that modify the sample in each denoising step. Of course, there could be some gradient interaction effects. Imagine that you want green images, but at the same time, you are using the text prompt "a volcano lava" you will put against the red/brown implicit colours in the text prompt with the green one.</p>
<h2>Wrap the inference pipeline into an App</h2>
<p>Now that we have a trained model and know how to control the sampling process of the unconditional distribution, <strong>why not wrap the entire inference pipeline into a friendly interface such as a <a href="https://gradio.app/" target="_blank">Gradio App</a>?</strong></p>
<p>Let's start by setting some context; Gradio is a framework that allows us to build machine-learning apps pretty fast based on the task for which our model was designed. For instance, in a generative image model, in which the inference process requires fixing different parameter types such as factor scale (slider), text prompt (text input), or colour (colour selector), plus we always expect an image as output. Gradio helps us accommodate these requirements, hiding many tedious details that save you valuable time.</p>
<p>In addition, Hugging Face gives you free space to host your App running on a CPU (you can power up the running using a GPU, but you need to pay). It's an excellent combo to provide you with a friendly web interface with minimal resources. There is a trend in the generative model community for using this interface to show and prototype their models' properties and features, such as <a href="https://huggingface.co/spaces/stabilityai/stable-diffusion" target="_blank">Stable Diffusion 2.1 Demo</a> or <a href="https://huggingface.co/spaces/huggingface-projects/diffuse-the-rest" target="_blank">Diffuse the rest</a>. Therefore, taking time to learn Gradio is an excellent decision to share your project with a friendly facade.</p>
<p><a href="https://huggingface.co/spaces/alkzar90/ukiyo-e-postal/blob/main/app.py" target="_blank">Here is the python script</a> with the pipeline and the
App code, also a screenshoot of how does it look the <a href="https://huggingface.co/spaces/alkzar90/ukiyo-e-postal" target="_blank">Ukiyo-e style postcard generator App</a>:</p>
<center>
<img src="/img/ukiyo-e-postalcard-app/ukiyo-e-app.png">
</center>
<br>
<p>As you can see, it is easier to pick the colour you want and enter a text prompt to guide the sampling; control the scale factors for experimenting with different guidance intensity levels. Using a seed is convenient for reproducing the gaussian noise, which the denoising process use as a starting point for generating the images, so you can edit iteratively the image generated by playing with the scale factors.</p>
<p>PD: Unfortunately, this kind of model involves a lot of computation, and the gradio App is running using huggingface space with CPU, so that means the whole inference pipeline takes a lot of time. But, the good news is if you don't have patience, you can run the Gradio App using the google colab notebook pointing out at the start of this post with the GPU setting.</p>
]]></content>
  </entry>
  <entry>
    <title>Deep Learning for Coders - notas capítulo 2</title>
    <link href="https://alkzar.cl/posts/fastai-chapter-2/"/>
    <id>https://alkzar.cl/posts/fastai-chapter-2/</id>
    <published>2022-08-11T00:00:00Z</published>
    <updated>2022-08-11T00:00:00Z</updated>
    <content type="html"><![CDATA[<h2>Data Augmentation</h2>
<p>Uno de los puntos interesantes del capítulo es la introducción del
conjunto de técnicas bajo el nombre de <em>data augmentation</em> ✨. Es una
idea simple pero ingeniosa ya que no va en la línea directa (y más obvia) de
mejorar el desempeño del modelo a través del diseño de la arquitectura, sino
el foco se mueve a los datos. Parte esencial pero definitivamente no la más
popular (?). El punto es que cualquier sistema bajo la categoria de
<em>machine learning</em> (incluyendo <em>deep learning</em>) tiene una relación
indisoluble con los datos, y ocupar <em>data augmentation</em> es una forma de
incrementar la diversidad de nuestro <em>dataset</em> usando data sintetica
creada a partir del <em>dataset</em> original.</p>
<center>
<img src="/img/fastai-chapter-2/data_augmentation_example.png">
</center>
<p>Enfocandonos en datos que son imagenes, la generación de nuevos datos
se realiza al aplicar transformaciones sobre una imagen. Arriba se observan
distintas variaciones de la misma imagen, vemos como la imagen de un elfo
sufré alteraciones como rotaciones, saturación del color, etcétera. Esto
ayuda aumentar artificialmente la variación en nuestros datos, y si todo sale
bien lograr una mejora en la generalización del modelo. Intuitivamente esto
se podría explicar porque la saturación de color ayuda a evitar que el modelo
dependa mucho del color verde en detectar elfos y permitir identificarlos en otros
entornos menos comunes pero probables (e.g. montañosos con paleta más cargada a
colores tierra), o incorporar mayor diversidad en las poses de los dibujos de
elfos que el algoritmo ve durante el entrenamiento, logrando así disminuir el
sesgo de las poses más comunes con que se dibujan y representan a los elfos.</p>
<p>La librería <code>fastai</code> operativamente implementa las transformaciones de imagen
durante cada época, donde con alguna probabilidad se aplica una o más
perturbaciones sobre la imagen, o se muestra su versión de la original.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">db</span></span>  <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DataBlock</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
    <span class="variable parameter python">blocks</span> <span class="keyword operator assignment python">=</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">ImageBlock</span></span><span class="punctuation separator tuple python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">CategoryBlock</span></span><span class="punctuation section group end python">)</span></span><span class="punctuation separator arguments python">,</span>
    <span class="variable parameter python">get_items</span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">get_image_filees</span></span><span class="punctuation separator arguments python">,</span>
    <span class="variable parameter python">splitter</span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">RandomSplitter</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">valid_pct</span><span class="keyword operator assignment python">=</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">seed</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">42</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>
    <span class="variable parameter python">get_y</span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">parent_label</span></span><span class="punctuation separator arguments python">,</span>
    <span class="variable parameter python">item_tfms</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">Resize</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">128</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>
    <span class="variable parameter python">batch_tfms</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">aug_transforms</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">mult</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span>
    </span><span class="punctuation section arguments end python">)</span></span>
    
<span class="meta qualified-name python"><span class="meta generic-name python">dls</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">db</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dataloaders</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">images_path</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>Recordemos que el objeto <code>DataLoaders</code> es una instancia de la clase encargada de
proveer <em>mini-batches</em> al algoritmo durante el entrenamiento y enviarlas a la GPU.
Por lo tanto, las transformaciones especificadas en el argumento <code>batch_tfms</code>
en <code>DataBlock</code> son las que se ejecutaran en la GPU para todo las imagenes
del <em>mini-batch</em>, transformaciones que van aplicandose época-tras-época.</p>
<p>Si bien el segundo capítulo no detalle mucho más, en la <a href="https://docs.fast.ai/vision.augment.html">documentación de <code>fastai</code></a> se puede encontrar
más información de cómo aplicar varias transformaciones y combinarlas. Además,
hay metodologias para encontrar el conjunto de transformaciones más óptimo para un
<em>dataset</em> especifico como la detallada en el <em>paper</em>
<a href="https://paperswithcode.com/paper/autoaugment-learning-augmentation-policies"><em>AutoAugment: Learning Augmentation Policies from Data</em> (Cubuk 2018)</a>.
De hecho, <a href="https://pytorch.org/vision/main/generated/torchvision.transforms.AutoAugment.html">el modulo <em>vision</em> de PyTorch, en <code>torchvision.transforms.AutoAugment</code></a>
se encuentran los conjuntos de transformaciones óptimos según la metodologia del
<em>paper</em> anterior para los <em>datasets</em>: Imagenet, CIFAR10, y SVHN. Una
alternativa es ocupar alguna de estas transformaciones <em>versus</em> ocupar
transformaciones arbitrariamente definidas.</p>
<p>Finalmente terminar con cómo interpretar teorícamente la técnica de
<em>data augmentation</em>. Existe una justificación bayesiana cuya línea argumentativa
es tratada en el <em>post</em>
<a href="https://statmodeling.stat.columbia.edu/2019/12/02/a-bayesian-view-of-data-augmentation/"><em>A Bayesian view of data augmentation</em> (O'Rourke 2019)</a>
, y también hay una breve sección en la nueva edición del libro de <a href="https://probml.github.io/pml-book/book1.html">Kevin Murphy
página 622</a>, citó del libro:
<em>"the data augmentation mechanism can be viewed as a way to algorithmically inject
prior knowledge"</em> 💉🧠.</p>
<h2>Cuestionario</h2>
<ol>
<li>
<p>Where do text models currently have a major deficiency?</p>
<ul>
<li><strong>R</strong>: Si bien los modelos de texto son buenos generando prosa apropiada
al contexto, estos modelos no son consistentes ni capaces de garantizar
respuestas correctas.</li>
</ul>
</li>
<li>
<p>What are possible negative societal implications of text generation models?</p>
<ul>
<li><strong>R:</strong> Los modelos de generación de texto reproducen los sesgos implicitos
contenidos bajo los textos en que fueron entrenados. Un impacto social
negativo es reproducir y amplificar este tipo de sesgos debido a la facilidad
con que pueden escalar ya sea en redes sociales u otras aplicaciones.</li>
</ul>
</li>
<li>
<p>In situations where a model might make mistakes, and those mistakes could
be harmful, what is a good alternative to automating a process?</p>
<ul>
<li><strong>R:</strong> Utilizar un sistema en conjunto con un experto, el primero
entrega recomendaciones, o alternativas como predicciones, y el experto
puede utilizarlas para complementar su análisis o para validar los resultados
evitando cometer errores y a la vez tomando ventaja de un sistema de apoyo.
Siempre se puede descartar la sugerencia del modelo si no es pertinente.</li>
</ul>
</li>
<li>
<p>What kind of tabular data is deep learning particularly good at?</p>
<ul>
<li><strong>R:</strong> Data tabular que contiene columnas con texto (e.g. comentarios de
clientes o <em>reviews</em> sobe una película) u otra información tabularizada
pero no estructurada (e.g. imagen de avatar de los usuarios en un foro).</li>
</ul>
</li>
<li>
<p>What's a key downside of directly using a deep learning model for
recommendation systems?</p>
<ul>
<li><strong>R:</strong> Los sistemas de recomendación son buenos entregando recomendaciones
que le pueden gustar al usuario pero no necesariamente son opciones de utilidad.
Si un sistema de recomendación me entrega vinilos de artistas que ya
conozco, no hay mucho valor en estas opciones, porque es muy probable que
ya conozca todas las alternativas y no necesite un sistema de recomendación
para eso.</li>
</ul>
</li>
<li>
<p>What are the steps of the Drivetrain Approach?</p>
<p>i. Objetivo: ¿Qué buscamos lograr con nuestro producto de datos?
ii. Levers: ¿Qué <em>input</em> podemos controlar para lograr nuestro objetivo?
iii. Data: ¿Qué datos disponemos o podemos adquirir que sean relevantes para
llevar acabo las acciones y cumplir el ojetivo?
iv. Modelo: ¿Qué acciones concretas generamos en base a nuestros <em>levers</em> (aka <em>inputs</em>)?</p>
<p><img src="https://github.com/fastai/fastbook/raw/2b8b8a20974baa756e3702778270aa12e0ab046e//images/drivetrain-approach.png" alt="" /></p>
</li>
<li>
<p>How do the steps of the Drivetrain Approach map to a recommendation system?</p>
<ul>
<li>Objetivo: Aumentar las ventas a través de recomendaciones novedosas y
encantadoras para nuestros clientes.</li>
<li>Levers: <em>Rankear</em> las recomendaciones de la mejor forma posible para
lograr el aumento de ventas.</li>
<li>Data: ¿Qué datos necesitamos recolectar para aumentar las ventas? (e.g.
reproducciones de nuevos artistas o información de las compras de última
temporada).</li>
<li>Modelo: Construir dos modelos de probabilidad de compra, uno condicionado
en ver las recomendaciones y otro no. La diferencia entre ambas probabilidades
es la función de utilidad de entregar una recomendación al cliente.</li>
</ul>
</li>
<li>
<p>Create an image recognition model using data you curate, and deploy it on
the web.</p>
<ul>
<li><strong>R:</strong> <a href="https://huggingface.co/spaces/alkzar90/croupier-creature-app">Bestiario</a>
es una simple aplicación para identificar clases de criaturas (i.e. elfos,
trasgos, zombies y caballeros) desde imagenes. En otro post escribire
sobre los pasos y desarrollos del proyecto</li>
</ul>
</li>
</ol>
<iframe src="https://hf.space/embed/alkzar90/croupier-creature-app/+" width="950" height="400"></iframe>
<ol start="9">
<li>
<p>What is <code>DataLoaders</code>?</p>
<ul>
<li><strong>R:</strong> Un <code>DataLoader</code> es una clase auxiliar para implementar la abstracción
de gestionar y proveer datos al modelo. Las 4 líneas de código siguientes son
la funcionalidades básicas de esta clase destacadas en el capítulo:</li>
</ul>
<pre><code class="code lang-python"><span class="source python"><span class="meta class python"><span class="storage type class python">class</span> <span class="entity name class python"><span class="meta generic-name python">DataLoaders</span></span><span class="meta class inheritance python"><span class="punctuation section inheritance begin python">(</span></span></span><span class="meta class inheritance python"><span class="entity other inherited-class python">GetAttr</span><span class="punctuation section inheritance end python">)</span></span><span class="meta class python"><span class="punctuation section class begin python">:</span></span>
<span class="meta function python">    <span class="storage type function python">def</span> <span class="entity name function python"><span class="support function magic python">__init__</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">self</span><span class="punctuation separator parameters python">,</span> <span class="keyword operator unpacking sequence python">*</span><span class="variable parameter python">loaders</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="meta qualified-name python"><span class="variable language python">self</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">loaders</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">loaders</span></span>
<span class="meta function python">    <span class="storage type function python">def</span> <span class="entity name function python"><span class="support function magic python">__getitem__</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">self</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">i</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="keyword control flow return python">return</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="variable language python">self</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">loaders</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">i</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">train</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">valid</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">add_props</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function inline python"><span class="storage type function inline python">lambda</span></span><span class="meta function inline python"><span class="meta function inline parameters python"> <span class="variable parameter python">i</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">self</span></span><span class="punctuation section function begin python">:</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="variable language python">self</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">i</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre></li>
<li>
<p>What four things do we need to tell fastai to create <code>DataLoaders</code>?</p>
<p>i. ¿Con qué tipo de datos vamos a trabajar (e.g. imagenes, audio)? -&gt; <code>blocks=(ImageBlock, CategoryBlock)</code>
ii. ¿Cómo obtener la lista con las observaciones (datos)? -&gt; <code>get_items=get_image_files</code>
iii. ¿Cómo se encuentran etiquetados las observaciones? -&gt; <code>get_y=parent_label</code>
iv. ¿Cómo crear el conjunto de validación? -&gt; <code>splitter=RandomSplitter(valid_pct=0.2, seed=42)</code></p>
</li>
<li>
<p>What does the <code>splitter</code> parameter to <code>DataBlock</code> do?</p>
<ul>
<li><strong>R:</strong> El argumento <code>splitter</code> dentro de <code>DataBlock</code> especifica
el porcentaje de observaciones que serán destinadas al conjunto de validación
además de garantizar la reproducibilidad de los resultados.</li>
</ul>
</li>
<li>
<p>How do we ensure a random split always gives the same validation set?</p>
<ul>
<li><strong>R:</strong> Utilizando un número de semilla (e.g. <code>seed=42</code>) para garantizar
que el generador de números aleatorios produzca la misma secuencia de
valores y por ende resultados.</li>
</ul>
</li>
<li>
<p>What letters are often used to signify the independent and dependent
variables?</p>
<ul>
<li><strong>R:</strong> La letra <span class="math inline"><math display="inline"><mi>y</mi></math></span> se utiliza para representar la variable dependiente
(i.e. <em>output</em>) y la <span class="math inline"><math display="inline"><mi>x</mi></math></span> para las variables independientes (i.e. <em>input</em>).</li>
</ul>
</li>
<li>
<p>What's the difference between the crop, pad, and squish resize approaches?
When might you choose one over the others?</p>
<ul>
<li>
<p><strong>R:</strong> Primero, es importante estandarizar nuestras imagenes para
transformarlas en tensores y que luego puedan ser insumidas por
la arquitectura del modelo. La mayoría de las veces que recolectamos
imagenes en la web, o de diferentes fuentes, notaremos que las imagenes
tendrán distintas dimensiones. ¿Cómo estandarizarlas? Hay distintas formas
y cada una puede tener un impacto en la calidad de nuestros datos.</p>
<ul>
<li><em>Crop</em>: Corta la imagen para generar un cuadrado de la dimensión
requerida usando el largo o ancho completo. Se puede perder información
relevante de la imagen respecto a la dimensión que sea truncada, como
la parte trasera un auto que puede permitir discriminar entre un tipo
de auto 🚓 y otro 🏎️. -&gt; <code>Resize(128)</code></li>
<li><em>Pad</em>: Agregar regiones negras en los bordes para completar las dimensiones,
lo que termina generando información nula que será simplemente pérdida
en recursos computacionales (pensemos en millones de observaciones que
necesitan esta transformación para quedar estandarizadas). -&gt; <code>Resize(128, ResizeMethod.Pad, pad_mode='zeros'))</code></li>
<li><em>Squish</em>: Contraemos o expandemos la imagen para lograr la dimensión
requerida. El problema es que podemos deformar el significado de lo que
representa la imagen, por ejemplo, tenemos una imagen de una tetera de té
🫖 y la debemos expandir para alcanzar las dimensiones requeridas, y la
tetera termina siendo una especie de balon más inflado ⚽  que aimensiones
reales del objeto. -&gt; <code>Resize(128, ResizeMethod.Squish))</code></li>
</ul>
</li>
<li>
<p>¿Cuando escoger una sobre otra? Depende mucho de la naturaleza de las imagenes
y que representan. Imagenes de números y letras pueden ser afectadas si
se recorta alguna parte distintiva de un número particular, un 7 podría ser
muy un 1 si la imagen se cropea de cierta forma. Sin embargo, si estamos
identificando paisajes que son muy distintos (e.g. pradera y oceanos) el
<em>cropping</em> no importar mucho.</p>
</li>
</ul>
</li>
<li>
<p>What is data augmentation? Why is it needed?</p>
<ul>
<li><strong>R:</strong> <em>Data augmentation</em> es un conjunto de técnicas para aumentar de
forma artificial los datos a través de perturbaciones aleatorias sobre
estos sin alterar su significado intrínsico. Por ejemplo, si rotamos o
modificamos la saturación de color de una foto de un perro, esta imagen
continuará siendo la representación de un perro independiente las
transformaciones aplicadas. Es importante notar que en la practica,
cuando se aplican estas perturbaciones, no se aumentan los datos previo
al proceso de entrenamiento. Pensemos que solo entre dos transformaciones
como rotación y saturación de color, el espacio de configuraciones entre el
producto cruz de estas dos operaciones dan lugar a infinitas versiones de una
imagen, sino mas bien durante el entrenamiento, se muestran distintas
versiones de un <em>input</em> agregando mayor variación y diversidad durante
el ajuste de parámetros.</li>
</ul>
</li>
<li>
<p>Provide an example of where the bear classification model might work
poorly in production, due to structural or style differences in the training
data.</p>
<ul>
<li><strong>R:</strong> Los ángulos de las fotos utilizadas para el conjunto de
entrenamiento pueden variar a las obtenidas respecto a la posición
de la camara en parque o lugar en que se utilicé el modelo. Otro problema
pueden ser las variaciones del entorno en producción que no fueron capturadas
en el <em>dataset</em> de entrenamiento apropiadamente como cambios de luminosidad
por estaciones del año.</li>
</ul>
</li>
<li>
<p>What is the difference between <code>item_tfms</code> and <code>batch_tfms</code>?</p>
<ul>
<li><strong>R:</strong> La diferencia entre <code>item_tfms</code> y <code>batch_tfms</code> es que el
primero se aplica previo al proceso de entrenamiento a modo de
pre-proceso de imagenes (e.g. estandarizar todas las imagenes a
ciertas dimensiones como 128x128) y utiliza la CPU. En cambio, <code>batch_tfms</code>
se aplica cada vez que el <code>DataLoader</code> entrega un <em>mini-batch</em>, o conjunto de
observaciones, al modelo y generalmente se aplican usando la GPU para aplicar
de manera eficiente las transformaciones sobre el <em>mini-batch</em> completo y
que el modelo realice el ajuste de parámetros con las perturbaciones
aleatorias particulares en esa <em>epoch</em>.</li>
</ul>
</li>
<li>
<p>What is a confusion matrix?</p>
<ul>
<li><strong>R:</strong> Una <a href="https://en.wikipedia.org/wiki/Confusion_matrix">matriz de confusión</a>
es una tabla que resume el desempeño predictivo de un modelo de
clasificación. El cálculo de las métricas que contiene debe realizarse sobre
el conjunto de pruebas, observaciones que no fueron utilizadas
durante el proceso de entrenamiento del modelo para dar cuenta sobre la
generalización del modelo en datos que nunca ha visto. Abajo hay un
ejemplo de matriz de confusión sobre un modelo de imagenes que
busca clasificar entre 10 tipos de vestimentas del <a href="https://huggingface.co/datasets/fashion_mnist"><em>dataset FashionMNIST</em></a>.
La diagonal representa el
<em>accuracy</em> para cada una de las clases, mientras más blanco el color de la
diagonal mejor, en este caso mayor número de imagenes de prendas fueron
correctamente clasificadas en su categoría. En cambio, las celdas que no son
parte de la diagonal representan el error que el modelo incurrió clasificando
respecto a las 9 clases restantes. En particular se observa que el modelo
presenta mayores dificultades en clasificar imagenes de <em>shirt</em>: <span class="math inline"><math display="inline"><mn>75</mn></math></span>% de
<em>accuracy</em> y en la mayoría de los casos las confunde con <em>T-shirt/top</em> y
<em>coat</em> con un error de <span class="math inline"><math display="inline"><mn>8.4</mn></math></span>% y <span class="math inline"><math display="inline"><mn>6.8</mn></math></span>% respectivamente.</li>
</ul>
</li>
</ol>
<center>
<img src="/img/fastai-chapter-2/confusion_matrix.png">
</center>
<ol start="19">
<li>
<p>What does export save?</p>
<ul>
<li><strong>R:</strong> El comando <code>learn.export()</code> guarda un archivo con extension
<code>.pkl</code> con el valor de los parámetros entrenados y la arquitectura del
modelo para cargarlo e instanciarlo posteriormente. Un <a href="https://docs.python.org/3/library/pickle.html">archivo <code>.pkl</code></a> es
un archivo pickle creado por un modulo de python que serializa objetos
en una serie de <em>bites</em>.</li>
</ul>
</li>
<li>
<p>What is it called when we use a model for making predictions, instead of
training?</p>
<ul>
<li><strong>R:</strong> Cuando utilizamos un modelo para realizar predicciones se
le conoce por inferencia, esta siendo utilizado como programa y no en modo
de entrenamiento o ajuste. No confundir con el término estadístico.</li>
</ul>
</li>
<li>
<p>What are IPython widgets?</p>
<ul>
<li><strong>R:</strong> Los <em>widgets</em> de IPython es una forma de utilizar javascript
en el contexto de jupyter notebook. Recordemos que cuando trabajamos
con jupyter notebook tenemos un servidor local corriendo detrás, por lo
que podemos tomar ventajas de tecnologías web.</li>
</ul>
</li>
<li>
<p>When would you use a CPU for deployment? When might a GPU be better?</p>
<ul>
<li><strong>R:</strong> Si el modelo no requiere capacidad para responder a un gran
flujo de consultas el uso de CPU para el <em>deployment</em> es recomendable
por su costo y administración. Evitando el gasto innecesario de usar
una GPU para realizar multiples inferencias si la aplicación no
copa la capacidad de esta y las mayores dificultades técnicas de
gestionarlas. Por lo tanto, la ventaja de ocupar GPU es cuando el modelo
recibe un gran número de solicitudes simultaneas para realizar
inferencia y que la GPU puede procesar al mismo tiempo.</li>
</ul>
</li>
<li>
<p>What are the downsides of deploying your app to a server, instead of to a
client (or edge) device such as a phone or PC?</p>
<ul>
<li>Envio de información del dispositivo <em>edge</em> al servidor puede
implicar mayores recursos computacionales para mantener tiempos de latencia
tolerables al cliente.</li>
<li>Temas de privacidad de información y <em>compliance</em> producto de enviar
los datos al servidor.</li>
</ul>
</li>
<li>
<p>What are three examples of problems that could occur when rolling out a
bear warning system in practice?</p>
<p>i. Detectar osos en imagenes capturadas de noche, debido a que el conjunto de
datos de entrenamiento solo contiene imagenes de día, la inferencia sobre
este tipo de observaciones será de mala calidad predictiva.
ii. Que los tiempos de inferencias esten dentro de lo necesario para que el
guardaparques pueda responder de manera oportuna a la alertas. Por más
que el sistema identifique correctamente a los osos, si se demora demasiado
carece de utilidad en producción.
iii. Diferentes posiciones de osos que las camaras puedan captar y que
no se encuentran representadas en el conjunto de entrenamiento. Por lo que
serán ignoradas por el modelo o tendrá una mala calidad predictiva.</p>
<ul>
<li>Nota: recordar que los posibles comportamientos de una red neuronal emergen
del intento del modelo por ajustar el ejemplo que quiere predecir al
conjunto de entrenamiento sobre el cual fue entrenado y que representa una
distribución particular.</li>
</ul>
</li>
<li>
<p>What is out-of-domain data?</p>
<ul>
<li><strong>R:</strong> En general, el concepto hace referencia a datos que difieren respecto
a los datos utilizados para entrenar el modelo, como los descritos en los
ejemplos de la respuesta anterior.</li>
</ul>
</li>
<li>
<p>What is domain shift?</p>
<ul>
<li><strong>R:</strong> Los datos que el modelo insume en producción cambian con el
tiempo, distanciandose cada vez mas respecto del conjunto de datos que se
utilizó para ajustar el modelo y afectando su desempeño sobre nuevas observaciones.
Por ejemplo los gustos en música van adaptandose a nuevas tendencias y estilos
culturales que van emergiendo en cada generación, por lo que un modelo
estático que solo ha sido entrenado una vez y no toma en consideración estos
cambios verá mermada su utilidad en el tiempo.</li>
</ul>
</li>
<li>
<p>What are the three steps in the deployment process?</p>
<p>i. <strong>Proceso manual:</strong> correr modelo en paralelo y revisar todas las predicciones
para tener idea del estado del modelo, así como potenciales problemas y mejoras.
Importante que las predicciones no gatillen ningúna acción automática y el
proceso sea ejecutado de manera manual.
ii. <strong>Lanzamiento con alcance limitado:</strong> modelo en funcionamiento con alcance
limitado y de bajo riesgo. Esto puede ser definido por zona geográfica o
funcionamiento sobre un periodo de tiempo acotado. La constante supervision
humana es importante.
iii. <strong>Expansion gradual:</strong> aumentar el alcance del modelo gradualmente,
se requieren buenos sistema de monitoreo y reporte para detectar cualquier
problema relevante, pensando que ya no tendremos el <em>input</em> de quien realizaba
la ejecución manual respecto a nuevos comportamientos que el proceso debe
tomar en cuenta. Considerar siempre que podria salir mal.</p>
</li>
</ol>
]]></content>
  </entry>
  <entry>
    <title>Deep Learning for Coders - notas capítulo 1</title>
    <link href="https://alkzar.cl/posts/fastai-chapter-1/"/>
    <id>https://alkzar.cl/posts/fastai-chapter-1/</id>
    <published>2022-07-09T00:00:00Z</published>
    <updated>2022-07-09T00:00:00Z</updated>
    <content type="html"><![CDATA[<p>Primer <em>post</em> de una serie de públicaciones sobre la lectura y
resolución del libro <a href="https://course.fast.ai"><em>Deep Learning for Coders with fasti &amp; PyTorch</em></a> de
Jeremy Howard &amp; Sylvain Gugger. Resumen y notas sobre el capítulo 📝, pero
también referencias a material adicional que complementan su lectura.
Además se encuentran mis respuestas al cuestionario y preguntas de
investigación propuestas al final del capítulo.</p>
<h2>Breve historia redes neuronales</h2>
<p>Se define <em>Deep Learning</em> a muy alto nivel como una técnica
computacional para realizar predicciones en base datos usando redes neuronales
compuestas de multiples capas. Cada una de estas capas recibe un <em>input</em> y entrega
un <em>output</em>, así refinando los resultados a medida que la información avanza en la
red. Hay un proceso de entrenamiento guiado por algún algoritmo que busca mínimizar el
error (e.g. SGD, Adagrad, Adam) de las predicciones generadas por el modelo
y el verdadero valor entregado por los datos. Estas redes neuronales profundas se
utilizan en varios campos de investigación como <em>Natural Language Processing (NLP)</em>,
<em>Computer Vision</em>, <em>Image Generation, Robotics</em>, <em>Recommendation Systems</em>, entre otros.</p>
<p>Luego el capítulo construye una breve línea de tiempo sobre los modelos de redes
neuronales.</p>
<ul>
<li>1943: Warren McCulloh y Walter Pitts desarrollan el modelo matemático
de una neurona artificial en el paper <a href="https://www.cs.cmu.edu/~./epxing/Class/10715/reading/McCulloch.and.Pitts.pdf"><em>A logical Calculus of the Ideas
Immanent in Nervous Activity</em></a>.</li>
<li>1957: Frank Rossenblat implementa el primer modelo de neurona artificial
llamado <em>Perceptron</em> con la capacidad de "aprender".</li>
<li>1969:  Marvin Minsky y Seymour Papert publican el libro <a href="https://mitpress.mit.edu/books/perceptrons-expanded-edition">Perceptron</a> sobre
el trabajo de Rossenblat. Demuestran que una capa de estas neuronas es incapaz
de aprender funciones simples como XOR. Sin embargo, en el mismo libro, demuestran como subsanar este problema añadiendo
más capas de neuronas (aka <em>multilayer perceptron</em>).</li>
<li>1970-1985: Disminución importante en investigaciones sobre redes neuronales,
con la excepción de un grupo acotado de investigadores. En el último episodio de la
temporada 2 del <em>podcast</em> <a href="https://open.spotify.com/episode/3GpQhNqRdYgVz1X8vswpB9?si=16bb0e19cbab4116"><em>The Robot Brains Podcast</em></a>,
entrevistan a Geoffrey Hinton, y cuenta una anéctoda sobre la presentación de una
investigación que realizaba en 1973 que utilizaba redes neuronales. Luego de
la presentación, y con una audiencia bastante escéptica, una de las pregunta
que recibió Hinton fue porqué usaba esos métodos, cuando Minsky y Papert
"habían dicho" que no servían (supuestamente en el libro <em>Perceptron</em>).</li>
<li>1986: Se pública el libro <em>Parallel Distributed Processing</em> (PDP)
de varios tomos por David Rumelhart, James McClelland, y Cia. Basandose y profundizandose en los trabajos previos de Rossenblat + Minksky el libro formaliza aún más la teoría y aspectos
de implementación.</li>
<li>2012: El grupo de Geoffrey Hington gana la competencia Imagenet disminuyendo
de forma drástica el error del modelo versus la solución del resto de los
participantes y de las certamenes anteriores.</li>
</ul>
<p>Hay muchos más detalles y contribuciones en la historia de la Inteligencia Artificial
y el uso de redes neuronales, Jürgen Schmidhuber ahonda en esto, ofreciendo
una serie de detalles y referencias interesantes en el siguiente video
estrenado en la conferencia AIJ a finales del año 2020.</p>
<center>
<iframe width="560" height="315" src="https://www.youtube.com/embed/pGftUCTqaGg?start=505" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
</center>
<h2>¿Qué es Machine Learning?</h2>
<p>El capítulo cita y comenta  el ensayo <a href="https://journals.sagepub.com/doi/abs/10.1177/000271626234000103"><em>Artificial Intelligence: A Frontier of Automation</em></a>
de Arthur L. Samuel (1962), quien acuñó el término <em>machine learning</em> y fue
director de investigación en comunicaciones de IBM. El ensayo comienza
dismitificando las consideraciones antropomórficas y afirmaciones
grandilocuentes sobre el campo de la Inteligencia Artificial, y a la vez legitimando
su validez e impacto como disciplina en la resolución de problemas espécificos como
traducir automáticamente del ruso al inglés, capacidad de reconocer dígitos escritos
a mano, o texto escrito de puño y letra, además de la resolución de juegos de mesa
que permiten explorar el diseño de agentes con capacidad de búsqueda e inferencia. El ensayo también acota
el <em>scope</em> en que opera la Inteligencia Artificial respecto
al rol del computador, descartando lo que no es. Arthur plantea
la analogía de que lo que distingue a un buen trabajador
de uno no tan bueno, la capacidad del primero de investigar
y aprender el cómo realizar la tarea, mientras el segundo
debe ser guiado paso a paso en la resolución de esta. Esto
significa que más allá de la complejidad del <em>software</em>--como
cálcular el estrés producido por el viento sobre las alas de una avión--
estaríamos frente a instrucciones detalladas previamente por un programador,
y por lo tanto, sería una inteligencia empaquetada y entregada a la máquina.</p>
<blockquote>
<p>"<em>Programming a computer for such computations is, at best, a difficult task, not
primarily because of any inherent complexity in the computer itself but, rather,
because of the need to spell out every minute step of the process in the most
exasperating detail. Computers, as any programmer will tell you, are giants morons,
not giants brains</em>" (Samuel, pag. 13)</p>
</blockquote>
<p>El objetivo, y la idea de inteligencia artificial de Arthur, era especificar
la tarea a la máquina y que esta pudiera encontrar por su cuenta la solución.
Arthur formula ciertos requerimientos críticos para que una
máquina tenga la capacidad de búscar soluciones, y lo hace
dando el ejemplo de programar a un computador para que juegue
damas. En esencia, una vez que uno tiene la representación
de un tablero en el computador y las reglas que gobiernan el juego, este
puede tomar acciones para generar movimientos y explorar las consecuencias
de distintos estados del tablero. Sin embargo, veremos que la tarea de ir explorando
las posibles secuencias de combinaciones hacía adelante
es un camino sin fin, <a href="https://www.deepmind.com/research/highlighted-research/alphago">pensemos en las <span class="math inline"><math display="inline"><msup><mn>10</mn><mrow><mn>170</mn></mrow></msup></math></span> configuraciones de
tablero que representa el juego Go</a> y que
superan el número de átomos en el universo, un
cometido imposible incluso para un computador. No se debe plantear la búsqueda
en términos de objetivos secundarios (i.e. ir por un caballo, o dar este movimiento)
sino de alguna otra forma.</p>
<blockquote>
<p><em>It is here that we encounter the idea of machine learning. Suppose we arrange for
some automatic means of testing the effectiveness of any current weight assignment in terms of actual performance and provide a mechanism for altering the weight assignment
so as to maximize the performance. We need not go into the details of such a
procedure to see that it could be made entirely automatic and to see that a
machine so programed would "learn" from its experience</em> (Samuel, pag. 17)</p>
</blockquote>
<center>
<img src="/img/fastai-chapter-1/Samuels_Diagram.png">
</center>
<p>El diagrama contiene los conceptos a los que se refiere Arthur, una máquina
dotada con un mécanismo de <em>feedback</em> automático, la experiencia se produce
a través de comparar las etiquetas y predicciones basadas en características
de los datos. Y luego la capacidad de asignar los pesos del programa para
cambiar el estado del programa y guiar la búsqueda de soluciones en dirección a máximizar el desempeño (i.e. tableros ganadores).
Utilizando este paradigma Arthur creo un programa para jugar damas que termino
superando a uno de los campeones estatales en EEUU.</p>
<p>Una lectura complementaria que me recordó el ensayo, e interesante
como mirada actualizada, es un <em>post</em> de Andrej Karpathy
que nombra a la descripcción realizada por Arthur de la máquina
averiguando las instrucciones como <a href="https://karpathy.medium.com/software-2-0-a64152b37c35"><em>software 2.0</em></a>. Eso sí, con la expección de que Karpathy limita el
paradigma exclusivamente a redes neuronales.</p>
<blockquote>
<p><em>"Neural networks are not just another classifier, they represent
the beginning of a fundamental shift in how we develop software.
They are Software 2.0"</em> (Karpathy)</p>
</blockquote>
<p>Karpathy se basa en la comparación de la forma tradicional de
escribir programas, o <em>software 1.0</em>, en donde se diseña el set
de instrucciones para desarrollar una solución, y donde cada
línea escrita por el programador es producto de decisiones
que darán forma a un punto dentro del espacio
de posibles programas. Respecto a una red neuronal, o <em>software
2.0</em>, al cuál se le especifica un objetivo, a
través de pares de <em>input</em> y <em>output</em>, además de un esqueleto de código que será la arquitectura del modelo y definirá el "espacio
del programa" con los posibles detalles a modificar. La red
neuronal a través de un proceso de ajuste de parámetros (i.e.
<em>weight assignment</em> en palabras de Arthur), guíado por su
mécanismo de evaluación (i.e. loss), explorará diferentes
configuraciones dentro del espacio y se quedará con la
solución--encapsulada en los valores de sus parámetros--que mejor
satisfagá el criterio de evaluación. El diagrama a continuación
aparece en el <em>post</em> y es una manera de ilustrar lo anterior:</p>
<p><img src="https://miro.medium.com/max/1400/1*5NG3U8MsaTqmQpjkr_-UOw.png" alt="Fuente: Software 2.0- Andrej Karpathy" /></p>
<blockquote>
<p><em>"To make the analogy explicit, in Software 1.0, human-engineered source code (e.g. some .cpp files) is compiled into a binary that does useful work. In Software 2.0 most often the source code comprises 1) the dataset that defines the desirable behavior and 2) the neural net architecture that gives the rough skeleton of the code, but with many details (the weights) to be filled in. The process of training the neural network compiles the dataset into the binary — the final neural network."</em> (Karpathy)</p>
</blockquote>
<p>Otro tema interesante tratado por Karpathy es que sí entendemos
las redes neuronales no como un simple clasificador, sino como
una nueva formar de pensar el desarrollo de programas, es posible
observar de mejor manera patrones y tendencias que faciliten la
creación de <em>software 2.0</em>. Igual como se utilizan un conjunto de
herramientas para apoyar la creación de <em>software 1.0</em> (i.e. IDE, versionamiento, <em>package
managers</em>). Karpathy escribe que será natural disponer de un
<em>stack</em> para la creación de <em>software 2.0</em>. Lo interesante es que desde la publicación del <em>post</em> el año 2017 hasta la
fecha, han proliferado una serie de herramientas que constituyen
parte del <em>stack</em> que Karpathy vislumbró. Por ejemplo, se nombra:</p>
<ul>
<li>El equivalente a un repositorio para albergar código de <em>software 1.0</em> como GitHub -&gt; En la actualidad contamos con el <a href="https://huggingface.co/datasets"><em>hub</em> de <em>datasets</em> de HuggingFace</a>, una
implementación de lo que Karpathy describe <em>"in this case repositories are datasets and commits are made up of additions and edits of the labels."</em>.</li>
<li>Etiquetar o re-etiquetar <em>inputs</em> para definir
el objetivo del programa. Proyectos como <a href="https://huggingface.co/datasets">Snorkel</a>
se han creado con un enfoque centrado en los datos (<em>weak supervision</em>),
donde para un conjunto de datos sin etiqueta, o sin una calidad de etiquetado
garantizado, es posible utilizar heuristicas en base a juicio
experto para etiquetar de forma programática (i.e. <em>labeling function</em>) los datos.</li>
<li>Algo similar a <em>package managers</em> (e.g. pip, conda) pero con
<em>checkpoints</em> de modelos ya entrenados. De nuevo, <a href="https://huggingface.co/models"><em>hub</em> de
modelos de HuggingFace</a> que
permite fácilmente importar modelos y usar <em>transfer learning</em>
para adaptarlos a nuevas tareas. Esto bajo el parádigma
<em>software 2.0</em> sería usar un programa para escribir otro programa.</li>
</ul>
<h2>¿Qué es una red neuronal?</h2>
<p>Una neurona es un producto punto entre un <em>input</em> (<span class="math inline"><math display="inline"><mrow><mi>𝐱</mi></mrow></math></span>) y un <em>set</em> de parámetros (<span class="math inline"><math display="inline"><mrow><mi>𝐰</mi></mrow></math></span>) más
un coeficiente que se llama <em>bias</em> (<span class="math inline"><math display="inline"><mrow><mi>𝐛</mi></mrow></math></span>). Al resultado de esta operación se le aplica
una función de activación (<span class="math inline"><math display="inline"><mi>σ</mi><mo symmetric="false" stretchy="false">(</mo><mi>⋅</mi><mo symmetric="false" stretchy="false">)</mo></math></span>), por lo que una neurona queda espeficada como:</p>
<p><div class="math display"><math display="block"><msup><mover><mrow><mi>y</mi></mrow><mi>^</mi></mover><mrow><mo symmetric="false" stretchy="false">(</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo>=</mo><mi>σ</mi><mo symmetric="false" stretchy="false">(</mo><msup><mrow><mi>𝐰</mi></mrow><mrow><mi>⊤</mi></mrow></msup><mrow><mi>𝐱</mi></mrow><mo>+</mo><mrow><mi>𝐛</mi></mrow><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<p>Si tenemos una red neuronal, organizamos conjuntos de neuronas capa por capa, por
lo que la información <em>input</em>-<em>output</em> de esta va fluyendo por la red. Si
concatenamos la información de dos neuronas, sería algo como:</p>
<p>\begin{split}
\hat{y}^{(2)} &amp;= \sigma(\boldsymbol w^{\top}\sigma(\boldsymbol w^{\top}\boldsymbol x + \boldsymbol b) + \boldsymbol b) \
&amp;= \sigma(\boldsymbol w^{\top}\hat{y}^{(1)} + \boldsymbol b)
\end{split}</p>
<p>Se le suele llamar <em>hidden layers</em> a las capas internas (i.e. diferentes al
input y output del modelo) dado que su resultado no se observa de forma
directa. Es posible continuar este encadenado de funciones para ir construyendo
modelos con más capas. Sin embargo, analizando la expresión
anterior, una red neuronal perfectamente se podría entender
como un <em>stack</em> de regresiones logisticas.</p>
<p>Un punto importante es que una red neuronal no concatena un conjunto de neuronas
como una simple cadena, o  <em>linked list</em>, sino que por cada capa tenemos varias
neuronas. Los pesos (o parámetros) de estas conexiones ya no serían el vector
<span class="math inline"><math display="inline"><mrow><mi>𝐰</mi></mrow></math></span> sino estarían codificados en una matriz o tensor <span class="math inline"><math display="inline"><mrow><mi>𝐖</mi></mrow></math></span>.</p>
<p>La función de activación tiene dos roles:</p>
<ol>
<li>Las funciones de activación nos permiten tener multitples pendientes para
distintos valores, algo que una función lineal por definición no permite.</li>
<li>La función de la última capa concentra los <em>outputs</em> de la
operación lineal en un rango de valores determinado y requerido según el problema
que estamos resolviendo.</li>
</ol>
<h2>Redes neuronales y aprendizaje de características</h2>
<blockquote>
<p><em>"Attemp have been made to mechanize both of these steps (creation of concepts &amp; weight assignment), but,
to date, very little progress has been made with
respect to the concept-formation step, and most of the workers
have been content to supply man-generated concepts (features)
and to develop machine procedures for assigning weights to these
concepts"</em> (Samuel, pag. 17)</p>
</blockquote>
<p>Una de las ventajas de las redes neuronales, y de porqué
Karpathy se refiere solo a redes neuronales cuando habla
de <em>software 2.0</em>, es la capacidad de aprender representaciones de
los datos. Modelos estadísticos más
tradicionales se enfocan solo en el paso de la asginación de
pesos, o <em>fitting</em>, relegando la representación de los datos
como un paso previo para que el modelo pueda insumir los
conceptos que habla Arthur. Por lo tanto el modelo no tiene
capacidad o no se encuentra en su diseño aprender
características/<em>features</em> de los datos.
En cambio, las redes neuronales con múltiples capas tienen la
capacidad de incorporar dentro del ajuste de parámetros el
aprendizaje de la representación de los datos, el paso de "creación de conceptos"
a la cual se refiere Arthur en la cita del comienzo de esta sección.
Lo que es de gran utilidad para lidiar con datos no estructurados como imagenes o
texto, cuya representación la mayoría de las veces no es trivial
de construir, o en otras palabras, su <em>feature engineering</em> es
prohibitivo. Diferencias interesantes entre <em>deep learning</em> y
estadística más tradicional en el <em>post</em> <a href="https://windowsontheory.org/2022/06/20/the-uneasy-relationship-between-deep-learning-and-classical-statistics/"><em>"The uneasy relationship between deep learning and classical statistics"</em></a>.</p>
<p>En el capítulo se cita y comenta el <em>paper</em> <a href="https://arxiv.org/abs/1311.2901">Visualizing and Understanding Convolutional Networks (Zeiler, Ferguson 2013)</a>
para ejemplificar lo anterior. Creo presentar este artículo es de gran utilidad
porque (i) demuestra que estos modelos no son cajas negras impenetrables y
(ii) es una demostración súper visual de la creación de concepto por parte
de redes neuronales. Además de mostrar la expresibilidad de las capas más
profundas en aprender conceptos de mayor jerarquía en base a conceptos
más primitivos. Los resultados de este estudio permitieron al grupo de
investigación entender mejor el modelo Alexnet que ganó la competencia Imagenet
el 2012, para luego realizar modificaciones en la arquitectura del modelo, y ganar
el certamen el año siguiente.  Acá el <em>abstract</em> del <em>paper</em>:</p>
<blockquote>
<p><strong>Abstract</strong>:
<em>Large Convolutional Network models have recently demonstrated impressive classification performance on the ImageNet benchmark (Krizhevsky et al., 2012). However there is no clear understanding of why they perform so well, or how they might be im- proved. In this paper we address both issues. We introduce a novel visualization technique that gives insight into the function of inter- mediate feature layers and the operation of the classifier. Used in a diagnostic role, these visualizations allow us to find model architec- tures that outperform Krizhevsky et al. on the ImageNet classification benchmark. We also perform an ablation study to discover the performance contribution from different model layers. We show our ImageNet model generalizes well to other datasets: when the softmax classifier is retrained, it convincingly beats the current state-of-the-art results on Caltech-101 and Caltech-256 datasets.</em></p>
</blockquote>
<h2>Entrenar modelos con <code>fastai</code> y <em>transfer learning</em></h2>
<p>Este es un libro práctico y ya dentro del primer capítulo se
realiza una breve demostración de como implementar un modelo de
clasificación. El objetivo es identificar gatos
y perros en imagenes usando la librería fastAI y el Oxford Pet <em>dataset</em>.
Si bien la tarea es simple, lo que encontré más interesante no es
el desempeño del modelo, sino la introducción de la técnica
utilizada para resolver el problema,  <em>transfer learning</em>, que esta
en el <em>core</em> de a API. Esta técnica se basa en utilizar un modelo pre-entrenado,
que ya tuvo un proceso de ajuste de parámetros, para adaptarlo a una nueva tarea.
La ventaja es que ya desde el inicio contamos con capacidad instalada por el modelo
anterior, lo que en algunos casos nos permite transferirla a nuestro nuevo modelo,
y obtener buenos resultados sin la necesidad de contar con demasiados datos.</p>
<p>Al <strong>cargar los datos</strong> se hace hincapíe en el objeto <code>path</code> de python.
Se utiliza un <em>dataloader</em> que es una abstracción utilizada por PyTorch para
gestionar el <em>dataset</em> (i.e. minibatches, etiquetas, etc).</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">fastai</span><span class="punctuation accessor dot python">.</span><span class="meta import-name python">vision</span><span class="punctuation accessor dot python">.</span><span class="meta import-name python">all</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="constant language import-all python">*</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">path</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">untar_data</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">URLs</span><span class="punctuation accessor dot python">.</span><span class="variable other constant python">PETS</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="keyword operator arithmetic python">/</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">images<span class="punctuation definition string end python">&#39;</span></span></span>

<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">is_cat</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="keyword control flow return python">return</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">isupper</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">dls</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ImageDataLoaders</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">from_name_func</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
  <span class="meta qualified-name python"><span class="meta generic-name python">path</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">get_image_files</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">path</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">valid_pct</span><span class="keyword operator assignment python">=</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">seed</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">42</span><span class="punctuation separator arguments python">,</span>
  <span class="variable parameter python">label_func</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">is_cat</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">item_tfms</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">Resize</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">224</span></span><span class="punctuation section arguments end python">)</span></span>
</span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p><code>ImageDataLoadeers.from_name_func()</code> es una de las funciones constructoras
para inicializar el <em>dataloader</em>. Esta función en particular permite crear un <em>dataloader</em>
directamente de las imagenes que se encuentran en un directorio, y cuyos nombres
contienen la estructura con las etiquetas del <em>dataset</em>. Por lo tanto,
se recibe una función para extraer las imagenes (<code>get_image_files</code>), se especifica el
porcentaje del conjunto de validación (<code>valid_pct</code>), además de la función para
extraer las etiquetas (<code>label_func</code>) y el argumento <code>item_tfms</code> que nos
permite aplicar transformaciones a las imagenes del <em>dataset</em> como ajustar
su tamaño, recortarlas, entre otras.</p>
<p>Una vez que inicializams el <em>dataloader</em> podemos <strong>entrenar el modelo</strong>. El objeto
<code>learner</code> en la librería <code>fastai</code> controla el proceso de aprendizaje e insume
todos los elementos necesarios (i.e. modelo, data, optimizador, etc). Existen
<code>learner</code>s específicos para arquitecturas conocidas como <code>cnn_learner</code> que es
para redes con arquitecturas con capas convolucionales. Se observa que uno de los
argumentos es <code>resnet34</code> (34 por el número de capas), acá estamos especificando que
utilizaremos este modelo pre-entrenado para adaptarlo a nuestro problema. <strong>Cuando
se utiliza <em>transfer learning</em> no se ajustan los parámetros desde 0</strong>,
sino que aplicamos <code>fine_tune(num_iter)</code> para (i) ajustar los parámetros de la
cabeza del modelo, capa encargada de adaptar el modelo al nuevo problema, y
(ii) ajustar los parámetros por cada época especificada en el argumento de
la función pero con la salvedad ir ajustando con mayor velocidad los
parámetros de las últimas capas respecto a los de las primeras, lo que
tiene sentido si pensamos que las primeras capas ya fueron entrenadas.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">learn</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">cnn_learner</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">dls</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">restnet34</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">metrics</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">error_rate</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">learn</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">fine_tune</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>Una vez que los parámetros fueron ajustados podemos utilizar el modelo como
cualquier programa, el cual recibe un <em>input</em> y entrega un <em>output</em>, este modo
se conoce como <strong>fase de inferencia</strong>. Finalmente, y dado que el programa creado
en este capítulo fue diseñado para resolver un problema de percepcción visual
que responde a la <em>query</em> ¿la imagen contiene un gato o un perro?. Podríamos
integrarlo dentro de otro <em>software</em> que, por ejemplo, deje entrar
a nuestro gato abriendole la puerta del patio pero que no haga lo mismo con el
perro de algún vecino.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">img</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">PILImage</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">create</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">uploader</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">data</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">is_cat</span></span>, <span class="meta qualified-name python"><span class="variable language python">_</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">probs</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">learn</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">predict</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">img</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="storage type string python">f</span><span class="meta string interpolated python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string interpolated python"><span class="string quoted double python">Es un gato?: </span><span class="meta interpolation python"><span class="punctuation section interpolation begin python">{</span><span class="source python embedded"><span class="meta qualified-name python"><span class="meta generic-name python">is_cat</span></span></span></span><span class="meta interpolation python"><span class="punctuation section interpolation end python">}</span></span><span class="string quoted double python">.<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="storage type string python">f</span><span class="meta string interpolated python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string interpolated python"><span class="string quoted double python">Probabilidad de que sea un gato: </span><span class="meta interpolation python"><span class="punctuation section interpolation begin python">{</span><span class="source python embedded"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">probs</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">item</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span></span><span class="meta format-spec python"><span class="constant other format-spec python">:.6f</span></span></span><span class="meta interpolation python"><span class="punctuation section interpolation end python">}</span></span><span class="string quoted double python"><span class="punctuation definition string end python">&quot;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>La librería es de alto nivel y tiene abstracciones sobre el <em>loop</em> de
aprendizaje para entrenar un modelo. Destacable que en núcleo de la API
se encuentra la técnica de <em>transfer learning</em>.</p>
<h2>Cuestionario</h2>
<ol>
<li>
<p>Do you need these for deep learning?</p>
<ul>
<li>Lots of math T/<strong>F</strong></li>
<li>Lots of data T/<strong>F</strong></li>
<li>Lots of expensive computers T/<strong>F</strong></li>
<li>A PhD T/<strong>F</strong></li>
</ul>
</li>
<li>
<p>Name five areas where deep learning is now the best tool in the world.</p>
<ul>
<li><strong>R:</strong> Natural Language Processing, Computer Vision, Recommendation Systems, Image Generation, Text Generation.</li>
</ul>
</li>
<li>
<p>What was the name of the first device that was based on the principle of the artificial neuron?</p>
<ul>
<li><strong>R:</strong> Mark I Perceptron, desarrollado por Frank Rossenblat. Una foto
de la pequeña máquina se puede ver <a href="https://americanhistory.si.edu/collections/search/object/nmah_334414">acá</a>.</li>
</ul>
</li>
<li>
<p>Based on the book of the same name, what are the requirements for parallel distributed processing (PDP)?</p>
<blockquote>
<blockquote>
<ul>
<li>Un conjunto de unidades de procesamiento</li>
<li>Un estado de activación</li>
<li>Una función de <em>output</em> para cada unidad</li>
<li>Un patrón de conectividad entre las unidades</li>
<li>Una regla de propagación para propagar patrones de actividad a
través de la red de connectividad</li>
<li>Una regla de activación para combinar los <em>inputs</em> incidiendo en una
unidad con el estado actual de esa unidad para producir un <em>output</em></li>
<li>Una regla de aprendizaje donde los patrones de conectividad sean
modificados por la experiencia (data)</li>
<li>Un ambiente donde el sistema opere</li>
</ul>
</blockquote>
</blockquote>
</li>
<li>
<p>What were the two theoretical misunderstandings that held back the field
of neural networks?</p>
<ul>
<li><strong>R:</strong> El primer malentendido que tuvieron las redes neuronales fue
por el trabajo realizado por Marvin Minsky y Seymour Papert en su libro
titulado <em>Perceptron</em>, donde demostraron que el Percepton no era capaz de
aprender funciones matematícas elementales como la función exclusive or.
Sin embargo, en el mismo libro demuestran que agregando una capa adicional
al <em>Perceptron</em>, el modelo tenía la flexibilidad de aprender cualquier función.
Otro malentendido es que estos modelos son cajas negras impenetrables. Si
bien presentan desafios a la hora de su interpretación, en el capítulo se da
como ejemplo el trabajo <em>Visualizing and Understanding Convolutional Networks</em>
(Zeiler, Fergus 2013) para dismitificar que las redes neuronales son modelos
inescrutalbles. Este <em>paper</em> investigó los parámetros de la red en cada capa e
identificó los <em>features</em> que el modelo aprendió una vez ajustado. Utilizando
esta información los autores mejoraron la arquitectura AlexNet y ganaron el
siguiente certamen de Imagenet el año 2013.</li>
</ul>
</li>
<li>
<p>What is a GPU?</p>
<ul>
<li><strong>R:</strong> Graphical Processing Unit (GPU). Esta pieza de <em>hardware</em> es útil
para computar múltiples operaciones en paralelo. Dado que entrenar
redes neuronales implica realizar muchas multiplicaciones y sumas, las
GPU han probado ser exitosas para entrenar estos modelos.</li>
</ul>
</li>
<li>
<p>Open a notebook and execute a cell containing: 1+1 What happens?</p>
<ul>
<li><strong>R:</strong> Devuelve el resultado de 2.</li>
</ul>
</li>
<li>
<p>Follow through each cell of the stripped version of the notebook for this
chapter. Before executing each cell, guess what will happen.</p>
<ul>
<li><strong>R:</strong> Done.</li>
</ul>
</li>
<li>
<p>Complete the Jupyter Notebook online appendix (https://oreil.ly/9uPZe)</p>
<ul>
<li><strong>R:</strong> Done.</li>
</ul>
</li>
<li>
<p>Why is it hard to use a traditional computer program to recognize images in a photo?</p>
<ul>
<li><strong>R:</strong> El desarrollo de un programa tradicional implica escribir las
instrucciones a la máquina de manera detallada. En palabras de Arthur
Samuel <em>"Programming a computer for such computations is, at best a
difficult task, ...because of the <strong>need to spell out every minute step
of the process in the most exasperating detail</strong>"</em>. En tareas de percepcción,
como reconocer objetos en una imagen, los humanos lo hacemos con facilidad
pero a nivel subconsciente. Por lo tanto, abstraer y crear estas instrucciones requiere
de un gran esfuerzo (<em>feature engineering</em>) y heuristicas para resolver
el problema. Además varían según el contexto particular (radiografia, números)
no siendo generalizables.</li>
</ul>
</li>
<li>
<p>What did Samuel mean by "weight assignment"?</p>
<ul>
<li><strong>R:</strong> Asignar valores a los parámetros del modelo. El proceso
de entrenamiento de una red neuronal es simplemente un proceso
de estimación o ajuste de los parámetros.</li>
</ul>
</li>
<li>
<p>What term do we normally use in deep learning for what Samuel called
"weights"?</p>
<ul>
<li><strong>R:</strong> El término más utilizado en la actualidad es el de parámetros (i.e.
especificado en la mayoría de los <em>frameworks</em> actuales).</li>
</ul>
</li>
<li>
<p>Draw a picture that summarizes Samuel's view of a machine learning model.</p>
</li>
</ol>
<center>
<img src="/img/fastai-chapter-1/Samuels_Diagram.png">
</center>
<ol start="14">
<li>
<p>Why is it hard to understand why a deep learning model makes a particular
prediction?</p>
<ul>
<li><strong>R:</strong> Todo modelo estadístico enfrenta dificultades para comprender
las predicciones a medida que la complejidad del modelo aumenta (i.e. más
parámetros y capas) y cuando los datos sobre los que el modelo se encuentra
operando difieren de manera importante respecto con los que fue entrenado (i.e. <em>distribution shift</em>).
Sobretodo vimos que una de las ventajas de las redes neuronales es su capacidad
modular de crecer incorporando capas y diferentes arquitecturas, pero esto
tambien dificulta la interpretabilidad de las predicciones. Siempre se debe
ser cauto con la interpretabilidad y afirmaciones sobre las capacidades
de un modelo, y aplicar varios métodos para inspeccionar y ver el funcionamiento
interno de los parámetros.</li>
</ul>
</li>
<li>
<p>What is the name of the theorem that shows that a neural network can solve
any mathematical problem to any level of accuracy?</p>
<ul>
<li><strong>R:</strong> El nombre del teorema es <em>Universal Approximation Theorem</em>. El
siguiente video de Michael Nielsen es una explicación visual sobre este teorema:</li>
</ul>
<right>
<iframe width="560" height="315" src="https://www.youtube.com/embed/Ijqkc7OLenI" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe>
</right>
</li>
<li>
<p>What do you need in order to train a model?</p>
<ul>
<li><strong>R:</strong> Del diagrama de más arriba podemos inferir que para entrenar
un modelo necesitamos datos (elipses azules), y por esto se entiende
el input (e.g. imagen, texto, características tabularizadas) y etiquetas
de buena calidad, sin esto último el mécanismo de <em>feedback</em>, compuesto
por la función de costo (elipse morada) y la regla de actualización, no puede
guiar el ajuste de los parámetros (elipse café). Se requiere una forma funcional
del modelo (aka arquitectura) para realizar las predicciones (elipses rosadas) en base a
los <em>inputs</em>, las cuáles el mécanismo de ajuste contrastará respecto a las etiquetas.
Una vez que el modelo fue entrenado, tenemos un programa que recibe <em>inputs</em> y entrega <em>outputs</em>,
el cual puede utilizarse como componente dentro de cualquier <em>software</em>.</li>
</ul>
</li>
<li>
<p>How could a feedback loop impact the rollout of a predictive policing model?</p>
<ul>
<li><strong>R:</strong> El modelo se ajusta a partir de datos. Si el modelo indica predicciones
para que las policias se focalicen en cierto sector geográfico, con mayor
probabilidad aumentaran los arrestos e incidentes registrados en esa zona
debido a la focalización de actividades de patrullaje policial.
En consecuencia, habrá un mayor número de información adicional de esa zona
cuando se incorporé nueva información al modelo. Al volver ajustar el modelo,
los ajustes de parámetros reforzaran la relación de criminalidad en ese sector,
aumentando el número de predicciones y respaldando las acciones policiales
definidas. Y así obtenemos un <em>positive feedback loop</em>, mientras más usamos el
modelo mayores sesgos producimos en los datos.</li>
</ul>
</li>
<li>
<p>Do we always have to use 224x224 pixel images with the cat recognition
model?</p>
<ul>
<li><strong>R:</strong> La dimensión de 224x224 responde a razones historicas cuando
se diseño una arquitectura en particular. Es posible aumentar la resolución
de la imagen y asi el modelo capturará mayor información, pero a un costo
computacional mayor. De manera contraria, menor resolución implica
una disminución en el desempeño del modelo, pero mayor eficiencia
computacional. Otra razón historica a la hora de entrenar redes
neuronales son que el tamaño de los <em>batch</em> aumenta en potencias de 2,
ver <a href="https://sebastianraschka.com/blog/2022/batch-size-2.html"><em>No, We Don't Have to Choose Batch Sizes As Powers of 2</em></a> (Sebastian Rashcka). Además de
la influencia de los <em>random seeds</em> para entrenar modelos, <a href="https://arxiv.org/abs/2109.08203"><em>"Torch.manual_seed(3407) is all you need: On the influence of random seed in
deep learning architectures for computer vision"</em></a> (David Picard).</li>
</ul>
</li>
<li>
<p>What is the difference between classification and regression?</p>
<ul>
<li><strong>R:</strong> La diferencia entre los problemas de clasificación y regresión
tiene que ver simplemente con el tipo de variable de respuesta que estamos
modelando. Si es una variable discreta (i.e. perro, gato, nivel socioeconomico)
es un problema de clasificación. En cambio, si la variable de respuesta
es continua (i.e. salario) es un problema de regresión.</li>
</ul>
</li>
<li>
<p>What is a validation set? What is a test set? Why do we need them?</p>
<ul>
<li><strong>R:</strong> El conjunto de validación se utiliza para computar métricas
durante el entrenamiento del modelo. Recordar que las métricas son
de consumo humano. Además el conjunto de validación nos permite probar
distintas configuraciones del modelo especificadas por los hiperparámetros.
En cambio, el conjunto de pruebas, o <em>test set</em>, es un conjunto de datos
reservado exclusivamente para reportar la <em>performance</em> final de nuestro
modelo, una vez que se probaron todas las ideas e iteraciones de experimentos.</li>
</ul>
</li>
<li>
<p>What will fastai do if you don't provide a validation set?</p>
<ul>
<li><strong>R:</strong> La librería <code>fastai</code> automáticamente separa el dataset en 80/20,
separando un 20% de los datos para el conjunto de validación. Si se requiere
cambiar este porcentaje se debe especificar en el argumento <code>valid_pct</code> del
<em>dataloader</em>.</li>
</ul>
</li>
<li>
<p>Can we always use a random sample for a validation set?</p>
<ul>
<li><strong>R:</strong> No siempre se debe usar un conjunto de validación aleatorio. La
mayor importancia tanto del conjunto de validación como el conjunto
de prueba es que sean representativos de datos futuros que no hemos visto.
Y tomar un conjunto de datos y obtener una fracción de manera aleatoria
no siempre es la respuesta. Imaginemos el caso de series de tiempo, no
tiene mucho sentido tomar una muestra aleatoria del dataset para
construir el conjunto de validación, pero si tiene sentido aislar
una parte final de la serie para evaluar el modelo simulando datos
futuros nunca antes visto. Otro ejemplo tiene que ver con posible redundancia
en las observaciones que de no aislarlas apropiadamente, el modelo obtenga
buenos resultados en el conjunto de validación solo porque ha memorizado
ciertas características de este grupo de observaciones particulares,
en vez de encontrar un patrón general. Por ejemplo, si tenemos la misma mascota
en diferentes fotos del <em>dataset</em>, lo correcto sería que todos los
ejemplos de esa mascota queden aislados en un mismo conjunto y no en
separadas en ambos conjuntos.</li>
</ul>
</li>
<li>
<p>What is overfitting? Provide an example.</p>
<ul>
<li><strong>R:</strong> El sobreajuste de un modelo se refiere al fenómeno cuando el modelo
comienza a memorizar el ruido, o parte "idiosincrática" del conjunto de datos
destinado al entrenamiento, tomando en cuenta efectos particulares del
<em>dataset</em> en el ajuste de sus parámetros y no realizando ajustes que
capturen patrones generalizables en los datos. El objetivo es entrenar
un modelo que obtenga un buen desempeño en datos nunca antes vistos y no
memorizar perfectamento los datos de entrenamiento. Por ejemplo, si utilizamos
un modelo para predecir el precio de viviendas, y el modelo durante el
entrenamiento se sobreajustó, sus parámetros reflejaran condiciones particulares
del grupo de viviendas utilizadas para ajustar el modelo y no un
patrón generalizable sobre los fundamentos en los precios de la vivienda
que sean de utilidad para cualquier otra vivienda que no se encuentre
en el <em>dataset</em>. El modelo tendrá peor desempeño en viviendas que no se encuentren
en los sectores cubiertos dentro del conjunto de entrenamiento, o que sus
características difieran respecto a los rangos de valores en las características
de las viviendas de entrenamiento.</li>
</ul>
</li>
<li>
<p>What is a metric? How does it differ from loss?</p>
<ul>
<li><strong>R:</strong> Una métrica sirve para medir el desempeño del modelo según
algún objetivo como nivel de error, precisión, sesgo en las predicciones, o
alguna métrica especifica de negocio (KPI). En otras palabras, las métricas
son de consumo humano, y están estrechamente relacionadas con el problema que
se busca resolver. En cambio, la función de costo esta diseñada para el
proceso de ajuste de los parámetros del modelo. Es parte del mécanismo de
retroalimentación automático del modelo. Por ejemplo, que la función de costo
haya disminuido 20% en 100 iteraciones no nos dice nada respecto a si
estamos identificando mejor las transaccions fraudulentas dentro de la
red de pagos, pero la historia es diferente si nuestro <em>accuracy</em>
mejoró 20%. Algo que si nos garantiza la función de costo es un mécanismo de
retroalimentación respecto a las predicciones del modelo según un estado
particular de parámetros (i.e. set de valores), y efectuar los cambios
pertinentes de estos en la dirección que minimiza la función de costo. Por esta
razón, la función debe cumplir ciertas propiedades como ser diferenciable,
eficiente en computar, etcétera. En conclusión, la función de costo es para el
computador y la métrica para el humano.</li>
</ul>
</li>
<li>
<p>How can pretrained models help?</p>
<ul>
<li><strong>R:</strong> Un modelo pre-entrenado ya paso por un proceso de ajuste de
parámetros, por lo que cuenta con algún grado de capacidad que permite
acelerar el aprendizaje en nuevos datos. Estas capacidades en
el mejor de los casos pueden ser características, o <em>features</em>, que el modelo
derechamente ya aprendió y pueden ser generalizables. Por ejemplo, en el
caso de un problema de clasificación de imagenes, disponer de una
característica que ya identifica "esquinas" siempre será de utilidad, y
es algo que se puede reutilizar. En un caso no tan óptimo, un modelo
pre-entrenado puede verse como una partida en caliente para el nuevo
proceso de ajuste y proveernos de una buena inicialización de parámetros
versus una inicialización completamente aleatoria.</li>
</ul>
</li>
<li>
<p>What is the "head" of a model?</p>
<ul>
<li><strong>R:</strong> La cabeza del modelo es la última capa que se agrega a una
arquitectura de un modelo pre-entrenado especifica para el dataset
que estamos trabajando. Cuando utilizamos un modelo pre-entrenado, debemos
adaptar la capa final de la arquitectura del modelo a las dimensiones
del <em>output</em> del problema que queremos "transferir" el modelo.
Por ejemplo, si el modelo pre-entrenado fue ajustado en el <em>dataset</em> ImageNet
el cual busca identificar 1000 categorías y nuestro problema solo requiere
distinguir entre dos, debemos adaptar la cabeza del modelo a una salida
de largo 2.</li>
</ul>
</li>
<li>
<p>What kind of features do the early layers of a CNN find? How about the
later layers?</p>
<ul>
<li><strong>R:</strong> Los <em>features</em> de las primeras capas son más primitivos, o
básicos, como texturas, gradientes o esquinas. En cambio, a medida que
vamos avanzando en las capas, los <em>features</em> que aprende la red van
siendo de mayor nivel como figuras geometricas, caras, etcétera. Una
explicación de esto tiene que ver con las capas convolucionales, las cuáles
son capas volumétricas que despliegan varios filtros o kernels que se especializan
en una misma región de pixeles (aka receptive field), aprendiendo conceptos
y aprovechando la estructura de "localidad" de la imagen: pixeles más
cercanos tinen mayor relación que pixeles más distantes. Además, entre
capas convulocionales, este tipo de arquitecturas suelen utilizar una capa
de <em>pooling</em>, básicamente es una técnica de <em>downsampling</em>, reduciendo
imagenes por ejemplo de 28x28 a 14x14 compactando los pixeles de la imagen
a través de una operación de agregado, lo que luego, al aplicar otra capa
convolucional tiene el efecto de aumentar la cobertura de los nuevos kernels
sobre la información de la imagen, aumentando su receptive field. De esta forma
las últimas capas comienzan aprender conceptos de mayor jerarquía al relacionar
distintas regiones iniciales que los filtros observaban y a construir en base
a los conceptos más primitivos.</li>
</ul>
</li>
<li>
<p>Are image models useful only for photos?</p>
<ul>
<li><strong>R:</strong> No, se puede utilizar modelos de imagen para todo problema que se
pueda reformular como una imagen (e.g. sonido-a-espectogramas). Regla general,
si un humano es capaz de interpretar un gráfico de cierto fenómeno que no
es un problema inherente de imagen, es probable que una arquitectura diseñada
para modelos de imagenes funcione bien.</li>
</ul>
</li>
<li>
<p>What is an architecture?</p>
<ul>
<li><strong>R:</strong> Las redes neuronales son funciones. La arquitectura es la forma
funcional que toma una red neuronal, la cual esta compuesta por las diferentes
capas y conexiones descrita en los parámetros. En la imagen a continuación
se preseneta la forma funcional del modelo que ganó la
competencia ImageNet 2012, llamada AlexNet:</li>
</ul>
</li>
</ol>
<right>   
![](https://www.researchgate.net/profile/Moumita-Tora/publication/318796117/figure/fig4/AS:631679571996680@1527615554120/AlexNet-illustration-The-input-is-a-224-by-224-image-that-goes-through-several-hidden.png)
</right> 
<ol start="30">
<li>
<p>What is segmentation?</p>
<ul>
<li><strong>R:</strong> Es un problema dentro del campo de visión por computadora que
consiste en identificar el contenido al que pertenece cada pixel dentro
de una imagen (i.e. autos, semaforos, peatones, etc).</li>
</ul>
</li>
<li>
<p>What is <code>y_range</code> used for? What do we need it?</p>
<ul>
<li><strong>R:</strong> Sirve para especificar el rango de la variable de respuesta cuando
el problema es de regresión, es decir, tenemos una variable de respuesta
tipo continua.</li>
</ul>
</li>
<li>
<p>What are hyperparameters?</p>
<ul>
<li><strong>R:</strong> Los hiperparámetros son variables que controlan algunos aspectos sobre
el proceso de entrenamiento del modelo. Por ejemplo, la cantidad de regularización
o la magnitud de la tasa de aprendizaje. Son parámetros sobre parámetros
dado afectan el ajustes sobre los parámetros de una u otra forma.</li>
</ul>
</li>
<li>
<p>What's the best way to avoid failures when using AI in an organization?</p>
<ul>
<li><strong>R:</strong> Siempre diseñar y crear un buen conjunto de validación
para evaluar correctamente la generalización de los modelos. Si se trabaja
con terceros, quienes se les encargará resolver un problema utilizando
modelos ajustados en base a datos, siempre guardar un conjunto de prueba
que no hayan visto los proveedores. Así tendremos capacidad para
diagnosticar correctamente el desempeño del modelo. Otro punto importante
es paralelamente elaborar un buen modelo base para saber de antemano
el posible potencial de mejora utilizando modelos más complejos para
resolver el problema.</li>
</ul>
</li>
</ol>
<h2>Further Research Questions</h2>
<p><em>Última actualización: 09/07/2022</em></p>
<p><strong>Why is a GPU useful for Deep Learning? How a CPU is different, and why is it less effective for deep learning?</strong></p>
<p>Hay una charla realizada por Stuart Oberman, vicepresidente de NVidia, en
Stanford realizada el 2017 que da un buen <em>overview</em> acerca sobre las GPU: <a href="https://www.youtube.com/watch?v=98Xis1W1mMk">Nvidia GPU Computing: A Journey From PC Gaming to Deep Learning</a> (slides de la
<a href="http://web.stanford.edu/class/ee380/Abstracts/171004-slides.pdf">presentación</a>).</p>
<center><iframe width="560" height="315" src="https://www.youtube.com/embed/98Xis1W1mMk" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture" allowfullscreen></iframe></center>
<p>GPU computing se divide en dos grandes grupos:</p>
<ol>
<li>Simulación: drug design, options pricing, wheather forecasting</li>
<li>Visualización: seismic imaging, automotive design, medical imaging</li>
</ol>
<p>NVIDIA introdujo la GPU en 1999, un único procesador para acelerar
juegs de video y gráficas 3D.</p>
<p>Objetivo: acercarse a la calidad de imagen de estudios de video
renderizadas de manera offline , pero en tiempo real. Esto significa
millones de pixeles por <em>frame</em>, &gt;  60 <em>frames</em> por segundo. Uso de largos
<em>arrays</em> de floating points para explorar paralelismo a lo ancho y profundo.</p>
<p>El modelo G80 fue la primera GPU que introdujo un procesador unificado
para sombras (unified shader processor). Todas las etapas de sombra
usan el mismo set de instrucciones y se ejecutan en la misma unidad: <em>streaming
multiprocessor</em> (CUDA).</p>
<blockquote>
<p>CUDA: C++ for throughput computers, on-chip memory managmenet, asunchronous, parallel
API,  programmability makes it possible to innovate.</p>
</blockquote>
<p>La <em>slide</em> número 22 hace una comparación interesante entre el paradigma
que guia un GPU versus a un CPU:</p>
<blockquote>
<blockquote>
<p>Latency Oriented:</p>
</blockquote>
</blockquote>
<ul>
<li>Fewer, bigger cores with out-of-order, speculative execution</li>
<li>Big caches optimized for latency</li>
<li>Math units are small part of the die</li>
</ul>
<blockquote>
<blockquote>
<p>Throughput Oriented</p>
</blockquote>
</blockquote>
<ul>
<li>Lots of simple compute cores and hardware scheduling</li>
<li>Big register files. Caches optimized for bandwidth.</li>
<li>Math units are most of the die</li>
</ul>
<p>Definiciones de los conceptos anteriores según el libro Designing Data-Intensive
Application de Martin Kleppmann:</p>
<p><strong>Throughput</strong></p>
<blockquote>
<p><em>Throughput is the number of records we can process per second, or the total time
it takes to run a job on a dataset of a certain size</em></p>
</blockquote>
<p><strong>Latency</strong></p>
<blockquote>
<p><em>Latency is the duration that a request is waiting to be handled -- during which
it is latent, awaiting service</em></p>
</blockquote>
<p><strong>Response time</strong></p>
<blockquote>
<p><em>Response time is what the client sees: besides the actual time to process the
request (the service time), it includes networks delays and queueing delays</em></p>
</blockquote>
<p>Pascal GP100: primer modelo de NVIDIA adaptado para Deep Learning, 21 TFLOPS fp16
for Deep Learning training and inference acceleration. Primera vez que se
agrega datatype fp16 con el objetivo de acelerar el entrenamiento e inferencia
de modelos de Deep Learning.</p>
<p>Tensor Core: matriz de precision hibrida. FP16 para AB y acumula con FP32 (o FP16).</p>
<br>
<br>
<p><strong>Try to think three areas where feedback loops might impact the use of machine learning. See if you can find documented examples of that happening in practice.</strong></p>
<p>TODO...</p>
]]></content>
  </entry>
  <entry>
    <title>Get N colours from a continuous colourmap in matplotlib 🎨</title>
    <link href="https://alkzar.cl/posts/get-n-colours-from-a-continuous-colourmap-in-matplotlib/"/>
    <id>https://alkzar.cl/posts/get-n-colours-from-a-continuous-colourmap-in-matplotlib/</id>
    <published>2022-05-25T00:00:00Z</published>
    <updated>2022-05-25T00:00:00Z</updated>
    <content type="html"><![CDATA[<p>During this week, I researched how to discretize a continuous colour palette with matplotlib. I had this problem with doing a data visualization in which I needed a sizeable discrete colour palette (i.e. 20). So, the solution is <a href="https://stackoverflow.com/a/14779462/5843243" target="_blank">easily findable on the web</a> if you've dealt with colourmaps in the past. However, I want to wrap up the why of the problem and how to solve it.</p>
<center>
<img src="/img/n-colors-from-cmap-matplotlib-post/lr_convergence.png">
</center>
<br>
<p>The above visualization shows 20 different lines that exhibit a clearly
convergence rate pattern between the hyperparameter <span class="math inline"><math display="inline"><mi>η</mi></math></span> and some
performance measure such as the log-likelihood.
How can we map the <span class="math inline"><math display="inline"><mi>η</mi></math></span> values to its specific lines? The natural option
is used colours, but you immediately notice that we have a problem with such large quantity
of different values for <span class="math inline"><math display="inline"><mi>η</mi></math></span>. Ok, it's not really a problem, you can just
designed a <code>custom_palette</code> (e.g. <a href="https://github.com/BlakeRMills/MetBrewer">MetBrewer repo</a>)
and code as follow:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">cycler</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">cycler</span></span>
<span class="meta statement import python"><span class="keyword control import python">import</span> <span class="meta qualified-name python"><span class="meta generic-name python">matplotlib</span></span> <span class="keyword control import as python">as</span> <span class="meta qualified-name python"><span class="meta generic-name python">mpl</span></span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">custom_palette</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">#hexcode_1<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator list python">,</span> <span class="constant language python">...</span><span class="punctuation separator list python">,</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">#hexcode_20<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation section list end python">]</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">custom_cycler</span></span> <span class="keyword operator assignment python">=</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">cycler</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">color</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">custom_palette</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section group end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">rc</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">axes<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">prop_cycle</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">custom_cycler</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>If you didn't know what <a href="https://matplotlib.org/cycler/">Cycler</a> is, it is just a
convenient way that <code>matplotlib</code> provides to iterate for different style options
such as colours, line styles, and others.</p>
<p>But wait…, there is a constraint here; we want to exhibit a pattern through
the colours--<em>as we increase the magnitude of <span class="math inline"><math display="inline"><mi>η</mi></math></span> the log-likelihood
convergence rate increase as well</em>--so we need this notion of a gradient. Look
at the <span class="math inline"><math display="inline"><mi>η</mi></math></span> values; there are jumping in regular steps of 0.005. That
could be annoying for picking a colour's sequence because we would have to
be accountable for the regularity between any consecutive colours.</p>
<p>So, a <a href="https://matplotlib.org/3.5.0/tutorials/colors/colormaps.html">continuous colormap</a> (<code>CMAP</code>) solve the previous issue, just we need a way
to discretize and pick <code>N=20</code> colours.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">cmap</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">get_cmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">CMAP</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">N</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">custom_palette</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">mpl</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">colors</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">rgb2hex</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">cmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">i</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="meta expression generator python"><span class="keyword control flow for generator python">for</span> <span class="meta generic-name python">i</span> <span class="keyword control flow for in python">in</span></span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">cmap</span><span class="punctuation accessor dot python">.</span><span class="variable other constant python">N</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section list end python">]</span></span>
</span></code></pre>
<p>Can we do better? Yes! It is possible to decouple the precision from the colour mapping
to each of the 20-labels by highlighting particular lines (e.g. best, worst)
using annotations. Then use a colourbar to communicate the <span class="math inline"><math display="inline"><mi>η</mi></math></span>'s effect on the
convergence rate as its value increases. The visualization will be way cleaner
than the above visualization with two columns of labels.</p>
<p>The <a href="https://stackoverflow.com/a/70192912/5843243" target="_blank">trick</a>
is done by <code>ScalerMappable</code> and requires passing two elements:</p>
<ol>
<li>A <code>cmap</code> that can we recover from our custom_palette <code>ListedColormap(custom_palette)</code>
or using directly the original ('RdBu' in this case)</li>
<li>A boundary on each discrete colour in the colourbar using <code>BoundaryNorm</code> that
takes two inputs: the list and length of values (i.e. different <span class="math inline"><math display="inline"><mi>η</mi></math></span> values)</li>
</ol>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">matplotlib</span><span class="punctuation accessor dot python">.</span><span class="meta import-name python">colors</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">ListedColormap</span><span class="punctuation separator import-list python">,</span> <span class="meta generic-name python">BoundaryNorm</span></span>
<span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">matplotlib</span><span class="punctuation accessor dot python">.</span><span class="meta import-name python">cm</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">ScalarMappable</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">cbar</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">colorbar</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">ScalarMappable</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">norm</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">BoundaryNorm</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">learning_rates</span></span><span class="punctuation separator arguments python">,</span> 
                                                     <span class="variable parameter python">ncolors</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">len</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">learning_rates</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> 
                                   <span class="variable parameter python">cmap</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">ListedColormap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">custom_palette</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation terminator statement python">;</span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">cbar</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tick_params</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">labelsize</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">7</span></span><span class="punctuation section arguments end python">)</span></span> 
</span></code></pre><center>
<img src="/img/n-colors-from-cmap-matplotlib-post/lr_convergence2.png">
<br>
<img src="/img/n-colors-from-cmap-matplotlib-post/dual_lr_convergence_curves.png">
</center>
<br>
<br>
<br>
<p>                         <em>That's the way computer talks to each other.</em></p>
]]></content>
  </entry>
  <entry>
    <title>How Gauss would compute a Confusion matrix for their classification model</title>
    <link href="https://alkzar.cl/posts/this-is-how-gauss-would-compute-a-confusion-matrix/"/>
    <id>https://alkzar.cl/posts/this-is-how-gauss-would-compute-a-confusion-matrix/</id>
    <published>2022-05-19T00:00:00Z</published>
    <updated>2022-05-19T00:00:00Z</updated>
    <content type="html"><![CDATA[<figure>
  <center>
    <img src="https://imgs.xkcd.com/comics/what_to_bring.png" alt="xkcd confusion matrix comic">
    <figcaption><a href="https://xkcd.com/1890/" target="_blank">Source: xkcd.com</a></figcaption>
  </center>
</figure>
<br>
<p>A <a href="https://en.wikipedia.org/wiki/Confusion_matrix" target="_blank">confusion matrix</a>
is a practical and conceptually simple tool to evaluate a classification model.
So we need to honour it with a simple way to compute it, like Gauss in the past,
without the magic of <del><code>from sklearn.metrics import confusion_matrix</code></del> would do
it with simple linear algebra operations:</p>
<blockquote>
<p><em>A confusion matrix is the matrix multiplication by the true and predicted labels, both encoding as one-hot vectors.</em></p>
</blockquote>
<p>If we have the true labels of 4 observations in vector <span class="math inline"><math display="inline"><mrow><mi>𝐲</mi></mrow><mo>=</mo><mo symmetric="false" stretchy="false">[</mo><mn>1</mn><mo>,</mo><mn>0</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>1</mn><mo symmetric="false" stretchy="false">]</mo></math></span>, and 3 different classes (i.e. 0, 1 and 2), their one-hot encoding will be:</p>
<p><div class="math display"><math display="block"><mrow><mi>𝐓</mi></mrow><mo>=</mo><mrow><mo stretchy="true">[</mo><mtable class="menv-arraylike"><mtr><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd></mtr><mtr><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd></mtr><mtr><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd></mtr><mtr><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd></mtr></mtable><mo stretchy="true">]</mo></mrow><mtext>&nbsp;</mtext><mo>∈</mo><mtext>&nbsp;</mtext><mo symmetric="false" stretchy="false">[</mo><mn>0,1</mn><msup><mo symmetric="false" stretchy="false">]</mo><mrow><mn>4</mn><mtext>&nbsp;</mtext><mo>×</mo><mtext>&nbsp;</mtext><mn>3</mn></mrow></msup></math></div></p>
<p>Some classification model gives us the predicted label for each observation in the vector <span class="math inline"><math display="inline"><mover><mrow><mrow><mi>𝐲</mi></mrow></mrow><mi>^</mi></mover><mo>=</mo><mo symmetric="false" stretchy="false">[</mo><mn>2</mn><mo>,</mo><mn>0</mn><mo>,</mo><mn>2</mn><mo>,</mo><mn>0</mn><mo symmetric="false" stretchy="false">]</mo></math></span>,
by the same logic above, the one-hot encoding will be:</p>
<p><div class="math display"><math display="block"><mover><mrow><mrow><mi>𝐓</mi></mrow></mrow><mi>^</mi></mover><mo>=</mo><mrow><mo stretchy="true">[</mo><mtable class="menv-arraylike"><mtr><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd></mtr><mtr><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd></mtr><mtr><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd></mtr><mtr><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd></mtr></mtable><mo stretchy="true">]</mo></mrow><mtext>&nbsp;</mtext><mo>∈</mo><mtext>&nbsp;</mtext><mo symmetric="false" stretchy="false">[</mo><mn>0,1</mn><msup><mo symmetric="false" stretchy="false">]</mo><mrow><mn>4</mn><mtext>&nbsp;</mtext><mo>×</mo><mtext>&nbsp;</mtext><mn>3</mn></mrow></msup></math></div>
We have everything to compute the confusion matrix and, it will be
<span class="math inline"><math display="inline"><msup><mrow><mi>𝐓</mi></mrow><mrow><mi>⊤</mi></mrow></msup><mover><mrow><mrow><mi>𝐓</mi></mrow></mrow><mi>^</mi></mover><mtext>&nbsp;</mtext><mo>∈</mo><mtext>&nbsp;</mtext><msubsup><mrow><mi>𝐙</mi></mrow><mrow><mn>0</mn><mi>+</mi></mrow><mrow><mn>3</mn><mo>×</mo><mn>3</mn></mrow></msubsup></math></span>. So again,</p>
<blockquote>
<p><em>A confusion matrix is the matrix multiplication by the true and predicted labels, both encoding as one-hot vectors.</em></p>
</blockquote>
<p><div class="math display"><math display="block"><msup><mrow><mi>𝐓</mi></mrow><mrow><mi>⊤</mi></mrow></msup><mover><mrow><mrow><mi>𝐓</mi></mrow></mrow><mi>^</mi></mover><mo>=</mo><mrow><mo stretchy="true">[</mo><mtable class="menv-arraylike"><mtr><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd></mtr><mtr><mtd><mn>1</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd></mtr><mtr><mtd><mn>0</mn></mtd><mtd><mn>0</mn></mtd><mtd><mn>1</mn></mtd></mtr></mtable><mo stretchy="true">]</mo></mrow><mtext>&nbsp;</mtext><mo>∈</mo><mtext>&nbsp;</mtext><msubsup><mi>Z</mi><mrow><mn>0</mn><mi>+</mi></mrow><mrow><mn>3</mn><mtext>&nbsp;</mtext><mo>×</mo><mtext>&nbsp;</mtext><mn>3</mn></mrow></msubsup><merror style="border-color: #b22222"><mtext>parsing error: new line command not allowed in current environment
╭─► context:
│
│…times~3} \\
│ ^^^^^^^^^^^
╰────────────</mtext></merror></math></div></p>
<p>As you notice, the confusion matrix summarizes the information correctly of
both vectors.</p>
<p><div class="math display"><math display="block"><mrow><mi>𝐲</mi></mrow><mo>=</mo><mo symmetric="false" stretchy="false">[</mo><mn>1,0,2,1</mn><mo symmetric="false" stretchy="false">]</mo><merror style="border-color: #b22222"><mtext>parsing error: new line command not allowed in current environment
╭─► context:
│
│…1,0,2,1] \\
│ \hat{\boldsy
│ ^^^^^^^^^^^^
╰─────────────</mtext></merror><mover><mrow><mrow><mi>𝐲</mi></mrow></mrow><mi>^</mi></mover><mo>=</mo><mo symmetric="false" stretchy="false">[</mo><mn>2,0,2,0</mn><mo symmetric="false" stretchy="false">]</mo></math></div></p>
<ul>
<li>The sum of the diagonal elements tells us that two observations were correctly
classified by the model</li>
<li>The model correctly classified one observation for class <code>0</code></li>
<li>The model didn't assign any label to class <code>1</code>, producing two errors</li>
<li>The model confuses two class' <code>1</code>-observations, one with a <code>2</code> and the
other with a <code>0</code>, look at the second row</li>
</ul>
<h3>Now with <code>import numpy as np</code></h3>
<p>We need two steps to compute our confusion matrix.</p>
<p>First, we need a way to transform a vector <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span> with k-classes into their one-hot-encoding
version, <code>v_one_hot = one_hot_econding(v)</code>:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">one_hot_encoding</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">v</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&#39;&#39;&#39;</span>Return the one-hot encoding vector for k-classes label vector<span class="punctuation definition comment end python">&#39;&#39;&#39;</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">num_classes</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">unique</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">v</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">size</span>
  <span class="keyword control flow return python">return</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">eye</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">num_classes</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">v</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
</span></code></pre>
<p>Second, compute the confusion matrix, <span class="math inline"><math display="inline"><mtext>&nbsp;</mtext><msup><mrow><mi>𝐓</mi></mrow><mrow><mi>⊤</mi></mrow></msup><mover><mrow><mrow><mi>𝐓</mi></mrow></mrow><mi>^</mi></mover><mtext>&nbsp;</mtext><mo>∈</mo><mtext>&nbsp;</mtext><msubsup><mrow><mi>𝐙</mi></mrow><mrow><mn>0</mn><mi>+</mi></mrow><mrow><mi>K</mi><mo>×</mo><mi>K</mi></mrow></msubsup><mtext>&nbsp;</mtext></math></span>, for k-classes; there are many ways of doing it with <code>numpy</code> as
you can see in the following code. Below I used the canonical notation to name the
true labels (<code>y</code>) and the predicted ones (<code>y_pred</code>):</p>
<pre><code class="code lang-python"><span class="source python"><span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 1st option: Using the matrix multiplication &#39;@&#39; operator
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">T</span> <span class="keyword operator matrix python">@</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y_pred</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 2nd option: Using np.dot()
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">T</span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y_pred</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 3rd option: Using np.matmul()
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">matmul</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">T</span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">one_hot_encoding</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y_pred</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>And we are done! Of course, you can always get your confusion matrix from your favourite store ;)</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">sklearn</span><span class="punctuation accessor dot python">.</span><span class="meta import-name python">metrics</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">confusion_matrix</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">confusion_matrix</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">y_pred</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><br>
<br>
<br>
<p>                         <em>That's the way computer talks to each other.</em></p>
]]></content>
  </entry>
  <entry>
    <title>Notes about how I am internalizing VIM key bindings ☕</title>
    <link href="https://alkzar.cl/posts/notes-about-how-i-am-internalizing-vim-keybindings/"/>
    <id>https://alkzar.cl/posts/notes-about-how-i-am-internalizing-vim-keybindings/</id>
    <published>2022-05-11T00:00:00Z</published>
    <updated>2022-05-11T00:00:00Z</updated>
    <content type="html"><![CDATA[<p><img src="https://imgs.xkcd.com/comics/real_programmers.png" alt="https://xkcd.com/378/" /></p>
<p><em>Wed, May 12th, 2022</em></p>
<p><strong>To my future self:</strong></p>
<p>My learning process using VIM key bindings was stuck on move commands; even though
I realized the power of moving freely and accurately around the editor, I haven't
a reason to waste my life with key bindings. This post is a never-ending excuse to
continue developing my VIM workflow. If I am reading this now, I am sure I forgot
something.</p>
<h3>Repeating a <code>char</code> or a <code>seq</code> of chars <code>n</code> times</h3>
<p><strong>Sequence</strong>: <code>ESC</code>-<code>n</code>-<code>i</code>-<code>char/seq</code>-<code>ESC</code>-<code>ESC</code></p>
<ul>
<li><code>n</code>: number of times</li>
<li><code>i</code>: insert mode</li>
<li><code>char/seq</code>: char or sequence of chars</li>
</ul>
<p><strong>Usage:</strong> Create a header for a file or when I need to repeat an arbitrary sequence
of characters.</p>
<script id="asciicast-J51Zv4rYzqsb2n6RUeM6iSqN4" src="https://asciinema.org/a/J51Zv4rYzqsb2n6RUeM6iSqN4.js" async data-autoplay="true" data-cols="84" data-rows="7" data-loop="1" data-speed="1.5"></script>
<h3>Basic replacing in the current line <code>:s</code> or in all llines <code>:%s</code></h3>
<p><strong>Sequence</strong>: <code>:s/this/for that/g</code> or <code>:%s/this/for that/g</code></p>
<ul>
<li><code>:s</code>: is short for "substitute" on the current line</li>
<li><code>:%s</code>: is short for "substitute" on the whole document</li>
<li><code>this</code>: is the pattern you want to replace</li>
<li><code>for that</code>: you want instead</li>
<li><code>g</code>: is short for "global" and refers to replacing more than the first
occurrence</li>
</ul>
<p><strong>Usage:</strong> guess what!</p>
<script id="asciicast-KSRic1uaEZi9SByNEFbYb3YRK" src="https://asciinema.org/a/KSRic1uaEZi9SByNEFbYb3YRK.js" async data-autoplay="true" data-cols="84" data-rows="10" data-loop="1" data-speed="1.5"></script>
<h3>Change upper-to-lower case and viceversa</h3>
<p><strong>Sequence:</strong> <code>guu</code> (upper-to-lower) or <code>gUU</code> (lower-to-upper) the current line</p>
<ul>
<li><code>U</code>: Change lower-to-upper case</li>
<li><code>u</code>: Change upper-to-lower case</li>
<li>VIM selection + {<code>u</code> or <code>U</code>}: change the selection given <code>u</code> or <code>U</code> behavior</li>
</ul>
<script id="asciicast-XPrdvDChbuUuipvaQvqqBvLP8" src="https://asciinema.org/a/XPrdvDChbuUuipvaQvqqBvLP8.js" async data-autoplay="true" data-cols="84" data-rows="10" data-loop="1" data-speed="1.5"></script>
<br>
<br>
<br>
<p>                         <em>That's the way computer talks to each other.</em></p>
]]></content>
  </entry>
  <entry>
    <title>Directional derivatives and JAX</title>
    <link href="https://alkzar.cl/posts/directional-derivatives-and-jax/"/>
    <id>https://alkzar.cl/posts/directional-derivatives-and-jax/</id>
    <published>2022-02-08T00:00:00Z</published>
    <updated>2022-02-08T00:00:00Z</updated>
    <content type="html"><![CDATA[<a href="https://colab.research.google.com/drive/1VD0QIfC-Q3WgmBPpgfAcF9zMZL_NAa5G?usp=sharing" target="_blank">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
<center>
<img src="/img/directional-derivative-post/directional_derivatives_setting.png">
</center>
<br>
<p>Directional derivatives are the conceptual tool to measure the effect on a
function by changing the input in any direction within the input space. It's
possible to compute the directional derivatives using the jacobian-vector
product, implemented by the automatic differentiation <a href="https://jax.readthedocs.io/en/latest/#">JAX</a> library.</p>
<p>Partial derivatives <span class="math inline"><math display="inline"><mi>∂</mi><mi>f</mi><mo>/</mo><mi>∂</mi><msub><mi>x</mi><mi>i</mi></msub></math></span> give us the rate of change if
we slightly modify the <em>ith element</em> of the input vector <span class="math inline"><math display="inline"><mrow><mi>𝐱</mi></mrow></math></span> by h, letting the
rest constant.</p>
<br>
<p><div class="math display"><math display="block"><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mn>1</mn></msub></mrow></mfrac><mo>=</mo><mi>l</mi><mi>i</mi><msub><mi>m</mi><mrow><mi>h</mi><mo>→</mo><mn>0</mn></mrow></msub><mfrac><mrow><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>1</mn></msub><mo>+</mo><mi>h</mi><mo>,</mo><msub><mi>x</mi><mn>2</mn></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>x</mi><mi>n</mi></msub><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐱</mi></mrow><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>h</mi></mrow></mfrac><merror style="border-color: #b22222"><mtext>parsing error: new line command not allowed in current environment
╭─► context:
│
│… x})}{h}
│ \\
│ \vdots
│ \\
│ \f
│ ^^
╰───</mtext></merror><mi>⋮</mi><merror style="border-color: #b22222"><mtext>parsing error: new line command not allowed in current environment
╭─► context:
│
│…\
│ \vdots
│ \\
│ \frac{\parti
│ ^^^^^^^^^^^^
╰─────────────</mtext></merror><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mi>n</mi></msub></mrow></mfrac><mo>=</mo><mi>l</mi><mi>i</mi><msub><mi>m</mi><mrow><mi>h</mi><mo>→</mo><mn>0</mn></mrow></msub><mfrac><mrow><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>1</mn></msub><mo>,</mo><msub><mi>x</mi><mn>2</mn></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>x</mi><mi>n</mi></msub><mo>+</mo><mi>h</mi><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐱</mi></mrow><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>h</mi></mrow></mfrac></math></div></p>
<br>
<p>The above definition can be more compactly using vector notation.</p>
<br>
<p><div class="math display"><math display="block"><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mi>i</mi></msub></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mrow><msub><mi>𝐱</mi><mn>𝟎</mn></msub></mrow><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>l</mi><mi>i</mi><msub><mi>m</mi><mrow><mi>h</mi><mo>→</mo><mn>0</mn></mrow></msub><mfrac><mrow><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><msub><mi>𝐱</mi><mn>𝟎</mn></msub></mrow><mo>+</mo><mi>h</mi><mrow><msub><mi>𝐞</mi><mi>𝐢</mi></msub></mrow><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><msub><mi>𝐱</mi><mn>𝟎</mn></msub></mrow><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>h</mi></mrow></mfrac></math></div>
<br></p>
<p>The <span class="math inline"><math display="inline"><msub><mi>e</mi><mi>i</mi></msub></math></span> vector represents a unit vector in the direction of <span class="math inline"><math display="inline"><mi>i</mi></math></span> with the
same number of dimensions that <span class="math inline"><math display="inline"><mrow><msub><mi>𝐱</mi><mn>𝟎</mn></msub></mrow></math></span>. The only
element of <span class="math inline"><math display="inline"><msub><mi>e</mi><mi>i</mi></msub></math></span> different from 0 is the <em>ith-element</em> with a value of 1.</p>
<p>As you can see in the initial diagram, in a 2D input space, there are two
partial derivatives:</p>
<ul>
<li><span class="math inline"><math display="inline"><mi>∂</mi><mi>f</mi><mo>/</mo><mi>∂</mi><mi>x</mi></math></span>: computing parallel to the x-axis (<span class="math inline"><math display="inline"><msub><mi>e</mi><mn>1</mn></msub></math></span> <em>typical known as</em> <span class="math inline"><math display="inline"><mover><mrow><mi>i</mi></mrow><mi>^</mi></mover></math></span>)</li>
<li><span class="math inline"><math display="inline"><mi>∂</mi><mi>f</mi><mo>/</mo><mi>∂</mi><mi>y</mi></math></span>: computing parallel to the y-axis (<span class="math inline"><math display="inline"><msub><mi>e</mi><mn>2</mn></msub></math></span> <em>typical known as</em> <span class="math inline"><math display="inline"><mover><mrow><mi>j</mi></mrow><mi>^</mi></mover></math></span>).</li>
</ul>
<p>Computing derivatives using unit vectors such as <span class="math inline"><math display="inline"><msub><mi>e</mi><mi>i</mi></msub></math></span> give us the change of
<span class="math inline"><math display="inline"><mi>f</mi></math></span> on the direction on <span class="math inline"><math display="inline"><mi>i</mi></math></span>, or parallel to the <em>i-axis</em>. How can we compute
the derivative of <span class="math inline"><math display="inline"><mi>f</mi></math></span> given a slight nudge of the inputs in any arbitrary
direction?</p>
<p>Directional derivatives is a way to compute the rate of change on <span class="math inline"><math display="inline"><mi>f</mi></math></span> in the
direction of <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span>.</p>
<br>
$$
\nabla_{{\bf v}}f({\bf x_0}) = lim_{h \to 0} \frac{f({\bf x_0} + h{\bf v}) - f({\bf x_0})}{h}
$$
<br>
<p>Think as <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span> as a weighted vector of the <em>n-directions</em> of the input
space. We aren't limited to the changes on <span class="math inline"><math display="inline"><mi>f</mi></math></span> in parallel directions in the input space.</p>
<p>We can compute directional derivatives using the dot product between the
jacobian vector (<span class="math inline"><math display="inline"><mi>∇</mi><mi>f</mi></math></span>) and the vector <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span>. For instance, for a two-dimensional input space, <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow><mo>=</mo><mo symmetric="false" stretchy="false">(</mo><msub><mi>v</mi><mn>1</mn></msub><mo>,</mo><msub><mi>v</mi><mn>2</mn></msub><mo symmetric="false" stretchy="false">)</mo></math></span>, and any arbitrary point <span class="math inline"><math display="inline"><mi>p</mi></math></span>:</p>
<p><div class="math display"><math display="block"><msub><mi>∇</mi><mrow><mi>𝐯</mi></mrow></msub><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>∇</mi><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo>⋅</mo><mrow><mi>𝐯</mi></mrow><mo>=</mo><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mn>1</mn></msub></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>v</mi><mn>1</mn></msub><mo>+</mo><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mn>2</mn></msub></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>v</mi><mn>2</mn></msub></math></div></p>
<p>More general:</p>
<p><div class="math display"><math display="block"><msub><mi>∇</mi><mrow><mi>𝐯</mi></mrow></msub><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>∇</mi><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo>⋅</mo><mrow><mi>𝐯</mi></mrow><mo>=</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>i</mi><mo>=</mo><mn>1</mn></mrow><mrow><mi>n</mi></mrow></munderover><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mi>i</mi></msub></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><msub><mi>v</mi><mi>i</mi></msub></math></div></p>
<p>Let's focus on computing the above using the function <code>jax.jvp</code>, which <code>jvp</code>
stands for the <em>jacobian-vector product</em>.</p>
<p>The function <code>jax.jvp</code> computes the directional derivative and whose arguments are:</p>
<ol>
<li>A differentiable function <span class="math inline"><math display="inline"><mi>f</mi></math></span> to compute the jacobian <span class="math inline"><math display="inline"><mi>∇</mi><mi>f</mi></math></span></li>
<li>A primal vector <span class="math inline"><math display="inline"><mrow><mi>𝐩</mi></mrow></math></span> to evaluate the jacobian <span class="math inline"><math display="inline"><mi>∇</mi><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo></math></span></li>
<li>A tangent vector <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span> which represent the direction in which we
want to calculate the derivative.</li>
</ol>
<p><code>jax.jvp</code> returns a tuple with <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">(</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo>,</mo><mi>∇</mi><msub><mi>f</mi><mrow><mi>v</mi></mrow></msub><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo><mo symmetric="false" stretchy="false">)</mo></math></span></p>
<h3>Example</h3>
<p>We compute the directional derivative of <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msup><mi>x</mi><mn>2</mn></msup><mi>y</mi></math></span> hand-coding all the
necessary elements and then checking the results given by <code>jax.jvp</code>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">fun</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">y</span></span>
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">fun_dx</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="keyword control flow return python">return</span> <span class="constant numeric integer decimal python">2</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span>
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">fun_dy</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span>
</span></code></pre>
<p>We define the primal vector <span class="math inline"><math display="inline"><mrow><mi>𝐩</mi></mrow></math></span> and the tangent vector <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span>
in which we want to compute the directional derivative.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">p</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">v</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span>
</span></code></pre>
<p>Evaluate <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>p</mi><mo symmetric="false" stretchy="false">)</mo></math></span>:</p>
<pre><code class="code lang-python"><span class="source python"><span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> *n-list/n-tuple unpack the element e0, e1, ..., en
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">fun</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">p</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span>
</span></code></pre>
<p>Compute the directional derivative using the <code>fun_dx</code> and <code>fun_dy</code>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">fun_dx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">p</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">*</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">v</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">fun_dy</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">p</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">*</span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">v</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>0</span>
</span></code></pre>
<p>Now using <code>jax.jvp</code> we obtain the same results: <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐩</mi></mrow><mo symmetric="false" stretchy="false">)</mo></math></span> and <span class="math inline"><math display="inline"><msub><mi>∇</mi><mrow><mi>𝐯</mi></mrow></msub><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐩</mi></mrow><mo symmetric="false" stretchy="false">)</mo></math></span>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jvp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">p</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">v</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">weak_type</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator tuple python">,</span>
   <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">weak_type</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section group end python">)</span></span>
</span></code></pre>
<p>A surface plot will show the output space, and a contour plot the input space
of <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msup><mi>x</mi><mn>2</mn></msup><mi>y</mi></math></span>. We will compute the directional derivatives for three points
and their respective directional vectors.</p>
<center>
<img src="/img/directional-derivative-post/directional_plot_surface.png">
</center>
<center>
<img src="/img/directional-derivative-post/directional_plot_contour.png">
</center>
<p>Look the directional vectors in the plot, or tangent vectors as JAX refers to
them, there are of different lengths. It's important to remark that if we want
the <a href="https://www.khanacademy.org/math/multivariable-calculus/multivariable-derivatives/partial-derivative-and-gradient-articles/a/directional-derivative-introduction?modal=1">"slope definition"</a> for directional derivatives we
need to transform <span class="math inline"><math display="inline"><mrow><mi>𝐯</mi></mrow></math></span> in a unit length vector (divide the directional
derivative definition by <span class="math inline"><math display="inline"><mi>|</mi><mi>|</mi><mi>v</mi><mi>|</mi><mi>|</mi></math></span>). Remember that partial
derivatives are computed using unit vectors (<span class="math inline"><math display="inline"><msub><mi>e</mi><mi>i</mi></msub></math></span>).</p>
<p><div class="math display"><math display="block"><msub><mi>∇</mi><mrow><mrow><mi>𝐯</mi></mrow></mrow></msub><mi>f</mi><mo>=</mo><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><mrow><mi>𝐱</mi></mrow></mrow></mfrac><mo>=</mo><mi>l</mi><mi>i</mi><msub><mi>m</mi><mrow><mi>h</mi><mo>→</mo><mn>0</mn></mrow></msub><mfrac><mrow><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐱</mi></mrow><mo>+</mo><mi>h</mi><mrow><mi>𝐯</mi></mrow><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mrow><mi>𝐱</mi></mrow><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>h</mi><mi>|</mi><mi>|</mi><mrow><mi>𝐯</mi></mrow><mi>|</mi><mi>|</mi></mrow></mfrac></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">primal_a</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">primal_b</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">primal_c</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">va</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">7<span class="punctuation separator decimal python">.</span>5</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span>7</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">vb</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">7<span class="punctuation separator decimal python">.</span>5</span><span class="punctuation separator list python">,</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span>7</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">vc</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>7</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">unit_va</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">va</span></span><span class="keyword operator arithmetic python">/</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">va</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">va</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric float python"><span class="punctuation separator decimal python">.</span>5</span>
<span class="meta qualified-name python"><span class="meta generic-name python">unit_vb</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">vb</span></span><span class="keyword operator arithmetic python">/</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">vb</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">vb</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric float python"><span class="punctuation separator decimal python">.</span>5</span>
<span class="meta qualified-name python"><span class="meta generic-name python">unit_vc</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">vc</span></span><span class="keyword operator arithmetic python">/</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">vc</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">vc</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric float python"><span class="punctuation separator decimal python">.</span>5</span>
<span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Computing making the directional vectors unit length
</span><span class="meta qualified-name python"><span class="variable language python">_</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">slope_a</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jvp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">primal_a</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">unit_va</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="variable language python">_</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">slope_b</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jvp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">primal_b</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">unit_vb</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="variable language python">_</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">slope_c</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jvp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">primal_c</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">unit_vc</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">tolist</span></span><span class="punctuation section arguments begin python">(</span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">slope_a</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">slope_b</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">slope_c</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">40<span class="punctuation separator decimal python">.</span>60427</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">weak_type</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator tuple python">,</span>
   <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">40<span class="punctuation separator decimal python">.</span>60427</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">weak_type</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator tuple python">,</span>
   <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">weak_type</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section group end python">)</span></span>
</span></code></pre>
<p>We can see some observations from the points and their directional derivatives.</p>
<ul>
<li><span style="color: #4682b4">Point A:</span> the directional derivative is 40.6, makes sense with the contour lines in front of A. The surface start to rise in the direction of <span class="math inline"><math display="inline"><msub><mover><mrow><mi>v</mi></mrow><mo stretchy="true">→</mo></mover><mi>a</mi></msub></math></span>.</li>
<li><span style="color: #b22222">Point B:</span> the function 𝑓 decreases in the direction pointing out the vector <span class="math inline"><math display="inline"><msub><mover><mrow><mi>v</mi></mrow><mo stretchy="true">→</mo></mover><mi>b</mi></msub></math></span>, like the directional derivative, 𝑓 changes −40.6 regarding the slight variations in the input across the directional vector. Notice that it has the same magnitude as the slope of point A but goes in the opposite direction; the surface plot shows how the function increases/decreases in the same proportion across its diagonals.</li>
<li><span style="color: #000000">Point C:</span> the surface is practically flat around the point (0,0). Notice that the directional derivative at <span class="math inline"><math display="inline"><msub><mover><mrow><mi>v</mi></mrow><mo stretchy="true">→</mo></mover><mi>c</mi></msub></math></span> is 0.</li>
</ul>
]]></content>
  </entry>
  <entry>
    <title>Taylor Approximation and JAX</title>
    <link href="https://alkzar.cl/posts/taylor-approximation-and-jax/"/>
    <id>https://alkzar.cl/posts/taylor-approximation-and-jax/</id>
    <published>2022-01-28T00:00:00Z</published>
    <updated>2022-01-28T00:00:00Z</updated>
    <content type="html"><![CDATA[<a href="https://colab.research.google.com/drive/1KDAbU3eW-fOxAYmp0eiQFbuuRmZARdqq?usp=sharing" target="_blank">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>
<center>
<img src="/img/taylor-post/portrait.png">
</center>
<p><em><strong>TL;DR:</strong></em> <em>In this post, we reviewed the concept of Taylor approximation, focusing on
differentiating using automatic differentiation techniques implemented in the python
JAX library. Taylor approximation is a powerful tool for analyzing non-linear systems
such as neural networks. We will examine two examples distilling from the book
Mathematics for Machine Learning, chapter 5, and implement code to have the ability
to reproduce and extend a quadratic approximation for other functions.</em></p>
<br>
<h3>1. Taylor Approximation Review</h3>
<p>Taylor's series allows us to approximate a function 𝑓 as a polynomial, computed
using derivatives. In the extrema, if we used infinite coefficients, or up to times
that 𝑓 can differentiate, we ended up with a perfect approximation.</p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mi>n</mi></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>:</mo><mo>=</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>k</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>n</mi></mrow></munderover><mfrac><mrow><msup><mi>f</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mi>k</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>k</mi><mi>!</mi></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>−</mo><msub><mi>x</mi><mn>0</mn></msub><msup><mo symmetric="false" stretchy="false">)</mo><mi>k</mi></msup></math></div></p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>1</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>:</mo><mo>=</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mo>+</mo><msup><mi>f</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>−</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<p>Note: 𝑓(𝑘) is 𝑓 differentiate k times, and 𝑘=0 is 𝑓 itself.</p>
<p>Let's code an example; I will replicate figure 5.4 from the <a href="https://mml-book.github.io/book/mml-book.pdf">Mathematics for Machine Learning</a>{target="_blank"} book.</p>
<p>We want to approximate the following function around <span class="math inline"><math display="inline"><mi>x</mi><mo>=</mo><mn>0</mn></math></span>:
<br>
<br>
<div class="math display"><math display="block"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>s</mi><mi>i</mi><mi>n</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>+</mo><mi>c</mi><mi>o</mi><mi>s</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">fun</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> 
    <span class="keyword control flow return python">return</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">sin</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">cos</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/example_5-4_00.png">
</center>
<p>That is like <span class="math inline"><math display="inline"><mi>f</mi></math></span> looks like around <span class="math inline"><math display="inline"><mi>x</mi><mo>=</mo><mn>0</mn></math></span> The task is to get an expression that
describes how <span class="math inline"><math display="inline"><mi>f</mi></math></span> varies around the <span class="math inline"><math display="inline"><mi>x</mi><mo>=</mo><mn>0</mn></math></span> neighbourhood. The most straightforward
way to achieve this is to remember that <span class="math inline"><math display="inline"><mi>f</mi><mi>′</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></span> is another function that gives
us the tangent line at point <span class="math inline"><math display="inline"><mi>x</mi></math></span>. We finish our task; we get an approximation
of <span class="math inline"><math display="inline"><mi>f</mi></math></span> just giving the equation of the tangent line at 𝑥.
Knowing what's the derivative of <span class="math inline"><math display="inline"><mi>s</mi><mi>i</mi><mi>n</mi></math></span> and <span class="math inline"><math display="inline"><mi>c</mi><mi>o</mi><mi>s</mi></math></span> plus the addition rule for
differentiating, we can compute this manually:</p>
<p><div class="math display"><math display="block"><mi>f</mi><mi>′</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>c</mi><mi>o</mi><mi>s</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>s</mi><mi>i</mi><mi>n</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Evaluating f`at x=0
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">cos</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">-</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">sin</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span>
</span></code></pre>
<p>And we get the equation for the tangent line approximating at <span class="math inline"><math display="inline"><msub><mi>x</mi><mn>0</mn></msub><mo>=</mo><mn>0</mn></math></span>,
<span class="math inline"><math display="inline"><mi>y</mi><mo>=</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mn>0</mn><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>1</mn></math></span>, and <span class="math inline"><math display="inline"><mi>m</mi><mo>=</mo><mi>f</mi><mi>′</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo>=</mo><mn>0</mn><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>1</mn></math></span>:</p>
<br>
$$y - y_0 = m (x - x_0)$$
<br>
$$y = 1 + f'(x_0)x$$
<br>
$$y = 1 + x$$
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">taylor_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">approx_around</span> <span class="keyword operator assignment python">=</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">num_coef</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/example_5-4_01.png">
</center>
<p>Notice that the tangent line is a pretty good approximation in the immediate space
around <span class="math inline"><math display="inline"><mi>x</mi><mo>=</mo><mn>0</mn></math></span>, but we want something that goes beyond our block. Our approximation
gets higher errors when we cross the street at the corner (look at <span class="math inline"><math display="inline"><mi>x</mi><mo>=</mo><mn>2</mn></math></span>!).</p>
<p>If we want to be famous at a scale, we need to improve our approximation. To do
that, we can improve how we deal with the curvature.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">taylor_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">approx_around</span> <span class="keyword operator assignment python">=</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">num_coef</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">3</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">PLOT_COEF</span> <span class="keyword operator assignment python">=</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">2</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/example_5-4_02.png">
</center>
<p>The green line does a better job approximating <span class="math inline"><math display="inline"><mi>f</mi></math></span> within -1 and 1 than the line.
It was intuitive to get an expression for the tangent line, not a quadratic one.
How do we get the equation that describes the green line?</p>
<p>Here is when Taylor's polynomial series is pretty handy:</p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>2</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>:</mo><mo>=</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>k</mi><mo>=</mo><mn>0</mn></mrow><mrow><mn>2</mn></mrow></munderover><mfrac><mrow><msup><mi>f</mi><mrow><mo symmetric="false" stretchy="false">(</mo><mi>k</mi><mo symmetric="false" stretchy="false">)</mo></mrow></msup><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>k</mi><mi>!</mi></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>−</mo><msub><mi>x</mi><mn>0</mn></msub><msup><mo symmetric="false" stretchy="false">)</mo><mi>k</mi></msup></math></div></p>
<p>To our line equation, we need to add the last term describes in <span class="math inline"><math display="inline"><msub><mi>T</mi><mn>2</mn></msub></math></span>. To obtain
this term, we need to compute <span class="math inline"><math display="inline"><mi>f</mi><mi>′</mi><mi>′</mi></math></span>.</p>
<p>In this case, the second derivative is easily computable given that the
derivatives are also cyclical because of the nature of <span class="math inline"><math display="inline"><mi>s</mi><mi>i</mi><mi>n</mi></math></span> and <span class="math inline"><math display="inline"><mi>c</mi><mi>o</mi><mi>s</mi></math></span>:</p>
<p><span class="math inline"><math display="inline"><mi>f</mi><mi>′</mi><mi>′</mi><mo>=</mo><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup><mi>f</mi></mrow><mrow><mi>∂</mi><msup><mi>x</mi><mn>2</mn></msup></mrow></mfrac><mo>=</mo><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup></mrow><mrow><mi>∂</mi><msup><mi>x</mi><mn>2</mn></msup></mrow></mfrac><mo symmetric="true" stretchy="true"minsize="1.2em" maxsize="1.2em">(</mo><mi>s</mi><mi>i</mi><mi>n</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>+</mo><mi>c</mi><mi>o</mi><mi>s</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo symmetric="true" stretchy="true"minsize="1.2em" maxsize="1.2em">)</mo></math></span></p>
<p><span class="math inline"><math display="inline"><mi>f</mi><mi>′</mi><mi>′</mi><mo>=</mo><mi>−</mi><mi>s</mi><mi>i</mi><mi>n</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>−</mo><mi>c</mi><mi>o</mi><mi>s</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></span></p>
<p>The quadratic approximation or the <em>second-order</em> Taylor approximation is:</p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>2</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>1</mn><mo>+</mo><mi>x</mi><mo>+</mo><mfrac><mrow><mi>f</mi><mi>′</mi><mi>′</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mn>2</mn></mrow></mfrac><msup><mi>x</mi><mn>2</mn></msup></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Evaluating f&#39;&#39; at x=0
</span><span class="keyword operator arithmetic python">-</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">sin</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">-</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">cos</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span>
</span></code></pre>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>2</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>1</mn><mo>+</mo><mi>x</mi><mo>−</mo><mfrac><mrow><mn>1</mn></mrow><mrow><mn>2</mn></mrow></mfrac><msup><mi>x</mi><mn>2</mn></msup></math></div>
It makes sense with the above pictures because the coefficient accompanied by the quadratic term is negative,
and therefore we have a concave down curve. Like you can see.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">plot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x_jnp</span></span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">1</span> <span class="keyword operator arithmetic python">+</span> <span class="meta qualified-name python"><span class="meta generic-name python">x_jnp</span></span> <span class="keyword operator arithmetic python">-</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>5</span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">x_jnp</span></span> <span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation separator arguments python">,</span> 
         <span class="variable parameter python">color</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">forestgreen<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span>
         <span class="variable parameter python">linestyle</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">-<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/example_5-4_03.png">
</center>
<p>So we can continue repeating this process, adding more coefficients and getting
a more accurate approximation. Of course, at the cost of computing higher-order
derivatives.</p>
<p>The below image is the final reproduction of figure 5.4. Notice the
power of 10 Taylor coefficients (red curve); it approximate <span class="math inline"><math display="inline"><mi>f</mi></math></span> within the domain
interval -4 and 4 almost perfectly. Be cautious, the same that happens with the
<em>fifth-order</em> Taylor approximation (green curve), which distances from <span class="math inline"><math display="inline"><mi>f</mi></math></span> in both
lateral of the plot; it would happen to <span class="math inline"><math display="inline"><msub><mi>T</mi><mrow><mn>10</mn></mrow></msub></math></span> if we expand the x-domain region
in the plot.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">taylor_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">fun</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">approx_around</span> <span class="keyword operator assignment python">=</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">num_coef</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">11</span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/example_5-4_04.png">
</center>
<p>Some thoughts about this section.</p>
<ol>
<li>
<p>How can we differentiate any 𝑓 no matter its complexity without relying on
manual computations?</p>
</li>
<li>
<p>How can we express the differentiation operations in code?</p>
</li>
<li>
<p>How can we extend Taylor approximation to multivariate functions
(i.e. <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>1</mn></msub><mo>,</mo><mi>…</mi><mo>,</mo><msub><mi>x</mi><mi>n</mi></msub><mo symmetric="false" stretchy="false">)</mo></math></span>) and everything which involve gradients?</p>
</li>
</ol>
<h3>2. Introducing Automatic Differentiation with JAX</h3>
<p><a href="https://jax.readthedocs.io/en/latest/index.html">JAX</a>{target="_blank"} is a python library that
combines the <code>numpy</code>'s interface, automatic differentiation capabilities, and
high-performance operations using XLA and GPU operations.</p>
<p>In this section, we will focus on the fundamentals of JAX to illustrate how to perform
automatic differentiation and understand how JAX operates at a high level.</p>
<ol>
<li><code>jax.grad()</code>: given a function <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></span> implemented in code, it returns a function
for compute the gradient (<span class="math inline"><math display="inline"><mi>f</mi><mi>′</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></span>)</li>
<li><code>jax.vmap()</code>: vectorize a <code>jax.grad</code>'s function</li>
<li><code>jax.jit()</code>: accelerate a function computations using XLA</li>
</ol>
<p>Let's start with an <a href="https://github.com/hips/autograd">example</a>{target="_blank"} used by the <code>autograd</code>
library, the predecessor of <code>JAX</code>: differentiate the hyperbolic tangent function.</p>
<p>The example is very illustrative because it is apparent how
<code>jax.grad</code> works modifying functions; look at the code!</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">jax</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">grad</span><span class="punctuation separator import-list python">,</span> <span class="meta generic-name python">vmap</span><span class="punctuation separator import-list python">,</span> <span class="meta generic-name python">jit</span></span>

<span class="meta annotation python"><span class="punctuation definition annotation python">@</span><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable annotation python">jit</span></span></span>
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">tanh</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
    <span class="keyword control flow return python">return</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span> <span class="keyword operator arithmetic python">-</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">exp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section group end python">)</span></span>  <span class="keyword operator arithmetic python">/</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">exp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section group end python">)</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">linspace</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric integer decimal python">7</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">7</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">200</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">fig</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">subplots</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">sharey</span><span class="keyword operator assignment python">=</span><span class="constant language python">True</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">figsize</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric float python">12<span class="punctuation separator decimal python">.</span>5</span><span class="punctuation separator tuple python">,</span> <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>5</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">plot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tanh</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">linestyle</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">-<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">color</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">black<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">axis</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">off<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">plot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tanh</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>
           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>                               <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 1st derivative
</span>           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>                         <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 2nd derivative
</span>           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>                   <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 3rd derivative
</span>           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>             <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 4th derivative
</span>           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span>       <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 5th derivative
</span>           <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> 6th derivative
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">suptitle</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">tanh and its higher-order derivatives (up to 6th)<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>           
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">fig</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">text</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>75</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python"><span class="punctuation separator decimal python">.</span>02</span><span class="punctuation separator arguments python">,</span> <span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">Source: Autograd README<span class="punctuation definition string end python">&quot;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">size</span><span class="keyword operator assignment python">=</span><span class="constant numeric integer decimal python">9</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">style</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">italic<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="meta function-call python"><span class="meta qualified-name python"><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">axis</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">off<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/autograd_tanh_example.png">
</center>
<p>As you can see in the code, <code>grad(tanh)</code> gives you a
function to compute the first derivative of <code>tanh</code>. Therefore, the transformation
of <code>jax.grad</code> in math notation is the following.</p>
<p><div class="math display"><math display="block"><mo symmetric="false" stretchy="false">(</mo><mi>∇</mi><mi>f</mi><mo symmetric="false" stretchy="false">)</mo><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><msub><mo symmetric="false" stretchy="false">)</mo><mi>i</mi></msub><mo>=</mo><mfrac><mrow><mi>∂</mi><mi>f</mi></mrow><mrow><mi>∂</mi><msub><mi>x</mi><mi>i</mi></msub></mrow></mfrac><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<p>Another interesting point is that <code>jax.grad</code> allows you to compose functions in a
series of transformations, such as the nested grad application to compute the
higher-order derivatives of <code>tanh</code>.</p>
<p>Why is the purpose of <code>jax.vmap</code>? If we want that
the function that <code>jax.grad</code> returns behave like this:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">tanh</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>46211717</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>7615941</span> <span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>90514827</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>We need to vectorize the function. Otherwise, we will have an error.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="comment line number-sign python"><span class="punctuation definition comment python">#</span>grad(tanh)(jnp.arange(10)) # throw an error
</span><span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>39322388</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>Therefore, if we want to evaluate the gradient at multiple values and receive an
array with the results, we can use <code>jax.vmap</code> to transform the function into a
vectorize version as much as grad operates modifying functions.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>39322388</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>20998716</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>09035333</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>We can code a naive implementation of <code>jax.vmap</code> to
understand what happens behind the scene. Beware that
the original function is far more complex, but this is fair to illustrate the main functionality.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">my_vmap</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">grad</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>A basic implementation of vmap to vectorize a function<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span> 
  <span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">grad</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">out</span></span> <span class="keyword operator assignment python">=</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="punctuation section list end python">]</span></span>
  <span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">i</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">shape</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
    <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">out</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">append</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">FUN</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">i</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="keyword control flow return python">return</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">out</span></span></span><span class="punctuation section arguments end python">)</span></span>
  
<span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">my_vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>39322388</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>20998716</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>09035333</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>You have an idea of how I replicate figure 5.4 of the previous sections that
require computing up to a <em>tenth-order</em> Taylor approximation. Yes, it's unnecessary
to hand-code the derivatives. I just used <code>jax.grad</code> ten times over <span class="math inline"><math display="inline"><mi>f</mi></math></span> itself.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span>
<span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">i</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">NUM</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Compute the ith derivative of FUN
</span>  <span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Do something like computing the ith taylor coefficient
</span>  <span class="constant language python">...</span>
</span></code></pre>
<p>For instance, let's plot <code>tanh</code> and its derivatives, but this time we will differentiate
ten times using the above pattern and avoid the nested code's boilerplate.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">tanh</span></span>
<span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">i</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">range</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">10</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">plot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">vmap</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">NABLA</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">axis</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">off<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/tanh_upto_10diff.png">
</center>
<p>Computing higher-order derivatives can be computationally expensive. Read the paper <a href="https://openreview.net/pdf?id=SkxEF3FNPH">"Taylor-Mode Automatic Differentiation for Higher-Order Derivatives in JAX"</a>{target="_blank"}
to understand the efficient way to compute higher-order derivatives. More context
about this problem and the paper's genesis in this
<a href="https://github.com/google/jax/issues/520">discussion</a>{target="_blank"}.</p>
<p><strong>How are the derivatives computed?</strong> <code>JAX</code> allow us to perform automatic differentiation and calculates results transforming numerical functions into a directed acyclic graph (DAG):</p>
<ul>
<li>outer lefts nodes represent the input variables</li>
<li>middle nodes represent intermediate variables</li>
<li>the outer right nodes represents the output node (a scalar)</li>
<li>as the name said, there are no cycles in the graph; the data always flows from left to the right, it could have branches, but none edge can point back</li>
</ul>
<p>The differentiation is just an application of the chain rule over DAG.</p>
<br>
<center>
<img src="/img/taylor-post/autoDidf_internediateVar_diagram.png">
</center>
<p>Once we have all the derivatives, we start multiplying but wait, the order matters.
Suppose we begin multiplying the square "F", as the diagram above shows you. Using
different orders to compute the gradient can get efficient depending on the problem.</p>
<p><code>jax.make_jaxpr</code> produces the JAX representation of the computation made, and it helps us visualise the diagram described above.</p>
<p>The intermediate variables are equations (<code>jaxpr.eqns</code>) that receive inputs, could be the function's input or other intermediate variables, and a set of primitive operations to compute over these to produce outputs.</p>
<p>You can read more about <code>jax.make_jaxpr</code> in the <a href="https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html">documentation</a>{target="_blank"}.</p>
<p>For instance, we can inspect how JAX decouples the function <span class="math inline"><math display="inline"><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msup><mi>x</mi><mn>2</mn></msup><mo>+</mo><mi>e</mi><mi>x</mi><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo></math></span>
in intermediate variables.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">f</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">exp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">jax_compu</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">make_jaxpr</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">f</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">jax_compu</span></span>
<span class="keyword operator comparison python">&gt;</span> 
 <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span> <span class="meta function inline python"><span class="storage type function inline python">lambda</span></span><span class="meta function inline python"><span class="meta function inline parameters python"> </span></span>; <span class="meta qualified-name python"><span class="meta generic-name python">a</span></span><span class="punctuation separator key-value python">:</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">f32</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span><span class="punctuation accessor dot python">.</span> <span class="meta qualified-name python"><span class="meta generic-name python">let</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">b</span></span><span class="punctuation separator key-value python">:</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">f32</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> = <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">integer_pow</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span>=<span class="constant numeric integer decimal python">2</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">a</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">c</span></span><span class="punctuation separator key-value python">:</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">f32</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> = <span class="meta qualified-name python"><span class="meta generic-name python">exp</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">a</span></span>
    <span class="meta qualified-name python"><span class="meta generic-name python">d</span></span><span class="punctuation separator key-value python">:</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">f32</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> = <span class="meta qualified-name python"><span class="meta generic-name python">add</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">b</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">c</span></span>
  <span class="keyword operator logical python">in</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="meta qualified-name python"><span class="meta generic-name python">d</span></span><span class="punctuation separator tuple python">,</span><span class="punctuation section group end python">)</span></span> <span class="punctuation section dictionary-or-set end python">}</span></span>
</span></code></pre>
<p>We can code a function to extract each element of the above <code>jaxpr</code>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">describe_jaxpr</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">FUN</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Given a function, print each element of its jaxpr<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
  <span class="meta statement import python"><span class="keyword control import from python">from</span></span><span class="meta statement import python"><span class="meta import-source python"> <span class="meta import-path python"><span class="meta import-name python">inspect</span></span> <span class="meta statement import python"><span class="keyword control import python">import</span></span></span></span><span class="meta statement import python"></span><span class="meta statement import python"> <span class="meta generic-name python">getsource</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">Source function definition:<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">getsource</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">--------------------------------------------------------------<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Evaluate the expression on 0.0 (arbitrary) to get a jaxpr
</span>  <span class="meta qualified-name python"><span class="meta generic-name python">expr</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">make_jaxpr</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">jaxpr</span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">The function has the following inputs, represented as <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">expr</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">invars</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">the function has the following constants, represented as <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">expr</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">constvars</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Get the equation that describe each intermediate variable and extract info 
</span>  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python"><span class="constant character escape python">\n</span>These are the intermediate variables describe by the equations computed along the DAG: <span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta statement for python"><span class="keyword control flow for python">for</span> <span class="meta generic-name python">i</span><span class="punctuation separator target-list python">,</span> <span class="meta generic-name python">eq</span> <span class="meta statement for python"><span class="keyword control flow for in python">in</span></span></span><span class="meta statement for python"> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">enumerate</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">expr</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">eqns</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation section block for python">:</span></span>
    <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">   <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">i</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">. <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">Obtain <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">eq</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">1</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python"> applying the primitive <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">eq</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">primitive</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python"> with params <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">eq</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">params</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python"> on input/s <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">eq</span></span></span><span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric integer decimal python">0</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python"><span class="constant character escape python">\n</span> The output is: <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">expr</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">outvars</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span> 
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">describe_jaxpr</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">f</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> 
<span class="meta qualified-name python"><span class="meta generic-name python">Source</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">definition</span></span><span class="punctuation separator annotation variable python">:</span>
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">f</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
 <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">exp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span>
<span class="meta qualified-name python"><span class="meta generic-name python">The</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">has</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">following</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">inputs</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">represented</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="invalid illegal name python">as</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">has</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">following</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">constants</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">represented</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="invalid illegal name python">as</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">These</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">are</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">intermediate</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">variables</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">describe</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">by</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">equations</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">computed</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">along</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="variable other constant python">DAG</span></span><span class="punctuation separator annotation variable python">:</span> 
  <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">b</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">integer_pow</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">y<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation section dictionary-or-set end python">}</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
  <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">c</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">exp</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
  <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">d</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">add</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">b</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">c</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">The</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">output</span></span> <span class="keyword operator logical python">is</span><span class="punctuation separator annotation variable python">:</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta qualified-name python"><span class="meta generic-name python">d</span></span><span class="punctuation section list end python">]</span></span>
</span></code></pre>
<p>Similar to the diagram above, we have two intermediate variables used to describe the output in this example.</p>
<ul>
<li>Input <span class="math inline"><math display="inline"><mi>x</mi></math></span> is represented by <span class="math inline"><math display="inline"><mi>a</mi></math></span></li>
<li>The first intermediate variable is <span class="math inline"><math display="inline"><mi>b</mi><mo>=</mo><msup><mi>a</mi><mn>2</mn></msup></math></span></li>
<li>Then, the second intermediate variable is created also using as input <span class="math inline"><math display="inline"><mi>a</mi></math></span>: <span class="math inline"><math display="inline"><mi>c</mi><mo>=</mo><mi>e</mi><mi>x</mi><mi>p</mi><mo symmetric="false" stretchy="false">(</mo><mi>a</mi><mo symmetric="false" stretchy="false">)</mo></math></span></li>
<li>Finally, the output is computed by summing the two intermediate variables: <span class="math inline"><math display="inline"><mi>d</mi><mo>=</mo><mi>b</mi><mo>+</mo><mi>c</mi></math></span>.</li>
</ul>
<p>Similarly, we can inspect the gradient function of <span class="math inline"><math display="inline"><mi>f</mi></math></span> given by <code>jax.grad(f)</code>:</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">describe_jaxpr</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">f</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> 
<span class="meta qualified-name python"><span class="meta generic-name python">Source</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">definition</span></span><span class="punctuation separator annotation variable python">:</span>
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">f</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">exp</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span><span class="keyword operator arithmetic python">-</span>
<span class="meta qualified-name python"><span class="meta generic-name python">The</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">has</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">following</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">inputs</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">represented</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="invalid illegal name python">as</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">function</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">has</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">following</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">constants</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">represented</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="invalid illegal name python">as</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">These</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">are</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">intermediate</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">variables</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">describe</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">by</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">equations</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">computed</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">along</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="variable other constant python">DAG</span></span><span class="punctuation separator annotation variable python">:</span> 
   <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">b</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">integer_pow</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">y<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation section dictionary-or-set end python">}</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">c</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">integer_pow</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary-or-set python"><span class="punctuation section dictionary-or-set begin python">{</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">y<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator key-value python">:</span> <span class="constant numeric integer decimal python">1</span><span class="punctuation section dictionary-or-set end python">}</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">d</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">mul</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span>, <span class="meta qualified-name python"><span class="meta generic-name python">c</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">e</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">exp</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">a</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="variable language python">_</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">add</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">b</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">e</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">f</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">mul</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span>, <span class="meta qualified-name python"><span class="meta generic-name python">e</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">6<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">mul</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span>, <span class="meta qualified-name python"><span class="meta generic-name python">d</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>
   <span class="constant numeric float python">7<span class="punctuation separator decimal python">.</span></span> <span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">Obtain</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">h</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">applying</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">the</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">primitive</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">add_any</span></span> <span class="meta statement with python"><span class="keyword control flow with python">with</span> <span class="meta qualified-name python"><span class="meta generic-name python">params</span></span> <span class="meta structure dictionary python"><span class="punctuation section dictionary begin python">{</span><span class="meta empty-dictionary python"><span class="punctuation section dictionary end python">}</span></span></span> <span class="meta qualified-name python"><span class="meta generic-name python">on</span></span> <span class="meta qualified-name python"><span class="support function builtin python">input</span></span><span class="keyword operator arithmetic python">/</span><span class="meta item-access python"><span class="meta qualified-name python"><span class="meta generic-name python">s</span></span></span> <span class="meta item-access python"><span class="punctuation section brackets begin python">[</span></span><span class="meta item-access arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">f</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">g</span></span></span><span class="meta item-access python"><span class="punctuation section brackets end python">]</span></span></span>

 <span class="meta qualified-name python"><span class="meta generic-name python">The</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">output</span></span> <span class="keyword operator logical python">is</span><span class="punctuation separator annotation variable python">:</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta qualified-name python"><span class="meta generic-name python">h</span></span><span class="punctuation section list end python">]</span></span>
</span></code></pre>
<p>Notice that the number of intermediate variables increases. For instance, you can
look at the equation described in (2) that is a primitive adding resulting from the
differentiation:  <span class="math inline"><math display="inline"><mi>∂</mi><mo>/</mo><mi>∂</mi><mi>x</mi><mo symmetric="false" stretchy="false">(</mo><msup><mi>x</mi><mn>2</mn></msup><mo symmetric="false" stretchy="false">)</mo><mo>→</mo><mn>2</mn><mi>x</mi></math></span>.</p>
<p>Further resources on automatic differentiation and JAX:</p>
<ol>
<li><a href="https://www.youtube.com/watch?v=wG_nF1awSS">What's automatic differentiation video</a>{target="_blank"}</li>
<li><a href="http://matpalm.com/blog/ymxb_pod_slice">JAX's tutorial by Mat Kelcey</a>{target="_blank"} showing more about
parallel computing using JAX</li>
<li><a href="http://videolectures.net/deeplearning2017_johnson_automatic_differentiation/">Automatic Differentiation, Deep Learning Summer School Montreal 2017 (Matthew Jonhson)</a>{target="_blank"}; another seminar about the topic <a href="https://www.youtube.com/watch?v=mVf3HJ6gND">JAX seminar</a>{target="_blank"}</li>
</ol>
<h3>3. Taylor Approximation with two variables</h3>
<p>Now we consider the setting when functions are multivariate:</p>
<p><div class="math display"><math display="block"><mi>f</mi><mo>:</mo><msup><mi mathvariant="normal">ℝ</mi><mi>D</mi></msup><mo>⟶</mo><mi mathvariant="normal">ℝ</mi></math></div>
<div class="math display"><math display="block"><mspace width="1em" /><mspace width="1em" /><mspace width="1em" /><mspace width="1em" /><mspace width="1em" /><mspace width="1em" /><mi>x</mi><mo>↦</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>,</mo><mspace width="1em" /><mi>x</mi><mo>∈</mo><msup><mi mathvariant="normal">ℝ</mi><mi>D</mi></msup></math></div></p>
<p>By definition 5.8 in MML, we have that a Taylor approximation of degree n is defined as:</p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mi>n</mi></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><munderover><mo movablelimits="false">∑</mo><mrow><mi>k</mi><mo>=</mo><mn>0</mn></mrow><mrow><mi>n</mi></mrow></munderover><mfrac><mrow><msubsup><mi>D</mi><mi>x</mi><mi>k</mi></msubsup><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mi>k</mi><mi>!</mi></mrow></mfrac><msup><mi>𝜹</mi><mi>k</mi></msup></math></div></p>
<p>The vector <span class="math inline"><math display="inline"><mi>𝜹</mi></math></span> represents a difference between <span class="math inline"><math display="inline"><mi>x</mi></math></span> and <span class="math inline"><math display="inline"><msub><mi>x</mi><mn>0</mn></msub></math></span>; the latter is a pivot-vector in which the approximation is around made.</p>
<p><span class="math inline"><math display="inline"><msubsup><mi>D</mi><mi>x</mi><mi>k</mi></msubsup></math></span> and <span class="math inline"><math display="inline"><msup><mi>𝝳</mi><mi>k</mi></msup></math></span> are tensors or k-dimensionl arrays.</p>
<br>
<blockquote class="twitter-tweet" data-theme="dark"><p lang="en" dir="ltr"><a href="https://twitter.com/hardmaru/status/1326054980134973442?s=21">November 10, 2020</a></blockquote> <script async src="https://platform.twitter.com/widgets.js" charset="utf-8"></script>
<br>
<p>If we have that <span class="math inline"><math display="inline"><mi>𝛅</mi><mo>∈</mo><msup><mi mathvariant="normal">ℝ</mi><mn>4</mn></msup></math></span>, we obtain <span class="math inline"><math display="inline"><msup><mi>𝛅</mi><mn>2</mn></msup><mo>:</mo><mo>=</mo><mi>𝛅</mi><mi>⨂</mi><mi>𝛅</mi><mo>=</mo><mi>𝛅</mi><msup><mi>𝛅</mi><mi>T</mi></msup><mo>∈</mo><msup><mi mathvariant="normal">ℝ</mi><mrow><mn>4</mn><mi>x</mi><mn>4</mn></mrow></msup></math></span></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">aarange</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">4</span></span><span class="punctuation section arguments end python">)</span></span>   <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> this is [0, 1, 2, 3]
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">eisum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">i,j<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">1</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">3</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">2</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">4</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">6</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">3</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">6</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">9</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">int32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p><span class="math inline"><math display="inline"><msup><mi>𝛅</mi><mn>3</mn></msup><mo>:</mo><mo>=</mo><mi>𝛅</mi><mi>⨂</mi><mi>𝛅</mi><mi>⨂</mi><mi>𝛅</mi><mo>∈</mo><msup><mi mathvariant="normal">ℝ</mi><mrow><mn>4</mn><mi>x</mi><mn>4</mn><mi>x</mi><mn>4</mn></mrow></msup></math></span></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">eisum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">i,j,k<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>

             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">1</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">2</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">3</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">2</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">4</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">3</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">9</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>

             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">2</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">4</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">4</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">8</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">12</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">12</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">18</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>

             <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">3</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">9</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">6</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">12</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">18</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
              <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric integer decimal python">0</span><span class="punctuation separator list python">,</span>  <span class="constant numeric integer decimal python">9</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">18</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">27</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">int32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>For instance, in the last 4x4x4 array, the last number computed is 64 by
<code>delta[3]*delta[3]*delta[3]</code> (4x4x4). Instead, the most lower-left element of the
third 4x4 array is 48 and you obtained it by <code>delta[2]*delta[3]*delta[3]</code> (3x4x4).</p>
<p>The Einstein Summation implemented in <code>jnp.einsum</code> is a notation that allow you to
represent a lot of array operations using index notation. Look this <a href="https://www.youtube.com/watch?v=pkVwUVEHmfI">video</a>{target="_blank"} for a detail explanation and the
<a href="https://numpy.org/doc/stable/reference/generated/numpy.einsum.html">documentation</a>{target="_blank"}.</p>
<p>Let's code the example 5.15, deriving at first manually and then use JAX to check
if we reach similar results.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">g</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> 
    <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Function used in the example 5.15 in MML<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>  
    <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span> <span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">3</span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">linspace</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric integer decimal python">5</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">5</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">50</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">linspace</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric integer decimal python">5</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">5</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">40</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">X</span></span>, <span class="meta qualified-name python"><span class="meta generic-name python">Y</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">np</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">meshgrid</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">y</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">Z</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">X</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">Y</span></span></span><span class="punctuation section arguments end python">)</span></span>

<span class="meta qualified-name python"><span class="meta generic-name python">fig</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">figure</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">figsize</span> <span class="keyword operator assignment python">=</span> <span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric float python">7<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator tuple python">,</span> <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>3</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">ax</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">fig</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">add_subplot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">projection</span><span class="keyword operator assignment python">=</span><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">3d<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">set_title</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted double python"><span class="punctuation definition string begin python">&quot;</span></span></span><span class="meta string python"><span class="string quoted double python">$g(x,y)=x^2+2xy+y^3$<span class="punctuation definition string end python">&quot;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">plot_surface</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">X</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">Y</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">Z</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">rstride</span> <span class="keyword operator assignment python">=</span> <span class="constant numeric integer decimal python">3</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">cstride</span> <span class="keyword operator assignment python">=</span> <span class="constant numeric integer decimal python">3</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">cmap</span> <span class="keyword operator assignment python">=</span> <span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">cividis<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span>
                <span class="variable parameter python">antialiased</span><span class="keyword operator assignment python">=</span><span class="constant language python">False</span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">alpha</span><span class="keyword operator assignment python">=</span><span class="constant numeric float python"><span class="punctuation separator decimal python">.</span>6</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">set_xlabel</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">x<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">set_ylabel</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">y<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">set_zlabel</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">z<span class="punctuation definition string end python">&#39;</span></span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span><span class="meta generic-name python">zaxis</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">set_major_locator</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">MultipleLocator</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">60</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">plt</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">subplots_adjust</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="variable parameter python">left</span><span class="keyword operator assignment python">=</span><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">ax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">view_init</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">15</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">45</span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/taylor_2var_example2.png">
</center>
<p>We will start with the first-order Taylor approximation, which gives us a plane.</p>
<p>We need <span class="math inline"><math display="inline"><mi>∂</mi><mi>g</mi><mo>/</mo><mi>∂</mi><mi>x</mi></math></span> and <span class="math inline"><math display="inline"><mi>∂</mi><mi>g</mi><mo>/</mo><mi>∂</mi><mi>y</mi></math></span> collect the gradient
into a vector (aka jacobian vector) and multiply by <span class="math inline"><math display="inline"><mi>𝛅</mi></math></span>.</p>
<p><div class="math display"><math display="block"><mi>∂</mi><mi>g</mi><mo>/</mo><mi>∂</mi><mi>x</mi><mo>=</mo><mn>2</mn><mi>x</mi><mo>+</mo><mn>2</mn><mi>y</mi></math></div>
<div class="math display"><math display="block"><mi>∂</mi><mi>g</mi><mo>/</mo><mi>∂</mi><mi>y</mi><mo>=</mo><mn>2</mn><mi>x</mi><mo>+</mo><mn>3</mn><msup><mi>y</mi><mn>2</mn></msup></math></div></p>
<p>Collect the partials into a vector:</p>
<p><div class="math display"><math display="block"><msubsup><mi>D</mi><mi>x</mi><mn>1</mn></msubsup><mo>=</mo><msub><mi>∇</mi><mrow><mi>x</mi><mo>,</mo><mi>y</mi></mrow></msub><mo>=</mo><mo symmetric="true" stretchy="true"minsize="1.2em" maxsize="1.2em">[</mo><mn>2</mn><mi>x</mi><mo>+</mo><mn>2</mn><mi>y</mi><mspace width="1em" /><mn>2</mn><mi>x</mi><mo>+</mo><mn>3</mn><msup><mi>y</mi><mn>2</mn></msup><mo symmetric="true" stretchy="true"minsize="1.2em" maxsize="1.2em">]</mo></math></div></p>
<p>Following the instruction of the example, we will approximate around <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo>,</mo><msub><mi>y</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mo symmetric="false" stretchy="false">(</mo><mn>1,2</mn><mo symmetric="false" stretchy="false">)</mo></math></span>.</p>
<p>Now we can evaluate all the expressions for completing the equation that describes the plane:</p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>1</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo>,</mo><msub><mi>y</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mo>+</mo><mfrac><mrow><msubsup><mi>D</mi><mi>x</mi><mn>1</mn></msubsup><mi>f</mi><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo>,</mo><msub><mi>y</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo></mrow><mrow><mn>1</mn><mi>!</mi></mrow></mfrac><msup><mi>𝜹</mi><mn>1</mn></msup></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">dg_dx</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span><span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> 
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Derivative of g() w.r.t x hand-coded<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
  <span class="keyword control flow return python">return</span> <span class="constant numeric integer decimal python">2</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">2</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span>
  
<span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">dg_dy</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span><span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span> 
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Derivative of g() w.r.t y hand-coded<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
  <span class="keyword control flow return python">return</span> <span class="constant numeric integer decimal python">2</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">3</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">g(1,2): <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">dg/dx(1,2): <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">dg_dx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">dg/dx(1,2): <span class="punctuation definition string end python">&#39;</span></span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="support type python">str</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">dg_dy</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator annotation variable python">:</span> <span class="constant numeric integer decimal python">13</span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta qualified-name python"><span class="meta generic-name python">dg</span></span><span class="keyword operator arithmetic python">/</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">dx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator annotation variable python">:</span> <span class="constant numeric integer decimal python">6</span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta qualified-name python"><span class="meta generic-name python">dg</span></span><span class="keyword operator arithmetic python">/</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">dx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric integer decimal python">1</span><span class="punctuation separator arguments python">,</span><span class="constant numeric integer decimal python">2</span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator annotation variable python">:</span> <span class="constant numeric integer decimal python">14</span>
</span></code></pre>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>1</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>13</mn><mo>+</mo><mo symmetric="false" stretchy="false">[</mo><mn>6</mn><mspace width="1em" /><mn>14</mn><mo symmetric="false" stretchy="false">]</mo><mrow><mo stretchy="true">[</mo><mtable class="menv-arraylike"><mtr><mtd><mi>x</mi><mo>−</mo><mn>1</mn></mtd></mtr><mtr><mtd><mi>y</mi><mo>−</mo><mn>2</mn></mtd></mtr></mtable><mo stretchy="true">]</mo></mrow></math></div></p>
<br>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>1</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>13</mn><mo>+</mo><mn>6</mn><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>−</mo><mn>1</mn><mo symmetric="false" stretchy="false">)</mo><mo>+</mo><mn>14</mn><mo symmetric="false" stretchy="false">(</mo><mi>y</mi><mo>−</mo><mn>2</mn><mo symmetric="false" stretchy="false">)</mo></math></div></p>
<br>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>1</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mn>6</mn><mi>x</mi><mo>+</mo><mn>14</mn><mi>y</mi><mo>−</mo><mn>21</mn></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">g_plane_approx</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
    <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Equation that describe the tangent plane at g(1,2)<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
    <span class="keyword control flow return python">return</span> <span class="constant numeric integer decimal python">6</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">14</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">-</span> <span class="constant numeric integer decimal python">21</span> 
</span></code></pre><center>
<img src="/img/taylor-post/linear_taylor_approx2.png">
</center>
<br>
<p>Similar to the 1D case, but now the line is a plane. You can notice that is a good
approximation at the very close neighbourhood of the point <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">(</mo><msub><mi>x</mi><mn>0</mn></msub><mo>,</mo><msub><mi>y</mi><mn>0</mn></msub><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><mo symmetric="false" stretchy="false">(</mo><mn>1,2</mn><mo symmetric="false" stretchy="false">)</mo></math></span>.
However, the plane fails to approximate the curvatures of <span class="math inline"><math display="inline"><mi>g</mi></math></span>.</p>
<p>Now with autodiff...how can we compute the jacobian vector? We can save all the
hand-coded derivatives using the function <code>jax.grad</code>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">grad</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric float python">6<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">14<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span>
</span></span></span></code></pre>
<p>There is another way to get the jacobian.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p><code>jax.jacfwd</code>'s name stands for jacobian forward and refers to the order that computes
the chain rule. We can use <code>jax.jacrev</code> to obtain the same results but traverse
the graph backwards. There is no concern about which one to use in this example
because the function g is straightforward in complexity. Still, it matters when
many variables are involved, and as a result, we get different shapes of the
jacobian matrix.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacrev</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>The argument <code>argnums</code> specified which with argument differentiate the function.
We give a tuple with the only two arguments of <code>g(x,y)</code>, <em>i.e. I want the full jacobian vector that has the gradient w.r.t. argument 0 (x) and argument 1 (y)</em>.</p>
<p>For example, lets compute the jacobian vector for <span class="math inline"><math display="inline"><msup><mi>x</mi><mn>2</mn></msup><mo>+</mo><mn>3</mn><mi>y</mi><mo>+</mo><msup><mi>z</mi><mn>2</mn></msup></math></span> and evaluate the
gradient <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">(</mo><mn>1.0</mn><mo>,</mo><mn>2.0</mn><mo>,</mo><mn>2.0</mn><mo symmetric="false" stretchy="false">)</mo></math></span>.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
            <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function inline python"><span class="storage type function inline python">lambda</span></span><span class="meta function inline python"><span class="meta function inline parameters python"> <span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">z</span></span><span class="punctuation section function begin python">:</span></span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">3</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta qualified-name python"><span class="meta generic-name python">z</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span><span class="punctuation separator arguments python">,</span> 
                                         <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">2</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span>
            </span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p><em>Note: <code>jax.asarray</code> collect all the derivatives in a single flat array.</em></p>
<p>How can we go further computing the Hessian?</p>
<p>We compute the second-order derivatives of <span class="math inline"><math display="inline"><mi>g</mi></math></span> and collect them into the <span class="math inline"><math display="inline"><mi>H</mi></math></span> matrix.</p>
<br>
<p><div class="math display"><math display="block"><mi>H</mi><mo>=</mo><mrow><mo stretchy="true">(</mo><mtable class="menv-arraylike"><mtr><mtd><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup><mi>g</mi></mrow><mrow><mi>∂</mi><msup><mi>x</mi><mn>2</mn></msup></mrow></mfrac><mo>=</mo><mn>2</mn></mtd><mtd><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup><mi>g</mi></mrow><mrow><mi>∂</mi><mi>x</mi><mi>y</mi></mrow></mfrac><mo>=</mo><mn>2</mn></mtd></mtr><mtr><mtd><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup><mi>g</mi></mrow><mrow><mi>∂</mi><mi>y</mi><mi>x</mi></mrow></mfrac><mo>=</mo><mn>2</mn></mtd><mtd><mfrac><mrow><msup><mi>∂</mi><mn>2</mn></msup><mi>g</mi></mrow><mrow><mi>∂</mi><msup><mi>y</mi><mn>2</mn></msup></mrow></mfrac><mo>=</mo><mn>6</mn><mi>y</mi></mtd></mtr></mtable><mo stretchy="true">)</mo></mrow></math></div></p>
<p>There are three constants except for the lower-right element of <span class="math inline"><math display="inline"><mi>H</mi></math></span>. We can compute <span class="math inline"><math display="inline"><mi>H</mi></math></span> with two passes of <code>jacfwd</code> and evaluate (1,2) to obtain the Hessian matrix.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">H</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta qualified-name python"><span class="meta generic-name python">H</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span>  <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
               <span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">12<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>There is multiple ways to compute the second Taylor's polynomial coefficient using the Hessian.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">-</span>  <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">trace</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>5</span> <span class="keyword operator arithmetic python">*</span> <span class="meta qualified-name python"><span class="meta generic-name python">H</span></span><span class="keyword operator matrix python">@</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">einsum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">i,j<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">6<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>5</span> <span class="keyword operator arithmetic python">*</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">einsum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">ij,i,j<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">H</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">6<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>Ok, now we will code a function to compute the Taylor approximation using the above
knowledge.</p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">quadratic_taylor_approx</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">FUN</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">approx</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">around_to</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Compute the quadratic taylor approximation for the set of points &#39;approx&#39; of a given FUN around the. point &#39;around_to&#39;<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span> <span class="keyword operator assignment python">=</span> <span class="meta qualified-name python"><span class="meta generic-name python">approx</span></span> <span class="keyword operator arithmetic python">-</span> <span class="meta qualified-name python"><span class="meta generic-name python">around_to</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Compute the Jacobian and the linear component
</span>  <span class="meta qualified-name python"><span class="meta generic-name python">J</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">around_to</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">linear_component</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">J</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">dot</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">delta</span><span class="punctuation accessor dot python">.</span><span class="variable other constant python">T</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Compute the Hessian and the qudractic component
</span>  <span class="meta qualified-name python"><span class="meta generic-name python">H</span></span> <span class="keyword operator assignment python">=</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">asarray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
                  <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python">
                            <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jax</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">jacfwd</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="variable other constant python">FUN</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> 
                            <span class="variable parameter python">argnums</span><span class="keyword operator assignment python">=</span><span class="meta group python"><span class="punctuation section group begin python">(</span><span class="constant numeric integer decimal python">0</span><span class="punctuation separator tuple python">,</span><span class="constant numeric integer decimal python">1</span><span class="punctuation section group end python">)</span></span>
                            </span><span class="punctuation section arguments end python">)</span></span><span class="meta function-call python"><span class="punctuation section arguments begin python">(</span></span><span class="meta function-call python"><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">around_to</span></span></span><span class="punctuation section arguments end python">)</span></span>
                  </span><span class="punctuation section arguments end python">)</span></span>
  <span class="meta qualified-name python"><span class="meta generic-name python">quadratic_component</span></span> <span class="keyword operator assignment python">=</span> <span class="constant numeric float python">0<span class="punctuation separator decimal python">.</span>5</span> <span class="keyword operator arithmetic python">*</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">einsum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">ij, ij-&gt;i<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> 
                                         <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">einsum</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta string python"><span class="string quoted single python"><span class="punctuation definition string begin python">&#39;</span></span></span><span class="meta string python"><span class="string quoted single python">ij,jk-&gt;ik<span class="punctuation definition string end python">&#39;</span></span></span><span class="punctuation separator arguments python">,</span> <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span><span class="punctuation separator arguments python">,</span><span class="meta qualified-name python"><span class="meta generic-name python">H</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> 
                                         <span class="meta qualified-name python"><span class="meta generic-name python">delta</span></span></span><span class="punctuation section arguments end python">)</span></span>
  <span class="keyword control flow return python">return</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">FUN</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">around_to</span></span></span><span class="punctuation section arguments end python">)</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta qualified-name python"><span class="meta generic-name python">linear_component</span></span> <span class="keyword operator arithmetic python">+</span> <span class="meta qualified-name python"><span class="meta generic-name python">quadratic_component</span></span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">quadratic_taylor_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span> <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span> <span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">13<span class="punctuation separator decimal python">.</span></span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">89<span class="punctuation separator decimal python">.</span></span><span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre><center>
<img src="/img/taylor-post/quadratic_taylor_approx2.png">
</center>
<p>The quadratic component (aka second order derivatives) gives us a better way to
approximate the curvature of <span class="math inline"><math display="inline"><mi>g</mi></math></span>.</p>
<p>Visually it looks ok, but we can use the closed-form expression for the quadratic
Taylor approximation around the point <span class="math inline"><math display="inline"><mo symmetric="false" stretchy="false">(</mo><mn>1</mn><mo>,</mo><mn>2</mn><mo symmetric="false" stretchy="false">)</mo></math></span> to verify if the function
<code>quadratic_taylor_approx</code> is doing its job.</p>
<p><em>Note: You can work out the closed-form expression from equation 5.180c in MML, and ignore the third-order partial derivatives.</em></p>
<p><div class="math display"><math display="block"><msub><mi>T</mi><mn>2</mn></msub><mo symmetric="false" stretchy="false">(</mo><mi>x</mi><mo>,</mo><mi>y</mi><mo symmetric="false" stretchy="false">)</mo><mo>=</mo><msup><mi>x</mi><mn>2</mn></msup><mo>+</mo><mn>6</mn><msup><mi>y</mi><mn>2</mn></msup><mo>−</mo><mn>12</mn><mi>y</mi><mo>+</mo><mn>2</mn><mi>x</mi><mi>y</mi><mo>+</mo><mn>8</mn></math></div></p>
<pre><code class="code lang-python"><span class="source python"><span class="meta function python"><span class="storage type function python">def</span> <span class="entity name function python"><span class="meta generic-name python">g_quadratic_approx</span></span></span><span class="meta function parameters python"><span class="punctuation section parameters begin python">(</span></span><span class="meta function parameters python"><span class="variable parameter python">x</span><span class="punctuation separator parameters python">,</span> <span class="variable parameter python">y</span><span class="punctuation section parameters end python">)</span></span><span class="meta function python"><span class="punctuation section function begin python">:</span></span>
  <span class="comment block documentation python"><span class="punctuation definition comment begin python">&quot;&quot;&quot;</span>Close-form expression for the quadratic taylor approx of g() around (1, 2)<span class="punctuation definition comment end python">&quot;&quot;&quot;</span></span>
  <span class="keyword control flow return python">return</span> <span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">6</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span><span class="keyword operator arithmetic python">*</span><span class="keyword operator arithmetic python">*</span><span class="constant numeric integer decimal python">2</span> <span class="keyword operator arithmetic python">-</span> <span class="constant numeric integer decimal python">12</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">2</span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">x</span></span><span class="keyword operator arithmetic python">*</span><span class="meta qualified-name python"><span class="meta generic-name python">y</span></span> <span class="keyword operator arithmetic python">+</span> <span class="constant numeric integer decimal python">8</span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="comment line number-sign python"><span class="punctuation definition comment python">#</span> Some cases to test
</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>7</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>8</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>3</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="constant numeric float python">10<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric integer decimal python">21</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span>1</span><span class="punctuation separator arguments python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>3</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="meta function-call python"><span class="meta qualified-name python"><span class="support function builtin python">print</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">g_quadratic_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>4</span><span class="punctuation separator arguments python">,</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>5</span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">13<span class="punctuation separator decimal python">.</span>0</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">42<span class="punctuation separator decimal python">.</span>0</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">94<span class="punctuation separator decimal python">.</span>46000000000001</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">17<span class="punctuation separator decimal python">.</span>659999999999997</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">2934<span class="punctuation separator decimal python">.</span>44</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">14<span class="punctuation separator decimal python">.</span>689999999999998</span>
<span class="keyword operator comparison python">&gt;</span> <span class="constant numeric float python">104<span class="punctuation separator decimal python">.</span>06</span>
</span></code></pre><pre><code class="code lang-python"><span class="source python"><span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">quadratic_taylor_approx</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta qualified-name python"><span class="meta generic-name python">g</span></span><span class="punctuation separator arguments python">,</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span> 
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">4<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>7</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>8</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>3</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">10<span class="punctuation separator decimal python">.</span>2</span><span class="punctuation separator list python">,</span> <span class="constant numeric integer decimal python">21</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">5<span class="punctuation separator decimal python">.</span>1</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>3</span><span class="punctuation section list end python">]</span></span><span class="punctuation separator list python">,</span>
                                      <span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">3<span class="punctuation separator decimal python">.</span>4</span><span class="punctuation separator list python">,</span> <span class="keyword operator arithmetic python">-</span><span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>5</span><span class="punctuation section list end python">]</span></span>
                                      <span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span><span class="punctuation separator arguments python">,</span> 
                        <span class="variable parameter python">around_to</span><span class="keyword operator assignment python">=</span><span class="meta function-call python"><span class="meta qualified-name python"><span class="meta generic-name python">jnp</span><span class="punctuation accessor dot python">.</span></span><span class="meta qualified-name python"><span class="variable function python">array</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span><span class="constant numeric float python">1<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation separator list python">,</span> <span class="constant numeric float python">2<span class="punctuation separator decimal python">.</span>0</span><span class="punctuation section list end python">]</span></span></span><span class="punctuation section arguments end python">)</span></span></span><span class="punctuation section arguments end python">)</span></span>
<span class="keyword operator comparison python">&gt;</span> <span class="meta function-call python"><span class="meta qualified-name python"><span class="variable function python">DeviceArray</span></span><span class="punctuation section arguments begin python">(</span><span class="meta function-call arguments python"><span class="meta structure list python"><span class="punctuation section list begin python">[</span>  <span class="constant numeric float python">13<span class="punctuation separator decimal python">.</span></span>      <span class="punctuation separator list python">,</span>   <span class="constant numeric float python">42<span class="punctuation separator decimal python">.</span></span>      <span class="punctuation separator list python">,</span>   <span class="constant numeric float python">94<span class="punctuation separator decimal python">.</span>46</span>    <span class="punctuation separator list python">,</span>   <span class="constant numeric float python">17<span class="punctuation separator decimal python">.</span>659998</span><span class="punctuation separator list python">,</span>
             <span class="constant numeric float python">2934<span class="punctuation separator decimal python">.</span>44</span>    <span class="punctuation separator list python">,</span>   <span class="constant numeric float python">14<span class="punctuation separator decimal python">.</span>690002</span><span class="punctuation separator list python">,</span>  <span class="constant numeric float python">104<span class="punctuation separator decimal python">.</span>05999</span> <span class="punctuation section list end python">]</span></span><span class="punctuation separator arguments python">,</span> <span class="variable parameter python">dtype</span><span class="keyword operator assignment python">=</span><span class="meta qualified-name python"><span class="meta generic-name python">float32</span></span></span><span class="punctuation section arguments end python">)</span></span>
</span></code></pre>
<p>The values are practically the same. There are some cases with
approximation error around the thousandth.</p>
]]></content>
  </entry>
  <entry>
    <title>Berkson&apos;s Paradox: why handsome people are such a jerk?</title>
    <link href="https://alkzar.cl/posts/berkson-s-paradox/"/>
    <id>https://alkzar.cl/posts/berkson-s-paradox/</id>
    <published>2021-02-14T00:00:00Z</published>
    <updated>2021-02-14T00:00:00Z</updated>
    <content type="html"><![CDATA[<p>It is not difficult to find a relationship between two variables when there is not.
We may observe a negative correlation between variables, but the actual correlation
between them is positive, or vice versa. Many times this phenomenon is explained
by Berkson's paradox.</p>
<p>When we are talking about correlation is simply refers to the measure of the
relationship between two variables. A <a href="https://brilliant.org/wiki/correlation/">correlation</a>
measures the grade in which two variables are linearly associated. It can take a
value that goes from -1 to 1, in which -1 means a perfect negative linear
relationship: more <span class="math inline"><math display="inline"><mi>X</mi></math></span> is associate with less <span class="math inline"><math display="inline"><mi>Y</mi></math></span>, or vice-versa. Conversely, a
correlation equal to 1 means that more <span class="math inline"><math display="inline"><mi>X</mi></math></span> is associated with more <span class="math inline"><math display="inline"><mi>Y</mi></math></span>. Finally,
a 0 correlation coefficient between <span class="math inline"><math display="inline"><mi>X</mi></math></span> and <span class="math inline"><math display="inline"><mi>Y</mi></math></span> is interpreted as no relationship at all.</p>
<p>It is essential to acknowledge the differences between what we observed and the
truth or what it is. In our daily lives, we are surrounded by plenty of scenarios
in which we don't think about the sample size or if groups are well represented
within the sample. The thing is that we look at the world through
our kaleidoscope, full of patterns, full of colours.</p>
<p>Imagine that you start thinking about the relationship between attractiveness and
sympathy...and you begin inferring by your own dates experiences and stories of
your fine selection of friends that handsome people are jerks; meanwhile,
enjoyable people are not so good looking.</p>
<p>Why it's so common to hear this? Are attractiveness and niceness put in conflict
by our creators? It is widespread to hear from the popular wisdom library such a
conundrum, but we don't have many reasons to believe any relationship between
these attributes. Indeed, both variables, in reality, correlate near to 0. But
what could it explain this apparent contradiction between the observed and the
"truth"?  Berkson's Paradox!</p>
<p>Let's say that <span class="math inline"><math display="inline"><mi>X</mi></math></span> = "Niceness" and <span class="math inline"><math display="inline"><mi>Y</mi></math></span> = "Attractiveness". Then we collect data from
many people, measuring these attributes, which you can visualize in the below animation.
Every attribute has a score from 0 to 100; the former means absent and the later
fullness of the attribute, respectively. Note that in the animation title, you
can see the correlation coefficient (<span class="math inline"><math display="inline"><mi>ρ</mi></math></span>).</p>
<!--
<center>
<img src="img/berksonParadox.gif" width="450" height="450"/>
</center>
-->
<center>
<img src="img/berksonParadox.gif"/>
</center>
<p>It's likely and honest to think that, on average, one restricts the potential
candidates in which we are interested in getting a date. You accept to date with
someone who is not so good-looking if their eloquent and lovely behaviour
compensates the absence of visual gracefulness. Otherwise, you could tolerate an
idiot in the way in which their handsomeness diminished its jerkiness fragrance.
As you can appreciate, from the above animation, your preferences remove from the
map the group of candidates who don't meet your minimum criteria. Let's call this
group "You would not date".</p>
<p>Now the ugly truth: another group would use similar arguments to avoid our perfect
combination of attractiveness and niceness, which we call "Would not date you".
This double mechanism of restrictions prevents us from "observed the world as it
is" and explain how we go from a correlation close to 0 between attractiveness and
niceness to a strong negative correlation of −0.7: <em>the more handsome you are,
the more jerkiness you spit out</em>. An illusion of our kaleidoscope that gives us the
impression of a false dichotomy between <span class="math inline"><math display="inline"><mi>X</mi></math></span> and <span class="math inline"><math display="inline"><mi>Y</mi></math></span>.</p>
<h3>Additional resources</h3>
<ul>
<li><a href="https://brilliant.org/wiki/berksons-paradox/">Read more about Berkson's paradox</a>{target="_blank"}</li>
<li>This example was mention in the book <a href="https://www.goodreads.com/book/show/48889983-calling-bullshit">"Calling Bullshit, The Art of Skepticism in a Data-Driven World"</a> and elaborate in <a href="https://www.goodreads.com/book/show/18693884-how-not-to-be-wrong?ac=1&amp;from_search=true&amp;qid=1QgwnLyEqF&amp;rank=1">"How Not to Be Wrong: The Power of Mathematical Thinking"</a></li>
<li><a href="https://www.youtube.com/watch?v=eSVg_DqPkNM">Avoiding the Piftalls of Selection Bias (Carl T. Bergstrom, 2021)</a>{target="_blank"}</li>
</ul>
<h3>Code</h3>
<p>The gif made by mounting three different plots created with ggplot2 (R) and using
the tool <a href="https://ezgif.com/maker">Animated Gif Maker</a>.</p>
<p>Here is the code if you want to reproduce the example:</p>
<pre><code class="code lang-r"><span class="source r"><span class="meta function-call r"><span class="support function r">library</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">ggplot2</span><span class="punctuation section parens end r">)</span></span>
<span class="meta function-call r"><span class="support function r">library</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">dplyr</span><span class="punctuation section parens end r">)</span></span>
<span class="meta function-call r"><span class="support function r">library</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">latex2exp</span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Multivariate normal dist parameters:
</span>mu <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">50</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">50</span></span><span class="punctuation section parens end r">)</span></span>
Sigma <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">matrix</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">200</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">4</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">8</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">200</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">2</span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Generate 10.000 data points
</span><span class="meta function-call r"><span class="support function r">set.seed</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">323</span></span><span class="punctuation section parens end r">)</span></span>
df <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">as.data.frame</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">MASS<span class="keyword other r">:</span><span class="keyword other r">:</span><span class="meta function-call r"><span class="variable function r">mvrnorm</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">5e3</span><span class="punctuation separator parameters r">,</span> mu<span class="punctuation separator parameters r">,</span> Sigma</span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Compute correlation between V1 and V2
</span>cor0 <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">round</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">cor</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">df<span class="keyword other r">$</span>V1<span class="punctuation separator parameters r">,</span> df<span class="keyword other r">$</span>V2</span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">2</span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Initial plot
</span>p0 <span class="keyword operator assignment r">&lt;-</span> df <span class="keyword operator other r">%&gt;%</span> 
  <span class="meta function-call r"><span class="variable function r">ggplot</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">geom_point</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1<span class="punctuation separator parameters r">,</span> V2</span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">alpha</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.3</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.2</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">color</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>steelblue<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_x_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">n.breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">5</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_y_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">coord_cartesian</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">xlim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">ylim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">clip</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>off<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>  
  <span class="meta function-call r"><span class="variable function r">labs</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Niceness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Attractiveness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">TeX</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">paste</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>$<span class="constant character escape r">\\</span>rho$ =<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="meta function-call r"><span class="support function r">as.character</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">cor0</span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">15</span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>JERK<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NICE<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NOT<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>HOT<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">theme_bw</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">base_size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">8</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">theme</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">plot.margin</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">margin</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>cm<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
        <span class="variable parameter r">plot.title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_text</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">hjust</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">0.5</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
        <span class="variable parameter r">panel.grid.major</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
        <span class="variable parameter r">panel.grid.minor</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Lower line:
</span><span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> ------------------------------------------------------------------------------
</span><span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Create a lower diagonal line and assign an identifier to each group (2 labels)
</span><span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> C: {LA=Lower Area, MA = Middle Area}
</span><span class="meta function r"><span class="entity name function r">lower_line</span> <span class="keyword operator assignment r">&lt;-</span> </span><span class="meta function r"><span class="keyword control r">function</span><span class="punctuation section parens begin r">(</span></span><span class="meta function r"><span class="meta function parameters r"><span class="variable parameter r">x</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">m</span><span class="keyword operator assignment r">=</span><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">1</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">b</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">85</span></span><span class="punctuation section parens end r">)</span></span> m <span class="keyword operator arithmetic r">*</span> x <span class="keyword operator arithmetic r">+</span> b

df <span class="keyword operator assignment r">&lt;-</span> df <span class="keyword operator other r">%&gt;%</span> 
  <span class="meta function-call r"><span class="variable function r">mutate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">C</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">case_when</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1</span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator comparison r">&lt;</span> V2 <span class="keyword other r">~</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>LA<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
                        <span class="constant language r">TRUE</span> <span class="keyword other r">~</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> 

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Compute cor1
</span>cor1 <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">round</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">cor</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">df<span class="meta item-access r"><span class="punctuation section brackets single begin r">[</span></span><span class="meta item-access r"><span class="meta item-access arguments r">df<span class="keyword other r">$</span>C <span class="keyword operator assignment r">=</span><span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span>, <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>V1<span class="punctuation definition string end r">&quot;</span></span></span></span><span class="meta item-access r"><span class="punctuation section brackets single end r">]</span></span><span class="punctuation separator parameters r">,</span> df<span class="meta item-access r"><span class="punctuation section brackets single begin r">[</span></span><span class="meta item-access r"><span class="meta item-access arguments r">df<span class="keyword other r">$</span>C <span class="keyword operator assignment r">=</span><span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span>, <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>V2<span class="punctuation definition string end r">&quot;</span></span></span></span><span class="meta item-access r"><span class="punctuation section brackets single end r">]</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">2</span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Create p1
</span>p1 <span class="keyword operator assignment r">&lt;-</span> df <span class="keyword operator other r">%&gt;%</span> 
  <span class="meta function-call r"><span class="variable function r">ggplot</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">geom_point</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1<span class="punctuation separator parameters r">,</span> V2<span class="punctuation separator parameters r">,</span> <span class="variable parameter r">colour</span> <span class="keyword operator assignment r">=</span> C</span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">alpha</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.27</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.2</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_x_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_y_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">coord_cartesian</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">xlim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">ylim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">clip</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>off<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_colour_manual</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">guide</span> <span class="keyword operator assignment r">=</span> <span class="constant language r">FALSE</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">values</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>steelblue<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>grey45<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">labs</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Niceness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Attractiveness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">colour</span> <span class="keyword operator assignment r">=</span> <span class="constant language r">NULL</span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">TeX</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">paste</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>$<span class="constant character escape r">\\</span>rho$ =<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="meta function-call r"><span class="support function r">as.character</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">cor1</span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">15</span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>JERK<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NICE<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NOT<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>HOT<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>YOU WOULD<span class="constant character escape r">\n</span>NOT DATE<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">20</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">20</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">3.5</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span> 
  <span class="meta function-call r"><span class="variable function r">geom_segment</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span><span class="keyword operator assignment r">=</span><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">5</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">xend</span><span class="keyword operator assignment r">=</span><span class="constant numeric float decimal r">90</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">yend</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">theme_bw</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">base_size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">8</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span> 
  <span class="meta function-call r"><span class="variable function r">theme</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">plot.margin</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">margin</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>cm<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">plot.title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_text</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">hjust</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">0.5</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">panel.grid.major</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">panel.grid.minor</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Upper line:
</span><span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> ------------------------------------------------------------------------------
</span><span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Create a upper diagonal line and assign an identifier (3 labels)
</span><span class="meta function r"><span class="entity name function r">upper_line</span> <span class="keyword operator assignment r">&lt;-</span> </span><span class="meta function r"><span class="keyword control r">function</span><span class="punctuation section parens begin r">(</span></span><span class="meta function r"><span class="meta function parameters r"><span class="variable parameter r">x</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">m</span><span class="keyword operator assignment r">=</span><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">1</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">b</span><span class="keyword operator assignment r">=</span><span class="constant numeric float decimal r">115</span></span><span class="punctuation section parens end r">)</span></span> b <span class="keyword operator arithmetic r">+</span> x<span class="keyword operator arithmetic r">*</span>m


df <span class="keyword operator assignment r">&lt;-</span> df <span class="keyword operator other r">%&gt;%</span> 
  <span class="meta function-call r"><span class="variable function r">mutate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">C</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">case_when</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="punctuation section parens begin r">(</span><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1</span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator comparison r">&lt;</span><span class="keyword operator assignment r">=</span> V2<span class="punctuation section parens end r">)</span> <span class="keyword operator logical r">&amp;</span> <span class="punctuation section parens begin r">(</span>V2 <span class="keyword operator comparison r">&lt;</span> <span class="meta function-call r"><span class="variable function r">upper_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1</span><span class="punctuation section parens end r">)</span></span><span class="punctuation section parens end r">)</span> <span class="keyword other r">~</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
                       <span class="meta function-call r"><span class="variable function r">upper_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1</span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator comparison r">&gt;</span><span class="keyword operator assignment r">=</span> V2 <span class="keyword other r">~</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>UA<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
                       <span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1</span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator comparison r">&lt;</span> V2 <span class="keyword other r">~</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>LA<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Compute cor2
</span>cor2 <span class="keyword operator assignment r">&lt;-</span> <span class="meta function-call r"><span class="support function r">round</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">cor</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">df<span class="meta item-access r"><span class="punctuation section brackets single begin r">[</span></span><span class="meta item-access r"><span class="meta item-access arguments r">df<span class="keyword other r">$</span>C <span class="keyword operator assignment r">=</span><span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span>, <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>V1<span class="punctuation definition string end r">&quot;</span></span></span></span><span class="meta item-access r"><span class="punctuation section brackets single end r">]</span></span><span class="punctuation separator parameters r">,</span> df<span class="meta item-access r"><span class="punctuation section brackets single begin r">[</span></span><span class="meta item-access r"><span class="meta item-access arguments r">df<span class="keyword other r">$</span>C <span class="keyword operator assignment r">=</span><span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>MA<span class="punctuation definition string end r">&quot;</span></span>, <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>V2<span class="punctuation definition string end r">&quot;</span></span></span></span><span class="meta item-access r"><span class="punctuation section brackets single end r">]</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">2</span></span><span class="punctuation section parens end r">)</span></span>

<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Create p2
</span>p2 <span class="keyword operator assignment r">&lt;-</span> df <span class="keyword operator other r">%&gt;%</span> 
  <span class="meta function-call r"><span class="variable function r">ggplot</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">geom_point</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r">V1<span class="punctuation separator parameters r">,</span> V2<span class="punctuation separator parameters r">,</span> <span class="variable parameter r">colour</span> <span class="keyword operator assignment r">=</span> C</span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">alpha</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.27</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">.2</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_x_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_y_continuous</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">breaks</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">seq</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">20</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">coord_cartesian</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">xlim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">ylim</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">0</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">100</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">clip</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>off<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">scale_colour_manual</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">guide</span> <span class="keyword operator assignment r">=</span> <span class="constant language r">FALSE</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">values</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>grey45<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>steelblue<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>grey45<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">labs</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Niceness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>Attractiveness<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">colour</span> <span class="keyword operator assignment r">=</span> <span class="constant language r">NULL</span><span class="punctuation separator parameters r">,</span>
       <span class="variable parameter r">title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">TeX</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="support function r">paste</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>$<span class="constant character escape r">\\</span>rho$ =<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="meta function-call r"><span class="support function r">as.character</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r">cor2</span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">15</span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>JERK<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NICE<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">geom</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
           <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="support function r">c</span>(</span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>NOT<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>HOT<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">4</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>YOU WOULD<span class="constant character escape r">\n</span>NOT DATE<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">20</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">20</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">3.5</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span> 
  <span class="meta function-call r"><span class="variable function r">geom_segment</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span><span class="keyword operator assignment r">=</span><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="keyword operator arithmetic r">-</span><span class="constant numeric float decimal r">5</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">xend</span><span class="keyword operator assignment r">=</span><span class="constant numeric float decimal r">90</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">yend</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">lower_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">90</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">annotate</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>text<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">label</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>WOULD NOT<span class="constant character escape r">\n</span>DATE YOU<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">x</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">80</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">80</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">3.5</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span> 
  <span class="meta function-call r"><span class="variable function r">geom_segment</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="meta function-call r"><span class="variable function r">aes</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">x</span><span class="keyword operator assignment r">=</span><span class="constant numeric float decimal r">10</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">y</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">upper_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">10</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">xend</span><span class="keyword operator assignment r">=</span><span class="constant numeric float decimal r">105</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">yend</span><span class="keyword operator assignment r">=</span><span class="meta function-call r"><span class="variable function r">upper_line</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">105</span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">theme_bw</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">base_size</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">8</span></span><span class="punctuation section parens end r">)</span></span> <span class="keyword operator arithmetic r">+</span>
  <span class="meta function-call r"><span class="variable function r">theme</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">plot.margin</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">margin</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="constant numeric float decimal r">1.5</span><span class="punctuation separator parameters r">,</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>cm<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">plot.title</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_text</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="variable parameter r">hjust</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">0.5</span></span><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">panel.grid.major</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span><span class="punctuation separator parameters r">,</span>
      <span class="variable parameter r">panel.grid.minor</span> <span class="keyword operator assignment r">=</span> <span class="meta function-call r"><span class="variable function r">element_blank</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="punctuation section parens end r">)</span></span></span><span class="punctuation section parens end r">)</span></span>


<span class="comment line number-sign r"><span class="punctuation definition comment r">#</span> Save plots
</span><span class="meta function-call r"><span class="variable function r">ggsave</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>/Users/YourUser/Desktop/bp0.png<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> p0<span class="punctuation separator parameters r">,</span> <span class="variable parameter r">width</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">height</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">dpi</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>retina<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span>
<span class="meta function-call r"><span class="variable function r">ggsave</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>/Users/YourUser/Desktop/bp1.png<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> p1<span class="punctuation separator parameters r">,</span> <span class="variable parameter r">width</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">height</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">dpi</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>retina<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span>
<span class="meta function-call r"><span class="variable function r">ggsave</span><span class="punctuation section parens begin r">(</span></span><span class="meta function-call r"><span class="meta function-call parameters r"><span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>/Users/YourUser/Desktop/bp2.png<span class="punctuation definition string end r">&quot;</span></span><span class="punctuation separator parameters r">,</span> p2<span class="punctuation separator parameters r">,</span> <span class="variable parameter r">width</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">height</span> <span class="keyword operator assignment r">=</span> <span class="constant numeric float decimal r">7.5</span><span class="punctuation separator parameters r">,</span> <span class="variable parameter r">dpi</span> <span class="keyword operator assignment r">=</span> <span class="string quoted double r"><span class="punctuation definition string begin r">&quot;</span>retina<span class="punctuation definition string end r">&quot;</span></span></span><span class="punctuation section parens end r">)</span></span>
</span></code></pre>]]></content>
  </entry>
  <entry>
    <title>Crear entornos virtuales en python y asociarlos a un kernel de jupyter notebook</title>
    <link href="https://alkzar.cl/posts/crear-entornos-virtuales-en-python-y-asociarlo-a-un-kernel-de-jupyter-notebook/"/>
    <id>https://alkzar.cl/posts/crear-entornos-virtuales-en-python-y-asociarlo-a-un-kernel-de-jupyter-notebook/</id>
    <published>2020-12-09T00:00:00Z</published>
    <updated>2020-12-09T00:00:00Z</updated>
    <content type="html"><![CDATA[<p>En este post se detalla el <em>workflow</em> para crear entornos virtuales con python, instalar paquetes y
utilizarlos con <em>jupyter notebooks</em>. Mayor información acerca de entornos virtuales en
la <a href="https://docs.python.org/3/tutorial/venv.html">documentación oficial</a>.</p>
<p>Las siguientes dos secciones se encuentran resumidas en la secuencia de comandos:</p>
<p><img src="/posts/crear-entornos-virtuales-en-python-y-asociarlo-a-un-kernel-de-jupyter-notebook/img/venvjup.gif#center" alt="" /></p>
<h3>Crear un entorno virtual</h3>
<ol>
<li>
<p>Creamos un directorio para el proyecto y dentro de el iniciamos un entorno virtual con: <code>python3 -m venv &lt;name&gt;</code>.</p>
</li>
<li>
<p>Activamos el entorno con <code>source &lt;name&gt;/bin/activate</code>. Los comandos <code>which python</code> y
<code>which pip</code> se pueden utilizar para constatar bajo cuál entorno estamos actuando, si aparece
la ruta del proyecto significa que estamos trabajando con el entorno virtual que creamos.
Un detalle visual cuando se tiene el entorno activado es que la ruta de trabajo indicada en el
terminal comienza con  <code>(&lt;name&gt;) ~/...</code>. Si es necesario desactivar el entorno, basta utilizar el
comando <code>deactivate</code> y se dejará también de ver el prefijo especificado en la ruta del terminal.</p>
</li>
<li>
<p>Para instalar paquetes simplemente utilizamos el comando <code>pip install &lt;package_name&gt;</code>. Es
posible utiizar un archivo de texto con el listado de los paquetes y la versión especifica
requerida e instalarlos utilizando <code>pip install -r requirements.txt</code>. El comando
<code>pip list</code> nos sirve para inspeccionar los paquetes instalados.</p>
</li>
</ol>
<h3>Crear un kernel en jupyter notebook</h3>
<ol start="4">
<li>
<p>Instalamos el paquete para trabajar con jupyter notebooks y crear un kernel:<br />
<code>pip install ipykernel</code>.
Importante, una vez terminada la instalación debemos desactivar el entorno (<code>deactivate</code>) y volver
activarlo antes de seguir. En mi caso, si continuó sin reiniciar el entorno tengo problemas
con asociar el kernel al entorno virtual, lo que a su vez significa problemas para cargar los
paquetes instalados.</p>
</li>
<li>
<p>Ahora creamos el kernel <code>ipython kernel install --user --name=&lt;name&gt;</code>. Importante utilizar el
mismo nombre del entorno virtual que utilizamos en el paso 1 (buena práctica). Si queremos desintalar algún kernel: <code>jupyter kernelspec uninstall &lt;name&gt;</code>.</p>
</li>
<li>
<p>Se pueden inspeccionar todos los kernel habilitados con el comando: <code>jupyter kernelspec list</code>. Debería aparecer en la lista el nuevo kernel creado en 6.</p>
</li>
<li>
<p>Finalmente al iniciar <code>jupyter lab</code> o <code>jupyter notebook</code> veremos que al
crear un nuevo notebook estará el kernel creado en el paso 5. Si utilizamos esta opción,
estarán disponibles todos los paquetes instaldos en el ambiente virtual.</p>
</li>
</ol>
]]></content>
  </entry>
  <entry>
    <title>El semestre faltante en tu educación de CS</title>
    <link href="https://alkzar.cl/posts/el-semestre-faltante-en-tu-educación-de-cs/"/>
    <id>https://alkzar.cl/posts/el-semestre-faltante-en-tu-educación-de-cs/</id>
    <published>2020-07-17T00:00:00Z</published>
    <updated>2020-07-17T00:00:00Z</updated>
    <content type="html"><![CDATA[<blockquote>
<p>Las clases te enseñan de todo sobre temas avanzados de CS, desde sistemas operativos hasta machine learning, sin embargo, hay un tema crítico que rara vez se trata y a su vez es dejado para que los estudiantes indaguen por sus propios medios: el cómo ser productivos con sus herramientas. Te enseñaremos a dominar la interfaz de línea de comandos, a usar un poderoso editor de texto, a utilizar las características sofisticadas de los sistemas de control de versiones y mucho más!</p>
</blockquote>
<p>Durante este último tiempo he estado revisando y estudiando el material del curso <a href="https://missing.csail.mit.edu/">"The Missing Semester of Your CS Education"</a>.
Este curso ya lleva dos versiones de vida en el MIT, y es impartido por el grupo de
instructores <a href="https://www.anishathalye.com/">Anish</a> , <a href="https://thesquareplanet.com/">Jon</a>,
y <a href="http://josejg.com/">Jose</a>. El objetivo del curso es el descrito en el
primer párrafo, y puedes saber más acerca de la <a href="https://missing.csail.mit.edu/about/">motivación detrás del curso acá</a>,
pero en resumen es darte herramientas para que seas más versado utilizando tú
ecosistema computacional: desde navegar por tú sistema utilizando la consola,
usar un editor de texto de manera eficiente, o saber qué es y cómo funciona
Git.</p>
<p>En la página web se puede encontrar un video por cada una de las clases, así como
los apuntes y notas de estas. Al final de las notas hay ejercicios para practicar
e investigar lo visto en la clase.</p>
<h1>Versión en español</h1>
<p>A modo de seguir la clase, y para contribuir en compartir el material
en español, he estado traduciendo las notas de cada clase. Por ahora, solo hay
dos notas traducidas completamente, pero a medida que vaya traduciendo las faltantes
las iré liberando. Se puede acceder al material en español desde el sitio oficial
del curso, en la sección <em>Translations</em>, o directamente desde el siguiente link:
https://missing-semester-esp.github.io/.</p>
<p>Cualquier contribución en la traducción es bienvenida, el link del
proyecto en github <a href="https://github.com/missing-semester-esp/missing-semester-esp.github.io">aquí</a>.
Si quieres editar las notas de una clase específica, puedes ir al final de cada
página, y ahí encontrarás el link "Editar esta página", el que te llevará directamente
al archivo markdown con el contenido de la clase en el repositorio del proyecto.</p>
]]></content>
  </entry>
</feed>
