Techno Blender
Digitally Yours.

Why WGANs beat GANs: A journey from KL divergence to Wasserstein loss | by shashank kumar | Jan, 2023

0 36


In 2014, Ian Goodfellow came up with the idea of GAN or vanilla GAN as we call it today. Although impressive, it was notoriously hard to train. Vanilla GAN suffered from an inability to converge, vanishing gradients, and mode collapse. Subsequently, a ton of research centred around addressing these issues. Researchers experimented with different model architectures, loss functions, and training methodologies. A particularly effective solution was Wasserstein GAN, introduced in 2017 by Arjovsky et al.

Photo by Arnold Francisca on Unsplash

This post attempts to explain why Wasserstein GANs function better than vanilla GANs. It assumes readers have some familiarity with the math behind GANs and VAEs and their training process. Let’s begin with an introduction to generative networks.

1. Generative Networks: A brief

Generative networks strive to spawn new samples that resemble the real data family. They do so by mimicking the data distribution. Popular frameworks like GANs and VAEs approach this by learning a mapping G that transforms a known/assumed distribution Z to the actual distribution space. Generators in GANs and encoders in VAEs handle this job. The neural weights of these networks parametrise G.

Networks minimise the difference between the actual and generated data distributions to learn the mapping. Three popular measures to quantify this difference are:

  1. KL Divergence
  2. JS Divergence
  3. Earth-Mover (EM) or Wasserstein-1 distance

Let’s peek into them.

2. Why KL and JS divergence fail?

We’ll briefly introduce KL and JS divergences and understand why they fail.

  1. Kullback-Leibler(KL) divergence can be computed in two ways, forward or reverse, and is thus asymmetric. Depending on whether the distributions are continuous or discrete, their forward KL divergence is as follows
Forward KL Divergence

We can calculate the reverse KL divergence by using Q, instead of P, to weight the log difference of distributions. VAEs operate on reverse KL.

Reverse KL divergence

2. Unlike the former, Jensen-Shannon(JS) divergence is symmetric. It’s essentially an average of the two KL divergences. It is not conspicuous from their loss function, binary cross-entropy, but GANs function on JS divergence when the discriminator attains optimality. I urge you to read this blog to understand why so.

JS Divergence

Arjovsky et al. used a simple example to show the pitfalls of KL and JS divergence. Consider two probability distributions described by the two parallel lines in the image below. Red is the actual distribution(P), while green is its estimation(Q). θ is the horizontal distance between them.

Image by author

A generative model would work to shift green closer to red, which is fixed at 0. Can you calculate the JS and KL divergence when θ = 1?

KL Divergences
JS Divergence

Now, how would these measures vary as a function of θ? If you observe closely, they won’t change unless θ=0.

I observe two main drawbacks with these measures.

  1. The difference between distributions at θ=0.1 and θ=1 is the same. At θ=0.1, the green line is much closer to the red, so the difference must be less.
  2. θ can be loosely considered as the estimated distribution. If our generative model is parametrised by φ, then θ is a function f(φ). During back-propagation, we’ll calculate the gradient of loss function, defined by one of the above measures, with respect to φ to tune the parameters. The first term of the second equation will always be 0. Hence, owing to zero gradients, our poor model will not learn anything.
Back-propagation using chain rule

From the above example we can make certain crucial inferences.

  1. KL and JS divergences do not account for the horizontal difference between two probability distributions
  2. In situations where the predicted and actual distributions don’t overlap, which is common during training, the gradients might be 0 leading to no learning.

So, how does the Wasserstein distance address these issues?

3. Why Wasserstein distance is a better measurement?

Frankly, the formulation of EM/Wasserstein-1 distance is horrific. So, I’ll refrain from its mathematical details. Instead, let’s understand it intuitively.

EM/Wasserstein-1

Let’s go back to parallel lines. This time, four of them. The red lines constitute a probability distribution, and so do the blue ones. The numbers on top are the probability mass at corresponding points (x=0,1,2,3). We intend to modify the blue lines to align (similar distributions) them with the red family. To do so, we’ll shift the probability masses.

