Table of contents


Note: Jeff Donahue, Philipp Krähenbühl and Trevor Darrell at Berkeley published a paper independently from us on the same idea, which they call Bidirectional GAN, or BiGAN. You should also check out their awesome work at https://arxiv.org/abs/1605.09782.

What is ALI?

The adversarially learned inference (ALI) model is a deep directed generative model which jointly learns a generation network and an inference network using an adversarial process. This model constitutes a novel approach to integrating efficient inference with the generative adversarial networks (GAN) framework.

What makes ALI unique is that unlike other approaches to learning inference in deep directed generative models (like variational autoencoders (VAEs)), the objective function involves no explicit reconstruction loop. Instead of focusing on achieving a pixel-perfect reconstruction, ALI tends to produce believable reconstructions with interesting variations, albeit at the expense of making some mistakes in capturing exact object placement, color, style and (in extreme cases) object identity. This is a good thing, because 1) capacity is not wasted to model trivial factors of variation in the input, and 2) the learned features are more or less invariant to these trivial factors of variation, which is what is expected of good feature learning.

These strenghts are showcased via the semi-supervised learning tasks on SVHN and CIFAR10, where ALI achieves a performance competitive with state-of-the-art.

Quick introduction to GANs

GAN pits two neural networks against each other: a generator network \(G(\mathbf{z})\), and a discriminator network \(D(\mathbf{x})\).

The generator tries to mimic examples from a training dataset, which is sampled from the true data distribution \(q(\mathbf{x})\). It does so by transforming a random source of noise received as input into a synthetic sample.

The discriminator receives a sample, but it is not told where the sample comes from. Its job is to predict whether it is a data sample or a synthetic sample.

The discriminator is trained to make accurate predictions, and the generator is trained to output samples that fool the discriminator into thinking they came from the data distribution.

Generative adversarial networks

GAN: general idea

To better understand GANs, it is necessary to describe it using a probabilistic framework. Two marginal distributions are defined:

The generator operates by sampling \(\mathbf{z} \sim p(\mathbf{z})\) and then sampling \(\mathbf{x} \sim p(\mathbf{x} \mid \mathbf{z})\). The discriminator receives either the generator sample or \(\mathbf{x} \sim q(\mathbf{x})\) and outputs the probability that its input was sampled from \(q(\mathbf{x})\).

Generative adversarial networks

GAN: probabilistic view

The adversarial game played between the discriminator and the generator is formalized by the following value function:

On one hand, the discriminator is trained to maximize the probability of correctly classifying data samples and synthetic samples. On the other hand, the generator is trained to produce samples that fool the discriminator, i.e. that are unlikely to be synthetic according to the discriminator.

It can be shown that for a fixed generator, the optimal discriminator is

and that given an optimal discriminator, minimizing the value function with respect to the generator parameters is equivalent to minimizing the Jensen-Shannon divergence between \(p(\mathbf{x})\) and \(q(\mathbf{x})\).

In other words, as training progresses, the generator produces synthetic samples that look more and more like the training data.

Learning inference

Even though GANs are pretty good at producing realistic-looking synthetic samples, they lack something very important: the ability to do inference.

Inference can loosely be defined as the answer to the following question:

Given \(\mathbf{x}\), what \(\mathbf{z}\) is likely to have produced it?

This question is exactly what ALI is equipped to answer.

ALI augments GAN’s generator with an additional network. This network receives a data sample as input and produces a synthetic \(\mathbf{z}\) as output.

Expressed in probabilistic terms, ALI defines two joint distributions:

ALI also modifies the discriminator’s goal. Rather than examining \(\mathbf{x}\) samples marginally, it now receives joint pairs \((\mathbf{x}, \mathbf{z})\) as input and must predict whether they come from the encoder joint or the decoder joint.

Like before, the generator is trained to fool the discriminator, but this time it can also learn \(q(\mathbf{z} \mid \mathbf{x})\).

Adversarially learned inference

ALI: probabilistic view

The adversarial game played between the discriminator and the generator is formalized by the following value function:

In analogy to GAN, it can be shown that for a fixed generator, the optimal discriminator is

and that given an optimal discriminator, minimizing the value function with respect to the generator parameters is equivalent to minimizing the Jensen-Shannon divergence between \(p(\mathbf{x}, \mathbf{z})\) and \(q(\mathbf{x}, \mathbf{z})\).

