GAN Discriminators Don't Generalize?

post by tryactions · 2020-06-08T20:36:08.069Z · LW · GW · 7 comments

Contents

7 comments

Disclaimer: I just started reading about GANs, so am almost certainly missing some context here.

Something that surprised me from the BigGAN paper:

We also observe that D’s loss approaches zero during training, but undergoes a sharp upward jump at collapse (Appendix F). One possible explanation for this behavior is that D is overfitting to the training set, memorizing training examples rather than learning some meaningful boundary between real and generated images. As a simple test for D’s memorization (related to Gulrajani et al. (2017)), we evaluate uncollapsed discriminators on the ImageNet training and validation sets, and measure what percentage of samples are classified as real or generated. While the training accuracy is consistently above 98%, the validation accuracy falls in the range of 50-55%, no better than random guessing (regardless of regularization strategy). This confirms that D is indeed memorizing the training set; we deem this in line with D’s role, which is not explicitly to generalize, but to distill the training data and provide a useful learning signal for G.

I'm not sure how to interpret this. The validation accuracy being close to 50% seems strange -- if the discriminator has 'memorized' the training set and has only seen training set vs generated images, why would it not guess close to 0% on things in the test set? Presumably they are both 1. not-memorized and 2. not optimized to fool the discriminator like generated images are. Maybe the post title is misleading, and we should think of this as "discriminators generalize surprisingly well despite also 'memorizing' the training data." (EDIT: See comment thread here [LW(p) · GW(p)] for clarification)

Note that the discriminator has far fewer parameters than there are bytes to memorize, so it necessarily is performing some sort of (lossy) compression to do well on the training set. Could we think of the generator as succeeding by exploiting patterns in the discriminator's compression, which the discriminator then works to obfuscate? I would expect more obfuscation to put additional demands on the discriminator's capacity. Maybe good generator task performance then comes from defeating simpler compression schemes, and it so happens that simple compression schemes are exactly what our visual system and GAN metrics are measuring.

Does this indicate that datasets are still just too small? Later in the same paper, they train on the much larger JFT-300M dataset (as opposed to ImageNet above) and mention:

Interestingly, unlike models trained on ImageNet, where training tends to collapse without heavy regularization (Section 4), the models trained on JFT-300M remain stable over many hundreds of thousands of iterations. This suggests that moving beyond ImageNet to larger datasets may partially alleviate GAN stability issues.

They don't mention whether this also increases discriminator generalization or decreases training set accuracy, which I'd be interested to know. I'd also be interested in connecting this story to mode collapse somehow.

7 comments

Comments sorted by top scores.

comment by gwern · 2020-06-09T02:23:59.573Z · LW(p) · GW(p)

These are good questions, and some of the points that suggest we don't really understand what GANs do or why they work. They are something I've previously highlighted in my writeups: https://www.gwern.net/Faces#discriminator-ranking * & https://github.com/tensorfork/tensorfork/issues/28 respectively.

The D memorization is particularly puzzling when you look at improvements to GANs, most recently, BigGAN got (fixed) data augmentation & SimCLR losses: one can understand why spatial distortions & SimCLR might help D under the naive theory that D learns realism and structure of real images to penalize errors by G, but then how do we explain chance guessing on ImageNet validation...?

Further, how do we explain the JFT-300M stability either, given that it seems unlikely that D is 'memorizing datapoints' when the batch sizes would suggest that the JFT-300M runs in question may be running only 4 or 5 epochs at most? (mooch generally runs at most n=2048 minibatches, so even 500k iterations is only ~3.4 epoches.)

Note that the discriminator has far fewer parameters than there are bytes to memorize, so it necessarily is performing some sort of (lossy) compression to do well on the training set.

Eh. "compression" is not a helpful concept here because every single generative model trained in any way is "compressing". (Someone once put up a website for using GPT-2 as a text compressor, because any model that emits likelihoods conditional on a history can be plugged into an arithmetic encoder and is immediately a lossless compressor/decompressor.)

Based on some other papers I don't have handy now, I've hand-waved that perhaps what a GAN's D does is it learns fuzzy patterns in image-space 'around' each real datapoint, and G spirals around each point, trying to approach it and collapse down to emitting the exact datapoint, but is repelled by D; as training progresses, D repels G from increasingly smaller regions around each datapoint. Because G spends its time traversing the image manifold and neural networks are biased towards simplicity, G inadvertently learns a generalizable generative model, even though it 'wants' to do nothing but memorize & spit out the original data (as the most certain Nash equilibrium way to defeat the D - obviously, D cannot possibly discriminate beyond 50-50 if given two identical copies of a real image). This is similar to the view of decision forests and neural networks as adaptive nearest-neighbor interpolators.

They don’t mention whether this also increases discriminator generalization or decreases training set accuracy, which I’d be interested to know.

