Introduction
In the previous post, we talked about maximum entropy inverse reinforcement learning(MaxEnt), and introduce a practical sample-based algorithm named guided cost learning(GCL) that allows us to tackle high-dimensional state and action spaces and nonlinear reward functions. In GCL, we saw that the reward function learning in fact competed against policy learning, which makes them resemble GANs. In this post, we will build a connection between MaxEnt IRL and GANs, and introduce an algorithm called generalized adversarial network guided cost learning (GAN-GCL), which employs GANs’ training process to theoretically achieving the same behavior as guided cost learning(GCL).
Preliminaries
Maximum Entropy Inverse Reinforcement Learning
In MaxEnt IRL, we have (near-)optimal trajectories \({\tau_i}\) sampled from some underlying trajectory policy distribution \(p(\tau)\), which is generally unknown. We train a reward function so that the log likelihood of the demonstrations is maximized when the agent acts optimally according to the reward function
\[\begin{align} \mathcal J_r(\psi)=\mathbb E_{\tau\sim p}[\log p_\psi(\tau)]&=\mathbb E_{\tau\sim p}[r_\psi(\tau)]-\log Z\\\ \mathrm{where}\quad Z&=\mathbb E_{\tau\sim \mu}\left[{\exp(r_\psi(\tau))\over {1\over 2}\tilde p(\tau)+{1\over 2}q_\theta(\tau)}\right] \end{align}\]where \(Z\) is the partition function estimated using the sampling distribution \(\mu={1\over 2}p+{1\over 2}q\), \(\tilde p(\tau)\) in the sampling weights is a rough estimate for the expert trajectory distribution \(p(\tau)\), and \(q_\theta(\tau)\) (which is denoted as \(\pi\) in our previous post) is trained to be the soft optimal policy w.r.t. the current reward function by maximizing the objective
\[\begin{align} \mathcal J_q(\theta)=\mathbb E_{\tau\sim q_\theta}[r_\psi(\tau)+\mathcal H(q_\theta(\tau))] \end{align}\]To align with GANs, we negate the objective so that we can perform minimization as we do in GANs. Therefore, we have
\[\begin{align} \mathcal L_r(\psi)=-\mathbb E_{\tau\sim p}[\log p_\psi(\tau)]&=-\mathbb E_{\tau\sim p}[r_\psi(\tau)]+\log Z\tag {1}\\\ \mathrm{where}\quad Z&=\mathbb E_{\tau\sim \mu}\left[{\exp(r_\psi(\tau))\over {1\over 2}\tilde p(\tau)+{1\over 2}q_\theta(\tau)}\right]\\\ \mathcal L_q(\theta)&=\mathbb E_{\tau\sim q_\theta}[-r_\psi(\tau)+\log(q_\theta(\tau))]\tag {2} \end{align}\]Generative Adversarial Networks
In GANs, there are two models trained simultaneously: a generator \(G\) and a discriminator \(D\). The discriminator is responsible for classifying its inputs as either the output of the generator or real samples from the underlying distribution \(p(x)\), which is generally unknown. The generator, on the other hand, aims to produce outputs that are real enough to “fool” the discriminator.
Formally, we have the loss for the discriminator
\[\begin{align} \mathcal L_{D}(\psi)=\mathbb E_{x\sim p}[-\log D_\psi(x)]+\mathbb E_{x\sim G_\theta}[-\log(1-D_\psi(x))]\tag {3} \end{align}\]and the loss for the generator
\[\begin{align} \mathcal L_G(\theta)=\mathbb E_{x\sim G_\theta}[-\log D_\psi(x)]+\mathbb E_{x\sim G_\theta}[\log (1-D_\psi(x))]\tag{4} \end{align}\]where the second term is the opposite of the discriminator loss, and the first term is the log of the discriminator confusion, which is introduced as an improvement in the initial paper of GANs.
For the sake of consistency, we use \(q_\theta\) rather than \(G_\theta\) in the rest of the post.
Connection between MaxEnt IRL and GANs
For a fixed generator with density \(q(\tau)\), the optimal discriminator is
\[\begin{align} D^\*(\tau)={p(\tau)\over p(\tau)+q(\tau)} \end{align}\]where \(p(\tau)\) is the distribution of the real data.
When \(q(\tau)\) could be evaluated, the traditional GAN discriminator can be modified to incorporate this density information. Instead of having the discriminator estimate the probability that its input is real, we can use the discriminator to estimate the distribution of the real data \(p(\tau)\), filling in the value of \(q(\tau)\) with its known value. In that case, the new form of the discriminator \(D_\psi\) is
\[\begin{align} D_\psi(\tau)={\tilde p_\psi(\tau)\over\tilde p_\psi(\tau)+q_\theta(\tau)} \end{align}\]This modest change allows the optimal discriminator to be completely independent of the generator: the discriminator is optimal when \(\tilde p_\psi(\tau)=p(\tau)\). Independence between the generator and the optimal discriminator may significantly improve the stability of training.[2]
In order to build a connection to MaxEnt IRL, we replace the expert trajectory distribution \(\tilde p_\psi(\tau)\) with \({1\over Z}\exp(r_\psi(\tau))\) and have
\[\begin{align} D_\psi(\tau)&={ {1\over Z}\exp(r_\psi(\tau))\over {1\over Z}\exp(r_\psi(\tau))+q_\theta(\tau)}\\\ &={\exp\big(r_\psi(\tau)-\log Z-\log q_\theta(\tau)\big)\over\exp\big(r_\psi(\tau)-\log Z-\log q_\theta(\tau)\big)+1}\\\ &=\sigma\big(r_\psi(\tau)-\log Z-\log q_\theta(\tau)\big) \end{align}\]This gives us a nice architecture of the discriminator: a simple binary classification with a sigmoid as the final layer. The only adjustment here is to subtract \(\log Z\) and \(\log q_\theta(\tau)\) from the input to the sigmoid.
Sticking the discriminator back into the discriminator loss, we obtain:
\[\begin{align} \mathcal L_{D}(\psi)=&\mathbb E_{\tau\sim p}[-\log D_\psi(\tau)]+\mathbb E_{\tau\sim G_\theta}[-\log(1-D_\psi(\tau))]\\\ =&\mathbb E_{\tau\sim p}\left[-\log { {1\over Z}\exp(r_\psi(\tau))\over 2\tilde\mu(\tau)}\right]+\mathbb E_{\tau\sim q_\theta}\left[-\log{q_\theta(\tau)\over 2\tilde\mu(\tau)}\right]\\\ =&\mathbb E_{\tau\sim p}[-r_\psi(\tau)+\log Z]+2\mathbb E_{\tau\sim\mu}[2\tilde \mu(\tau)] \end{align}\]where we replace \({1\over 2}\left({1\over Z}\exp(r_\psi(\tau))+q_\theta(\tau)\right)\) with \(\tilde\mu(\tau)\), and \({1\over 2}(p+q)\) with \(\mu\) to simplify notation.
Similarly, we derive the generator loss
\[\begin{align} \mathcal L_G(\theta)=&\mathbb E_{x\sim q_\theta}[-\log D_\psi(x)]+\mathbb E_{x\sim q_\theta}[\log (1-D_\psi(x))]\\\ =&\mathbb E_{x\sim q_\theta}\left[-\log { {1\over Z}\exp(r_\psi(\tau))\over 2\tilde\mu(\tau)}+\log{q_\theta(\tau)\over 2\tilde \mu(\tau)}\right]\\\ =&\log Z+\mathbb E_{\tau\sim q_\theta}\left[\log q_\theta(\tau)-r_\psi(\tau)\right] \end{align}\]For better reference, we rewrite the losses for GANs and MaxEnt RL bellow
\[\begin{align} \mathcal L_r(\psi)=&-\mathbb E_{\tau\sim p}[r_\psi(\tau)]+\log Z\quad \mathrm{where}\ Z=\mathbb E_{\tau\sim \mu}\left[{\exp(r_\psi(\tau))\over \tilde \mu(\tau)}\right]\tag 1\\\ \mathcal L_q(\theta)=&\mathbb E_{\tau\sim q_\theta}[-r_\psi(\tau)+\log(q_\theta(\tau))]\tag 2\\\ \mathcal L_{D}(\psi)=&\mathbb E_{\tau\sim p}[-r_\psi(\tau)+\log Z]+2\mathbb E_{\tau\sim\mu}[2\tilde \mu(\tau)]\tag 3\\\ \mathcal L_G(\theta)=&\log Z+\mathbb E_{\tau\sim q_\theta}\left[\log q_\theta(\tau)-r_\psi(\tau)\right]\tag 4 \end{align}\]Note that \(Z\) in Equation \((4)\) is constant. Therefore, optimizing Equation \((2)\) is equivalent to optimizing Equation \((4)\)
Next, we establish the equivalence between the MaxEnt IRL loss and the discriminator loss by proving the following facts
- The value of \(Z\) that minimizes the discriminator loss is an importance-sampling estimator for the partition function defined in Eq.\((1)\).
- For this value of \(Z\), the derivative of the discriminator’s loss w.r.t. \(\psi\) is equal to the derivative of the MaxEnt IRL objective.
\(Z\) estimates the partition function
We first show that minimizing \(\mathcal L_D(\psi)\) w.r.t. \(Z\) in Eq.\((3)\) gives us the exactly importance-sampling estimator for the partition function in Eq.\((1)\):
\[\begin{align} \partial_Z\mathcal L_D(\psi)&=0\\\ \partial_Z\Big(\log Z+2\mathbb E_{\tau\sim \mu}[\log\tilde\mu(\tau)]\Big)&=0\\\ {1\over Z}&=\mathbb E_{\tau\sim\mu}\left[{ {1\over Z^2}\exp(r_\psi(\tau))\over \tilde \mu(\tau)}\right]\\\ Z&=\mathbb E_{\tau\sim \mu}\left[{\exp(r_\psi(\tau))\over \tilde \mu(\tau)}\right] \end{align}\]This nice result suggests that now we may no longer need to estimate the importance weights anymore. Instead, we can directly estimate the partition function by minimizing the discriminator’s loss w.r.t. \(Z\).
The discriminator has the same derivative as the MaxEnt IRL objective
Consider the derivative of the discriminator loss w.r.t. parameters \(\psi\)
\[\begin{align} \partial_\psi\mathcal L_D(\psi) &=\mathbb E_{\tau\sim p}[-\partial_\psi r_\psi(\tau)]+2\partial_\psi\mathbb E_{\tau\sim \mu}[\log\tilde\mu(\tau)]\\\ &=\mathbb E_{\tau\sim p}[-\partial_\psi r_\psi(\tau)]+\mathbb E_{\tau\sim\mu}\left[{ {1\over Z}\exp(r_\psi(\tau))\partial_\psi r_\psi(\tau)\over\tilde\mu(\tau)}\right]\\\ &=-\mathbb E_{\tau\sim p}[\partial_\psi r_\psi(\tau)]+{1\over Z}\mathbb E_{\tau\sim\mu}\left[{\exp(r_\psi(\tau))\partial_\psi r_\psi(\tau)\over\tilde\mu(\tau)}\right]\\\ &=\partial_\psi\mathcal L_r(\psi) \end{align}\]where in the first step we omit terms irrelevant to \(\psi\). The last equality holds because \(Z\) in the discriminator estimates the partition function in the MaxEnt IRL objective and we do not differentiate through the importance weights when we compute the derivative of the MaxEnt IRL objective. This result suggests that applying gradient descent to the discriminator in GAN is equal to doing the same to the MaxEnt IRL objective.
References
- CS 294-112 at UC Berkeley. Deep Reinforcement Learning Lecture 16
- Chelsea Finn et al. A Connection Between Generative Adversarial Networks, Inverse Reinforcement Learning, and Energy-Based Models
- Ian J. Goodfellow et al. Generative Adversarial Nets.
- Andrew Ng et al. Policy Invariance under Reward Transformations: Theory and Application to Reward Shaping