Matching the joints also has the effect of matching the marginals (i.e., \(p(\mathbf{x}) \sim q(\mathbf{x})\) and \(p(\mathbf{z}) \sim q(\mathbf{z})\)) as well as the conditionals / posteriors (i.e., \(p(\mathbf{z} \mid \mathbf{x}) \sim q(\mathbf{z} \mid \mathbf{x})\) and \(q(\mathbf{x} \mid \mathbf{z}) \sim p(\mathbf{x} \mid \mathbf{z})\)).

Experimental results

Regarding reconstructions: odd columns are validation set examples, even columns are their corresponding reconstructions (e.g., first column contains validation set examples, second column contains their corresponding reconstruction).

CIFAR10

The CIFAR10 dataset contains 60,000 32x32 colour images in 10 classes.

Samples
Reconstructions

SVHN

SVHN is a dataset of digit images obtained from house numbers in Google Street View images. It contains over 600,000 labeled examples.

Samples
Reconstructions

CelebA

CelebA is a dataset of celebrity faces with 40 attribute annotations. It contains over 200,000 labeled examples.

Samples
Reconstructions

Tiny ImageNet

The Tiny Imagenet dataset is a version of the ILSVRC2012 dataset that has been center-cropped and downsampled to \(64 \times 64\) pixels. It contains over 1,200,000 labeled examples.

Samples
Reconstructions

Latent space interpolations

As a sanity check for overfitting, we look at latent space interpolations between CelebA validation set examples. We sample pairs of validation set examples \(\mathbf{x}_1\) and \(\mathbf{x}_2\) and project them into \(\mathbf{z}_1\) and \(\mathbf{z}_2\) by sampling from the encoder. We then linearly interpolate between \(\mathbf{z}_1\) and \(\mathbf{z}_2\) and pass the intermediary points through the decoder to plot the input-space interpolations.

Latent space interpolations

We observe smooth transitions between pairs of example, and intermediary images remain believable. This is an indicator that ALI is not concentrating its probability mass exclusively around training examples, but rather has learned latent features that generalize well.

Conditional generation

ALI can also be used for conditional generative modeling by providing the encoder, decoder and discriminator networks with a conditioning variable \(\mathbf{y}\).

We apply the conditional version of ALI to CelebA using the dataset’s 40 binary attributes. The attributes are linearly embedded in the encoder, decoder and discriminator. We observe how a single element of the latent space z changes with respect to variations in the attributes vector \(\mathbf{y}\).

Conditional generation

The row attributes are

The column attributes are

Semi-supervised learning

We investigate the usefulness of the latent representation learned by ALI through semi-supervised benchmarks on SVHN and CIFAR10.

SVHN

Number of labeled examples 1000
VAE (M1 + M2) 1 36.02
SWWAE with dropout 2 23.56
DCGAN + L2-SVM 3 22.18
SDGM 4 16.61
GAN (feature matching) [^7] 8.11 ± 1.3
ALI (ours, L2-SVM) 19.14 ± 0.50
ALI (ours, no feature matching) 7.42 ± 0.65

CIFAR10

Number of labeled examples 1000 2000 4000 8000
Ladder network 5     20.40  
CatGAN [^6]     19.58  
GAN (feature matching) [^7] 21.83 ± 2.01 19.61 ± 2.09 18.63 ± 2.32 17.72 ± 1.82
ALI (ours, no feature matching) 19.98 ± 0.89 19.09 ± 0.44 17.99 ± 1.62 17.05 ± 1.49

Discussion

We first compare with GAN on SVHN by following the procedure outlined in Radford et al. (2015) 3. We train an L2-SVM on the learned representations of a model trained on SVHN. The last three hidden layers of the encoder as well as its output are concatenated to form a 8960-dimensional feature vector. A 10,000 example held-out validation set is taken from the training set and is used for model selection. The SVM is trained on 1000 examples taken at random from the remainder of the training set. The test error rate is measured for 100 different SVMs trained on different random 1000-example training sets, and the average error rate is measured along with its standard deviation.

Using ALI’s inference network as opposed to the discriminator to extract features, we achieve a misclassification rate that is roughly 3.00 ± 0.50% lower than reported in Radford et al. (2015) 3, which suggests that ALI’s inference mechanism is beneficial to the semi-supervised learning task.

