Introduction to Wasserstein GANs with Gradient Penalty

In this article, we discuss the Wasserstein loss function for Generative Adversarial Networks (GANs), which solves a common issue that arises during the training process.

3 years ago   •   7 min read

By Peter Foy

In our previous article on deep convolutional GANs we introduced a few components that improve the model's ability to generate images.

In this article, we'll discuss another issue that arises in GAN training, which is when the model gets stuck generating the same thing over and over again. For example, if we trained a GAN on dog images of all different breed types and it gets stuck only generating golden retrievers.

This issue occurs because the discriminator improves but gets stuck classifying an image as either extremely fake or extremely real. This produces a situation where the generator will latch on to this information, and only produce the "extremely real" looking images, i.e. the golden retrievers in this example.

More specifically, this occurs due to binary cross-entropy loss, in which the discriminator is forced to produce a value between 0 and 1. The discriminator will then approach 0 or 1 as the model "improves", although it still results in this undesirable outcome.

To solve this, we'll look at a new loss function that allows the discriminator to choose any number between negative infinity and positive infinity, which can mitigate the above-mentioned problem and allow both the generator and discriminator to continue learning over time.

This article is based on notes from Week 3 of the first course in this Generative Adversarial Networks (GANs) Specialization and is organized as follows:

  • Mode Collapse
  • Vanishing Gradient and BCE Loss
  • Earth Mover's Distance
  • Wasserstein Loss Function
  • Condition on the Wasserstein Critic
  • Methods for Enforcing 1-Lipschitz Continuity

Stay up to date with AI

We're an independent group of machine learning engineers, quantitative analysts, and quantum computing enthusiasts. Subscribe to our newsletter and never miss our articles, latest news, etc.

Great! Check your inbox and click the link.
Sorry, something went wrong. Please try again.

Mode Collapse

In this section and the next, we'll look at issues faced by traditional GANs that are trained with binary cross-entropy loss, two of which include mode collapse and vanishing gradients.

We'll then look at a modification to the GAN architecture and a new loss function that can overcome these issues.

In a distribution of a data, a mode is the area with a high concentration of observations.

In a normal distribution, for example, the mean value is the single mode of the distribution.

There can also be distributions with multiple modes where the mean isn't one of them, which is referred to as bimodal or multimodal. You can find a visual example of the mode of a distribution here.

Intuitively, any peak in a probability distribution over a set of features is a mode of the distribution.

If we take the example of generating handwritten digits between 0 and 9, the probability density distribution over the features $x_1$ and $x_2$ will have a distribution with 10 modes, one for each digit.

Intuitively, mode collapse sounds like the number of modes collapses to one of fewer modes and some of them disappear altogether.

In the handwritten digit example, we can imagine a discriminator may misclassify images of 1's and 7's as they look similar. The generator may then take this information and only produce 1's and 7's as they have a higher chance of fooling the discriminator.

This could then continue to a situation where the generator is only collapsing to a single mode of 1's as they have the highest chance of fooling the discriminator.

In summary:

  • Modes are peaks in the distribution of features.
  • Mode collapse occurs when the generator learns to fool the discriminator by producing a single class from the training set, in other words it gets stuck in a single mode

Vanishing Gradient and BCE Loss

The traditional way to train GANs is the binary cross-entropy loss, or BCE loss.

With BCE loss, however, training is prone to issues like mode collapse and vanishing gradients. In this section, we'll look at why BCE loss is susceptible to the vanishing gradient problem.

Recall that the BCE loss function is an average of the cost for the discriminator misclassifying real and fake observations.

The higher the cost function, the worse the discriminator is performing.

This means the generator wants to maximize this cost function and the discriminator wants to minimize—which is often referred to as the "minimax game".

Also recall that it is typically easier for the discriminator to determine if something is real or fake than it is for the generator to create a complex image. This means it's common for the discriminator to outperform the generator during training.

As the discriminator gets better during training, it can start to give less informative feedback to the generator.

In other words, the discriminator may provide gradients closer and closer to zero, which is less helpful for the generator as it doesn't know how to improve.

This is how the vanishing gradient problem arises with BCE loss function.

Earth Mover's Distance

In this section we'll look at a different cost function for GANs called Earth Mover's Distance (EMD) that solves the vanishing gradient problem found in BCE loss.

Let's assume we have a generated and a real distribution with the same variance but different means and they are normally distributed.

The Earth Mover's Distance measures how different these two distributions are by estimating the amount of effort it takes for the generated distribution to equal the real distribution.

This function depends on both the distance and the amount that the generated distribution needs to be moved in order to equal the real distribution.

A key difference of EMD cost function is that there is no "ceiling" at 0 and one like there is with BCE loss.