Image by author
  1. Shift 0.2 from x=2 to x=0
  2. Shift 0.1 from x=3 to x=0
  3. Shift remaining 0.7 from x=3 to x=1

This, however, is just one amongst many other transport plans. We could also

  1. Shift 0.1 from x=2 to x=0
  2. Shift 0.1 from x=2 to x=1
  3. Shift 0.2 from x=3 to x=0
  4. Shift 0.6 from x=3 to x=0

Of the two plans, which is more optimal? To ascertain that, let’s take the analogy of work from physics. Here, we define work as mass multiplied by shift distance. So, the work done in the two transport schemes is 2.1 and 2.7.

Image by author

2.1, work involved in the optimal transport plan, is thus the EM distance. The same is applicable in the continuous realm also. As each point has an associated probability mass in continuous distributions, integration replaces summation.

In a nutshell, the above morbid equation involves computing the work required in all transport schemes that bring similarity between two distributions and selecting the minimum amongst them. As opposed to the other measures, EM distance accounts for the horizontal difference between two distributions while preserving their overall shape. Moreover, it also does away with the problem of vanishing gradients.

Reconsider the example in section 1. The minimum cost in aligning the two distributions is θ, the horizontal distance between the lines. Hence, we get stable gradients to tune the parameters even when the predicted and actual distributions don’t overlap.

Image by author

4. Using python to visualize all measures

Now, let’s calculate all the discussed measures in python to visualize their behaviour.

import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt
from scipy.stats import wasserstein_distance

We’ll define functions to calculate KL and JS divergences and thankfully, for EM distance, we can use the scipy library.

def kld(d1,d2,eps=10^-6):
return np.sum(d1*np.log((d1)/(d2)))

def jsd(d1,d2):
return 0.5 * kld(d1,(d1+d2)/2) + 0.5 * kld(d2,(d1+d2)/2)

Next, we’ll define two normal distributions and visualize how the measures change as the gaussians are separated. That is, their mean difference is increased.

x = np.arange(-40, 40, 0.001)
q_mean_range = (2,20)
fkl = []
bkl = []
js = []
em = []
p = norm.pdf(x, 0, 2)
for i in range(*q_mean_range):
q = norm.pdf(x, i, 2)
fkl.append(kld(p,q))
bkl.append(kld(q,p))
js.append(jsd(p,q))
em.append(wasserstein_distance(p,q))
x = np.arange(*q_mean_range,1)
f,ax = plt.subplots(2,2,figsize=(10,10))
ax = ax.flatten()
for i,(vals,name) in enumerate(zip([fkl,bkl,js,em],['FWD KL','REV KL','JS Divergence','EM Distance'])):
ax[i].plot(x,vals)
ax[i].set_xlabel('Mean difference')
ax[i].set_title(name)
plt.show()
Plot of all measures vs mean difference

What do you note? As the mean difference is increased, the KL divergences explode and the JS divergence saturates. However, the EM distance increases linearly. Thus, amongst the four, EM distance seems to be the best option for maintaining gradient flow during training.

EM/Wasserstein distance also alleviates mode collapse. Roughly speaking, mode collapse happens when the generator learns to produce a particular mode that fools a discriminator stuck in local minima. As discussed in section 2, GANS — when the discriminator is at a minimum — operates on JS divergence leading to zero gradients. Consequently, the discriminator is trapped, and the generator is disincentivised to create varying samples.

With Wasserstein distance, discriminators can attain optimality without the gradients vanishing. They can escape local minima and discard generator outputs, impelling the generator to not overfit on a particular discriminator and produce multiple modes.

Conclusion

With this, we come to the end of the post. I hope it summarised why Wasserstein GANs outperform vanilla GANs and VAEs. I’ve glossed over some mathematical intricacies surrounding the EM distance. If you’re interested, you need to read the paper. You must note that the discussed equation of EM distance is intractable. A mathematical trick is applied to approximate it. It’s unlikely that you’ll ever need it in practice. Still, if you’re itching to know, you can read about the approximation here or in the paper.