We then investigate ALI’s performance when label information is taken into account during training. We adapt the discriminative model proposed in Salimans et al. (2016) [^7]. The discriminator takes \(x\) and \(z\) as input and outputs a distribution over \(K + 1\) classes, where \(K\) is the number of categories. When label information is available for \(q(x, z)\) samples, the discriminator is expected to predict the label. When no label information is available, the discriminator is expected to predict \(K + 1\) for \(p(x, z)\) samples and \(k \in \{1,\ldots,K\}\) for \(q(x,z)\) samples.

Interestingly, Salimans et al. (2016) [^7] found that they required an alternative training strategy for the generator where it tries to match first-order statistics in the discriminator’s intermediate activations with respect to the data distribution (they refer to this as feature matching). We found that ALI did not require feature matching to achieve comparable results.

We are still investigating the differences between ALI and GAN with respect to feature matching, but we conjecture that the latent representation learned by ALI is better untangled with respect to the classification task and that it generalizes better.

Comparison with GAN on a toy dataset

Toy dataset

To highlight the role of the inference network during learning, we performed an experiment on a toy dataset for which \(q(x)\) is a 2D gaussian mixture with 25 mixture components laid out on a grid. The covariance matrices and centroids have been chosen such that the distribution exhibits lots of modes separated by large low-probability regions, which makes it a decently hard task despite the 2D nature of the dataset.

We trained ALI and GAN on 100,000 \(q(x)\) samples. The decoder and discriminator architectures are identical between ALI and GAN (except for the input of the discriminator, which receives the concatenation of x and z in the ALI case). Each model was trained 10 times using Adam with random learning rate and \(\beta_1\) values, and the weights were initialized by drawing from a gaussian distribution with a random standard deviation.

We measured the extent to which the trained models covered all 25 modes by drawing 10,000 samples from their \(p(x)\) distribution and assigning each sample to a \(q(x)\) mixture component according to the mixture responsibilities. We defined a dropped mode as one that wasn’t assigned to any sample. Using this definition, we found that ALI models covered 13.4 ± 5.8 modes on average (min: 8, max: 25) while GAN models covered 10.4 ± 9.2 modes on average (min: 1, max: 22).

We then selected the best-covering ALI and GAN models, and the GAN model was augmented with an encoder using the following procedures:

The encoders learned for GAN inference have the same architecture as ALI’s encoder. We also trained a VAE with the same encoder-decoder architecture as ALI to outline the qualitative differences between ALI and VAE models.

We then compared each model’s inference capabilities by reconstructing 10,000 held-out samples from \(q(x)\) (see figure above). We observe the following:

In summary, this experiment provides evidence that adversarial training benefits from learning an inference mechanism jointly with the decoder. Furthermore, it shows that our proposed approach for learning inference in an adversarial setting is superior to the other approaches investigated.

Conclusion

The adversarially learned inference (ALI) model jointly learns a generation network and an inference network using an adversarial process. The model learns mutually coherent inference and generation networks, as exhibited by its reconstructions. The induced latent variable mapping is shown to be useful, achieving results competitive with the state-of-the-art on the semi-supervised SVHN and CIFAR10 tasks.


Processing Systems_. [^6]: Springenberg, J. T. (2015). Unsupervised and semi-supervised learning with categorical generative adversarial networks. arXiv preprint arXiv:1511.06390. [^7]: Salimans, T., Goodfellow, I. J., Zaremba, W., Cheung, V., Radford, A., and Chen, X. (2016). Improved techniques for training GANs. arXiv preprint arXiv:1606.03498.

  1. Kingma, D. P., Mohamed, S., Rezende, D. J., and Welling, M. (2014). Semi-supervised learning with deep generative models. In Advances in Neural Information Processing Systems, pages 3581–3589. 

  2. Zhao, J., Mathieu, M., Goroshin, R., and Lecun, Y. (2015). Stacked what-where auto-encoders. arXiv preprint arXiv:1506.02351

  3. Radford, A., Metz, L., and Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434 2 3

  4. Maaløe, L., Sønderby, C. K., Sønderby, S. K., and Winther, O. (2016). Auxiliary deep generative models. arXiv preprint arXiv:1602.05473

  5. Rasmus, A., Valpola, H., Honkala, M., Berglund, M., and Raiko, T. (2015). Semi-supervised learning with ladder network. In _Advances in Neural Information