Instead, the cost function continue to grow regardless of how far apart the distributions are.

This means the gradient of the measurement won't approach zero, which makes the GAN less susceptible to the vanishing gradient problem and mode collapse.

Wasserstein Loss Function

In this section, we'll look at the Wasserstein loss function, or W-Loss, which uses Earth Mover's Distance for training GANs.

W-Loss works by approximating the Earth Mover's Distance between the real and generated distributions.

The equation for the Wasserstein loss is shown below:

$$\min_g \max_c {\mathbb{E}}(c(x)) - {\mathbb{E}}(c(g(z)))$$

In this case, the function calculates the difference between the expected values of the predictions of the discriminator.

The $c$ represents what's called the "critic", so it is $c$ of a real example on the left and $c$ of a generated example on the right.

The discriminator wants to maximize the distance between the the real and the fake examples, whereas the generator wants to minimize this difference.

Recall that with BCE loss, the output of the discriminator is a prediction between 0 and 1, which is why it uses a sigmoid activation function in the output layer.

W-Loss doesn't have this requirement, which means you can use a linear layer at the end of the discriminator's neural network and can produce any real value as output.

We now call the discriminator a critic since it is no longer discriminating between two classes, instead it is trying to maximize the distance between its evaluation of a fake example and a real example.

There is an additional condition for the Wasserstein cost function for it to work well, which we'll discuss in the next section.

Condition on the Wasserstein Critic

The continuity condition on the critic neural network is an important to understand when using W-Loss for training GANs.

In training GANs using W-Loss, the critic has a special condition—it needs to be called 1-Lipschitz Continuous, or 1-L Continuous.

This condition says that the norm of the gradient should be at most 1 at every point.

This condition on the neural networks is important for W-Loss as it ensures the W-Loss function is continuous, differentiable, and it doesn't grow too much and maintains stability during training.

This condition is applied to training both the critic and generator's neural network to ensure a stable training process.

Methods for Enforcing 1-Lipschitz Continuity

In this section we'll look at two methods to enforce 1-L continuity of the critic neural network:

  • Weight clipping
  • Gradient penalty

Recall that for the critic to be 1-L continue means the norm of its gradient is at most 1 at every single point of the function:

$$||\triangledown f(x) ||_2 \leq 1$$

Weight Clipping

Weight clipping ensures L-1 continuity by forcing the weights of the critic to a fixed interval.

Any weights outside of this interval are then clipped, or set to the maximum or minimum amount allowed.

The downside of this method is that it limits the critics ability to learn if you clip the weights too much.

Gradient Penalty

The gradient penalty is another method to enforce L-1 continuity by adding a regularization term to the loss function:

$$\min_g \max_c {\mathbb{E}}(c(x)) - {\mathbb{E}}(c(g(z))) + \lambda reg$$

What this does is penalize the critic if the gradient norm is greater than one.

To do so, the function samples some points by randomly interpolating between real and fake examples. We can call randomly interpolated image $\hat{x}$, and its here that we want the critic's gradient to be less than or equal to 1.

The regulation term is shown below:

$$(||\triangledown c(\hat{x}) ||_2 - 1)^2$$

The $\hat{x}$ example can be described as follows:

$$\epsilon x + (1 - \epsilon)g(z)$$

The complete expression for training a GAN with W-Loss and a gradient penalty is shown below:

$$\min_g \max_c {\mathbb{E}}(c(x)) - {\mathbb{E}}(c(g(z))) + \lambda{\mathbb{E}}(||\triangledown c(\hat{x})||_2 -1)^2$$

The first half ${\mathbb{E}}(c(x)) - {\mathbb{E}}(c(g(z)))$ estimates Earth Mover's distance with the main W-Loss component.

This makes the GAN less prone to mode collapse and vanishing gradients.

The second half $\lambda{\mathbb{E}}(||\triangledown c(\hat{x})||_2 -1)^2$ is the regularization term that tries to make the critic 1-L Continuous so that the loss function is continuous and differentiable.

Summary: Wasserstein GANs with Gradient Penalty

To summarize, the Wasserstein loss function solves a common problem during GAN training, which arises when the generator gets stuck creating the same example over and over again.

To solve this, W-loss works by approximating the Earth Mover's Distance between the real and generated distributions.

In order to use W-Loss for training GANs, we also have a special condition called 1-Lipschitz Continuous, or 1-L Continuous, which says that the norm of the gradient should be at most 1 at every point. This condition can be enforced using two methods: weight clipping or gradients penalty.

If you'd like to learn more, the original Wasserstein GAN paper can be found here and the paper that proposes gradient penalty and weight clipping to WGAN can be found here.


Spread the word

Keep reading