In 2014, Ian Goodfellow came up with the idea of GAN or vanilla GAN as we call it today. Although impressive, it was notoriously hard to train. Vanilla GAN suffered from an inability to converge, vanishing gradients, and mode collapse. Subsequently, a ton of research centred around addressing these issues. Researchers experimented with different model architectures, loss functions, and training methodologies. A particularly effective solution was Wasserstein GAN, introduced in 2017 by Arjovsky et al.

Photo by Arnold Francisca on Unsplash

This post attempts to explain why Wasserstein GANs function better than vanilla GANs. It assumes readers have some familiarity with the math behind GANs and VAEs and their training process. Let’s begin with an introduction to generative networks.

1. Generative Networks: A brief

Generative networks strive to spawn new samples that resemble the real data family. They do so by mimicking the data distribution. Popular frameworks like GANs and VAEs approach this by learning a mapping G that transforms a known/assumed distribution Z to the actual distribution space. Generators in GANs and encoders in VAEs handle this job. The neural weights of these networks parametrise G.

Networks minimise the difference between the actual and generated data distributions to learn the mapping. Three popular measures to quantify this difference are:

  1. KL Divergence
  2. JS Divergence
  3. Earth-Mover (EM) or Wasserstein-1 distance

Let’s peek into them.

2. Why KL and JS divergence fail?

We’ll briefly introduce KL and JS divergences and understand why they fail.

  1. Kullback-Leibler(KL) divergence can be computed in two ways, forward or reverse, and is thus asymmetric. Depending on whether the distributions are continuous or discrete, their forward KL divergence is as follows
Forward KL Divergence

We can calculate the reverse KL divergence by using Q, instead of P, to weight the log difference of distributions. VAEs operate on reverse KL.

Reverse KL divergence

2. Unlike the former, Jensen-Shannon(JS) divergence is symmetric. It’s essentially an average of the two KL divergences. It is not conspicuous from their loss function, binary cross-entropy, but GANs function on JS divergence when the discriminator attains optimality. I urge you to read this blog to understand why so.

JS Divergence

Arjovsky et al. used a simple example to show the pitfalls of KL and JS divergence. Consider two probability distributions described by the two parallel lines in the image below. Red is the actual distribution(P), while green is its estimation(Q). θ is the horizontal distance between them.

Image by author

A generative model would work to shift green closer to red, which is fixed at 0. Can you calculate the JS and KL divergence when θ = 1?

KL Divergences
JS Divergence

Now, how would these measures vary as a function of θ? If you observe closely, they won’t change unless θ=0.

I observe two main drawbacks with these measures.

  1. The difference between distributions at θ=0.1 and θ=1 is the same. At θ=0.1, the green line is much closer to the red, so the difference must be less.
  2. θ can be loosely considered as the estimated distribution. If our generative model is parametrised by φ, then θ is a function f(φ). During back-propagation, we’ll calculate the gradient of loss function, defined by one of the above measures, with respect to φ to tune the parameters. The first term of the second equation will always be 0. Hence, owing to zero gradients, our poor model will not learn anything.
Back-propagation using chain rule

From the above example we can make certain crucial inferences.

  1. KL and JS divergences do not account for the horizontal difference between two probability distributions
  2. In situations where the predicted and actual distributions don’t overlap, which is common during training, the gradients might be 0 leading to no learning.

So, how does the Wasserstein distance address these issues?

3. Why Wasserstein distance is a better measurement?

Frankly, the formulation of EM/Wasserstein-1 distance is horrific. So, I’ll refrain from its mathematical details. Instead, let’s understand it intuitively.

EM/Wasserstein-1

Let’s go back to parallel lines. This time, four of them. The red lines constitute a probability distribution, and so do the blue ones. The numbers on top are the probability mass at corresponding points (x=0,1,2,3). We intend to modify the blue lines to align (similar distributions) them with the red family. To do so, we’ll shift the probability masses.