mooch is pretty good about answering questions. You can ask him on Twitter. (I would bet the answer is probably that the equivalent test was not done on the JFT-300M models. His writeup is very thorough and I would expect him to have mentioned it if that had been done; in general, my impression is that the JFT-300M runs were done with very little time to spare and not nearly as thoroughly, since he spent all his time trying tweaks on BigGAN to get it to work at all.)

* One caveat I haven't had time to update my writeup with: I found that D ranking worked in a weird way which I interpreted as consistent with D memorization; however, I was recently informed that I had implemented it wrong and it works much better when fixed; but on the gripping hand, they find that the D ranking still doesn't really match up with 'realism' so maybe my error didn't matter too much.

Replies from: tryactions
comment by tryactions · 2020-06-09T16:20:55.738Z · LW(p) · GW(p)

Thanks for sharing thoughts and links: discriminator ranking, SimCLR, CR, and BCR are all interesting and I hadn't run into them yet. My naive thought was that you'd have to use differentiable augmenters to fit in generator augmentation.

You can ask him on Twitter.

I'm averse to using Twitter, but I will consider being motivated enough to sign-up and ask. Thanks for pointing this out.

"compression" is not a helpful concept here because every single generative model trained in any way is "compressing"

I am definitely using this concept too vaguely, although I was gesturing at compression in the discriminator instead of the generator. Thinking of the discriminator as a lossy compressor in this way would be... positing a mapping f: discriminator weights -> distributions, which for trained weights does not fully recapture the training distribution? We could see G as attempting to match this imperfect distribution (since it doesn't directly receive the training examples), and D as modifying weights to simultaneously 1. try to capture the training distribution as f(D), and 2. try to have f(D) avoid the output of G. Hence why I was thinking D might be "obfuscating" -- in this picture, I think f(D) is pressured to be a more complicated manifold while sticking close to the training distribution, making it more difficult for G to fit it.

Is such an f implicit in the discriminator outputs? I think that it is just by normalizing across the whole space, although that's computationally infeasible. I'd be interested in work that attempts to recover the training distribution from D alone.

I think it's decently likely I'm confused here.

Replies from: gwern
comment by gwern · 2020-06-09T21:54:56.958Z · LW(p) · GW(p)

My naive thought was that you'd have to use differentiable augmenters to fit in generator augmentation.

I believe the data augmentations in question are all differentiable, so you can backprop from the augmented images to G. (Which is not to say they are easy: the reason that Zhao et al 2020 came out before we got SimCLR working on our own BigGAN is that lucidrains & Shawn Presser got SimCLR working - we think - except it only works on GPUs, which we don't have enough of to train BigGAN on, and TPU CPUs, where it memory-leaks. Very frustrating, especially now that Zhao shows that SimCLR would have worked for us.)

I'm averse to using Twitter, but I will consider being motivated enough to sign-up and ask.

I assume he has email; he also hangs out on our Discord and answers questions from time to time.

I think it's decently likely I'm confused here.

It's definitely a confusing topic. Most GAN researchers seem to sort of shrug and... something something the Nash equilibrium minimizes the Jensen–Shannon divergence something something converges with decreasing learning rate in the limit, well, it works in practice, OK? Nothing like likelihood or VAE or flow-based models, that's for sure. (On the other hand, nobody's ever trained those on something like JFT-300M, and the compute requirements for something like OpenAI Jukebox are hilarious - what is it, 17 hours on a V100 to generate a minute of audio?)

comment by seed · 2020-06-09T18:59:18.185Z · LW(p) · GW(p)

Getting a validation accuracy of 50% in a binary classification task isn't "surprisingly well". It means your model is as good as random guessing: if you flipped a coin, you would get the right answer half the time, too. Getting 0% validation accuracy would mean that you are always guessing wrong, and would get 100% accurate results by reversing your model's prediction. So, yes, just like the article says, the discriminator does not generalize.

Replies from: tryactions
comment by tryactions · 2020-06-09T19:32:09.751Z · LW(p) · GW(p)

Yes, I understand this point. I was saying that we'd expect it to get 0% if its algorithm is "guess yes for anything in the training set and no for anything outside of it".

It continues to be surprising (to me) even though we expect that it's trying to follow that algorithm but can't do so exactly. Presumably the generator is able to emulate the features that it's using for inexactly matching the training set. In this case, if those features were "looks like something from the training/test distribution", we'd expect it to guess closer to 100% on the test set. If those features were highly specific to the training set, we'd expect it to get closer to 0% on the test set (since the model should reject anything without those features). Instead it gets ~50% which means whatever it's looking for is completely uncorrelated to what the test data looks like and present in half of the examples -- that seems surprising to me.

I'd currently interpret this as "the discriminator network acts nonsensically outside the training set + generator distribution, so it gets close to chance just because that's what nonsensical networks do."

Replies from: seed
comment by seed · 2020-06-09T19:59:10.879Z · LW(p) · GW(p)

Oh, I see, sorry.

Replies from: tryactions
comment by tryactions · 2020-06-09T20:17:11.817Z · LW(p) · GW(p)

No worries, was worth clarifying. I edited the post to link this comment thread.