When I first learnt about GANs (generative adversarial networks)1 I followed the “alternative” objective (which I will refer to as $G_{alt}$), which is the most common GAN objective found in the wild at the time of writing. You can see an example of it in DCGAN2, which is available on GitHub.
$G_{alt}$ corresponds to the following update steps:
nn.BCECriterion
in Torch) with target 1 for “real” images and 0 for
“fake” images.With the Power of Mathematics™ we can express the loss functions used in the above update steps.
Let
Discriminator update, “real” examples $x\sim X_{data}$:
$$ \mathcal{L}_{real}(x) =\mathcal{L} _{BCE}(\sigma(V(x;\omega)), 1) =-\log(\sigma(V(x;\omega))) $$
Discriminator update, “fake” examples $x\sim G(Z;\theta)$:
$$ \mathcal{L}_{fake}(x) =\mathcal{L} _{BCE}(\sigma(V(x;\omega)), 0) =-\log(1-\sigma(V(x;\omega))) $$
Generator update:
$$ \mathcal{L}_{gen}(z) =\mathcal{L} _{BCE}(\sigma(V(G(z;\theta);\omega)), 1) =-\log(\sigma(V(G(z;\theta);\omega))) $$
Imagine that $q(x)$ is the true probability distribution over all images, and that $p(x)$ is our approximation. As we train our GAN, the approximation $p(x)$ becomes closer to $q(x)$. There are multiple ways of measuring the distance of one probability distribution from another, and functions for doing so are called f-divergences. Prominent examples of f-divergences include KL divergence and JS divergence.
Somewhere in our practical formulation of the GAN objective we have implicitly specified a divergence to be minimised. This wouldn’t matter very much if our model had the capacity to model $q(x)$ perfectly, since the minimum would be achieved when $p(x)=q(x)$ regardless of which divergence is used. In reality this is not the case, and even after perfect training $p(x)$ will still be an approximation. The kicker is that the “best” approximation depends on the divergence used.
For example, consider a simplified case in one dimension where $q(x)$ is a bimodal distribution, but $p(x)$ only has the modelling capacity of a single Gaussian. Should $p(x)$ try to fit a single mode really well (mode-seeking), or should it attempt to cover both (mode-covering)? There is no “right answer” to this question, which is why multiple f-divergences exist and are useful.
Fig 1. Which is the better approximation? The answer depends on the f-divergence you are using!
Poole et al. 3 have worked backwards to find the f-divergence being minimised for $G_{alt}$. It turns out that the divergence is not a named or well-known function. The authors argue that the GAN divergence is on the mode-seeking end of the spectrum, which results in a tendency for the generator to produce less variety.
It would be nice to specify whichever divergence we wanted when training a GAN. Fortunately for us, f-GAN4 describes a way to explicitly specify the f-divergence you want in the GAN objective.
Essentially the parts of the practical GAN objective specified earlier that imply the divergence are the sigmoid activation and the binary cross entropy loss. By replacing these parts with generic functions, we reach a more general formulation of the loss functions.
$$ \mathcal{L}_{real}(x) =-g _f(V(x;\omega)) $$
$$ \mathcal{L}_{fake}(x) =f^*(g _f(V(x;\omega))) $$
where $g_f(v)$ = an activation function tailored to the f-divergence, and $f^*(t)$ = the Fenchel conjugate of the f-divergence. A table of these functions can be found in the f-GAN paper, and they are relatively straightforward to implement as part of a custom criterion in Torch.
By setting $g_f(v) = \log(\sigma(v))$ and $f^*(t) = -\log(1 - e^t)$ we get the same discriminator loss functions as $G _{alt}$.
In the f-GAN paper, the generator loss is the same as $\mathcal{L}_{real}$:
$$ \mathcal{L}_{gen}(z) =-g _f(V(x;\omega)) $$
Pretty simple stuff here, really.
Poole et al. propose an extension which allows the generator and discriminator to be trained with different f-divergences. Roughly speaking this involves undoing the effects of the discriminator f-divergence to recover the density ratio $\frac{q(x)}{p(x)}$, and then applying the generator f-divergence $f_G$.
$$ \mathcal{L}_{gen}(z) =f_G((f’)^{-1}(g _f(V(x;\omega)))) $$
Here are some generated examples after training DCGAN on CIFAR-10 with different divergences, using the f-GAN generator loss.
f-divergence | Generated output |
---|---|
GAN divergence | |
JS divergence | |
RKL divergence |
Generative Adversarial Networks. https://arxiv.org/abs/1406.2661 ↩︎
Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks. https://arxiv.org/abs/1511.06434 ↩︎
Improved generator objectives for GANs. https://arxiv.org/abs/1612.02780 ↩︎
f-GAN: Training Generative Neural Samplers using Variational Divergence Minimization. https://arxiv.org/abs/1606.00709 ↩︎