Image by author
  1. Shift 0.2 from x=2 to x=0
  2. Shift 0.1 from x=3 to x=0
  3. Shift remaining 0.7 from x=3 to x=1

This, however, is just one amongst many other transport plans. We could also

  1. Shift 0.1 from x=2 to x=0
  2. Shift 0.1 from x=2 to x=1
  3. Shift 0.2 from x=3 to x=0
  4. Shift 0.6 from x=3 to x=0

Of the two plans, which is more optimal? To ascertain that, let’s take the analogy of work from physics. Here, we define work as mass multiplied by shift distance. So, the work done in the two transport schemes is 2.1 and 2.7.

Image by author

2.1, work involved in the optimal transport plan, is thus the EM distance. The same is applicable in the continuous realm also. As each point has an associated probability mass in continuous distributions, integration replaces summation.

In a nutshell, the above morbid equation involves computing the work required in all transport schemes that bring similarity between two distributions and selecting the minimum amongst them. As opposed to the other measures, EM distance accounts for the horizontal difference between two distributions while preserving their overall shape. Moreover, it also does away with the problem of vanishing gradients.

Reconsider the example in section 1. The minimum cost in aligning the two distributions is θ, the horizontal distance between the lines. Hence, we get stable gradients to tune the parameters even when the predicted and actual distributions don’t overlap.

Image by author

4. Using python to visualize all measures

Now, let’s calculate all the discussed measures in python to visualize their behaviour.

import numpy as np
from scipy.stats import norm
from matplotlib import pyplot as plt
from scipy.stats import wasserstein_distance

We’ll define functions to calculate KL and JS divergences and thankfully, for EM distance, we can use the scipy library.

def kld(d1,d2,eps=10^-6):
return np.sum(d1*np.log((d1)/(d2)))

def jsd(d1,d2):
return 0.5 * kld(d1,(d1+d2)/2) + 0.5 * kld(d2,(d1+d2)/2)

Next, we’ll define two normal distributions and visualize how the measures change as the gaussians are separated. That is, their mean difference is increased.

x = np.arange(-40, 40, 0.001)
q_mean_range = (2,20)
fkl = []
bkl = []
js = []
em = []
p = norm.pdf(x, 0, 2)
for i in range(*q_mean_range):
q = norm.pdf(x, i, 2)
fkl.append(kld(p,q))
bkl.append(kld(q,p))
js.append(jsd(p,q))
em.append(wasserstein_distance(p,q))
x = np.arange(*q_mean_range,1)
f,ax = plt.subplots(2,2,figsize=(10,10))
ax = ax.flatten()
for i,(vals,name) in enumerate(zip([fkl,bkl,js,em],['FWD KL','REV KL','JS Divergence','EM Distance'])):
ax[i].plot(x,vals)
ax[i].set_xlabel('Mean difference')
ax[i].set_title(name)
plt.show()
Plot of all measures vs mean difference

What do you note? As the mean difference is increased, the KL divergences explode and the JS divergence saturates. However, the EM distance increases linearly. Thus, amongst the four, EM distance seems to be the best option for maintaining gradient flow during training.

EM/Wasserstein distance also alleviates mode collapse. Roughly speaking, mode collapse happens when the generator learns to produce a particular mode that fools a discriminator stuck in local minima. As discussed in section 2, GANS — when the discriminator is at a minimum — operates on JS divergence leading to zero gradients. Consequently, the discriminator is trapped, and the generator is disincentivised to create varying samples.

With Wasserstein distance, discriminators can attain optimality without the gradients vanishing. They can escape local minima and discard generator outputs, impelling the generator to not overfit on a particular discriminator and produce multiple modes.

Conclusion

With this, we come to the end of the post. I hope it summarised why Wasserstein GANs outperform vanilla GANs and VAEs. I’ve glossed over some mathematical intricacies surrounding the EM distance. If you’re interested, you need to read the paper. You must note that the discussed equation of EM distance is intractable. A mathematical trick is applied to approximate it. It’s unlikely that you’ll ever need it in practice. Still, if you’re itching to know, you can read about the approximation here or in the paper.

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment