My impression of singular learning theory

post by Ege Erdil (ege-erdil) · 2023-06-18T15:34:27.249Z · LW · GW · 30 comments

Contents

30 comments

Disclaimer: I'm by no means an expert on singular learning theory and what I present below is a simplification that experts might not endorse. Still, I think it might be more comprehensible for a general audience than going into digressions about blowing up singularities and birational invariants.

Here is my current understanding of what singular learning theory is about in a simplified (though perhaps more realistic?) discrete setting.

Suppose you represent a neural network architecture as a map where , is the set of all possible parameters of (seen as floating point numbers, say) and is the set of all possible computable functions from the input and output space you're considering. In thermodynamic terms, we could identify elements of as "microstates" and the corresponding functions that the NN architecture maps them to as "macrostates".

Furthermore, suppose that comes together with a loss function evaluating how good or bad a particular function is. Assume you optimize using something like stochastic gradient descent on the function with a particular learning rate.

Then, in general, we have the following results:

  1. SGD defines a Markov chain structure on the space whose stationary distribution is proportional to on parameters for some positive constant that depends on the learning rate. This is just a basic fact about the Langevin dynamics that SGD would induce in such a system.
  2. In general is not injective, and we can define the "-complexity" of any function as . Then, the probability that we arrive at the macrostate is going to be proportional to .
  3. When is some kind of negative log-likelihood, this approximates Solomonoff induction in a tempered Bayes paradigm - we raise likelihood ratios to a power - insofar as the -complexity is a good approximation for the Kolmogorov complexity of the function , which will happen if the function approximator defined by is sufficiently well-behaved.

The intuition for why we would expect (3) to be true in practice has to do with the nature of the function approximator . When is small, it probably means that we only need a small number of bits of information on top of the definition of itself to define , because "many" of the possible parameter values for are implementing the function . So is probably a simple function.

On the other hand, if is a simple function and is sufficiently flexible as a function approximator, we can probably implement the functionality of using only a small number of the bits in the codomain of , which leaves us the rest of the bits to vary as we wish. This makes quite large, and by extension the complexity quite small.

The vague concept of "flexibility" mentioned in the paragraph above requires to have singularities of many effective dimensions, as this is just another way of saying that the image of has to contain functions with a wide range of -complexities. If is a one-to-one function, this clean version of the theory no longer works, though if is still "close" to being singular (for instance, because many of the functions in its image are very similar) then we can still recover results like the one I mentioned above. The basic insights remain the same in this setting.

I'm wondering what singular learning theory experts have to say about this simplification of their theory. Is this explanation missing some important details that are visible in the full theory? Does the full theory make some predictions that this simplified story does not make?

30 comments

Comments sorted by top scores.

comment by Daniel Murfet (dmurfet) · 2023-06-18T16:44:07.166Z · LW(p) · GW(p)

I think this is a very nice way to present the key ideas. However, in practice I think the discretisation is actually harder to reason about than the continuous version. There are deeper problems, but I'd start by wondering how you would ever compute c(f) defined this way, since it seems to depend in an intricate way on the details of e.g. the floating point implementation.

I'll note that the volume codimension definition of the RLCT is essentially what you have written down here, and you don't need any mathematics beyond calculus to write that down. You only need things like resolutions of singularities if you actually want to compute that value, and the discretisation doesn't seem to offer any advantage there.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-18T16:58:11.268Z · LW(p) · GW(p)

I think this is a very nice way to present the key ideas. However, in practice I think the discretisation is actually harder to reason about than the continuous version. There are deeper problems, but I'd start by wondering how you would ever compute c(f) defined this way, since it seems to depend in an intricate way on the details of e.g. the floating point implementation.

I would say that the discretization is going to be easier for people with a computer science background to grasp, even though formally I agree it's going to be less pleasant to reason about or to do computations with. Still, if properties of NNs that only appeared when they are continuous functions on were essential for their generalization, we might be in trouble as people keep lowering the precision of their floating point numbers. This explanation makes it clear that while assuming NNs are continuous (or even analytic!) might be useful for theoretical purposes, the claims about generalization hold just as well in a more realistic discrete setting.

I'll note that the volume codimension definition of the RLCT is essentially what you have written down here, and you don't need any mathematics beyond calculus to write that down. You only need things like resolutions of singularities if you actually want to compute that value, and the discretisation doesn't seem to offer any advantage there.

Yes, my definition is inspired by the volume codimension definition, though here we don't need to take a limit as some because the counting measure makes our life easy. The problem you have in a smooth setting is that descending the Lebesgue measure in a dumb way to subspaces with positive codimension gives trivial results, so more care is necessary to recover and reason about the appropriate notions of volume.

comment by interstice · 2023-06-18T16:02:08.319Z · LW(p) · GW(p)

None of this is specific to singular learning theory. The basic idea that the parameter-function map might be degenerate and biased towards simple functions predates SLT(at least this most recent wave of interest in its application to neural nets anyway) and indeed goes back to the 90s, no algebraic geometry required. As far as I can tell, the non-trivial content of SLT is that the averaging over parameters with a given loss is dominated by singular points in the limit because volume clusters there as you take an ever-narrower interval around the minimum set. That's interesting, but I don't have a strong expectation it will end up being applicable to real neural nets since I don't see a mechanism by which SGD is supposed to be attracted to such points(I can see why SGD would be attracted to broad basins generally, but that's not SLT-specific -- the SLT-specific part is attraction to weird points where many broad basins intersect)

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-18T16:08:43.373Z · LW(p) · GW(p)

None of this is specific to singular learning theory. The basic idea that the parameter-function map might be degenerate and biased towards simple functions predates SLT(at least this most recent wave of interest in its application to neural nets anyway) and indeed goes back to the 90s, no algebraic geometry required.

Sure, I'm aware that people have expressed these ideas before, but I have trouble understanding what is added by singular theory on top of this description. To me, much of singular learning theory looks like trying to do these kinds of calculations in an analytic setting where things become quite a bit more complicated, for example because you no longer have the basic counting function to measure the effective dimensionality of a singularity, forcing you to reach for concepts like "real log canonical threshold" instead.

As far as I can tell, the non-trivial content of SLT is that the averaging over parameters with a given loss is dominated by singular points in the limit because volume clusters there as you take an ever-narrower interval around the minimum set.

I'm not sure why we should expect that beyond the argument I already give in the post. The geometry of the loss landscape is already fully accounted for by the Boltzmann factor; what else does singular learning theory add here?

Maybe this is also what you're confused about when you say "I don't see a mechanism by which SGD is supposed to be attracted to such points".

Replies from: interstice
comment by interstice · 2023-06-18T16:16:43.734Z · LW(p) · GW(p)

I’m not sure why we should expect that beyond the argument I already give in the post. The geometry of the loss landscape is already fully accounted for by the Boltzmann factor; what else does singular learning theory add here?

So I believe the point of SLT is that the Boltzmann-weighted integral over the state-space simplifies in certain settings as the number of data points approaches infinity. That integral is going to be dominated by a narrow 'band' around the minimum set, and to evaluate it generally you have to consider the entire minimum set. But when there are singularities, places where there are cusps or intersections of the minimum set, the narrow band's effective dimensionality can go up(this is illustrated in the tweet I linked). This means that as you can just consider the behavior near the 'cuspiest' singularity(I think this is what the RLCT measures) to understand the whole integral.

(...uh, I think. I actually haven't looked into the details enough to write with confidence, but the above is my impression from what reading I have done and jesse's tweet)

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-18T16:21:23.905Z · LW(p) · GW(p)

To me that just sounds like you're saying the integral is dominated by the contribution of the simplest functions that are of minimum loss, and the contribution factor scales like where is the effective dimensionality near the singularity representing this function, equivalently the complexity of said function. That's exactly what I'm saying in my post - where is the added content here?

Replies from: tgb, interstice
comment by tgb · 2023-06-19T11:38:07.404Z · LW(p) · GW(p)

Here's a concrete toy example where SLT and this post give different answers (SLT is more specific). Let .  And let . Then the minimal loss is achieved at the set of parameters where  or  (note that this looks like two intersecting lines, with the singularity being the intersection). Note that all  in that set also give the same exact . The theory in your post here doesn't say much beyond the standard point that gradient descent will (likely) select a minimal or near-minimal , but it can't distinguish between the different values of  within that minimal set.

SLT on the other hand says that gradient descent will be more likely to choose the specific singular value  .

Now I'm not sure this example is sufficiently realistic to demonstrate why you would care about SLT's extra specificity, since in this case I'm perfectly happy with any value of  in the minimal set - they all give the exact same . If I were to try to generalize this into a useful example, I would try to find a case where  has a minimal set that contains multiple different . For example,  only evaluates  on a subset of points (the 'training data') but different choices of minimal  give different values outside of that subset of training data. Then we can consider which  has the best generalization to out-of-training data - do the parameters predicted by SLT yield  that are best at generalizing?

Disclaimer: I have a very rudimentary understanding of SLT and may be misrepresenting it.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-19T12:27:47.689Z · LW(p) · GW(p)

I don't think this representation of the theory in my post is correct. The effective dimension of the singularity near the origin is much higher, e.g. because near every other minimal point of this loss function the Hessian doesn't vanish, while for the singularity at the origin it does vanish. If you discretized this setup by looking at it with a lattice of mesh , say, you would notice that the origin is surrounded by many parameters that give nearly identical loss, while near other parts of the space the number of such parameters is far fewer.

The reason you have to do some kind of "translation" between the two theories is that SLT can see not just exactly optimal points but also nearly optimal points, and bad singularities are surrounded by many more nearly optimal points than better-behaved singularities. You can interpret the discretized picture above as the SLT picture seen at some "resolution" or "scale" , i.e. if you discretized the loss function by evaluating it on a lattice with mesh you get my picture. Of course, this loses the information of what happens as and in some thermodynamic limit, which is what you recover when you do SLT.

I just don't see what this thermodynamic limit tells you about the learning behavior of NNs that we didn't know before. We already know NNs approximate Solomonoff induction if the -complexity is a good approximation to Kolmogorov complexity and so forth. What additional information is gained by knowing what looks like as a smooth function as opposed to a discrete function?

In addition, the strong dependence of SLT on being analytic is bad, because analytic functions are rigid: their value in a small open subset determines their value globally. I can see why you need this assumption because quantifying what happens near a singularity becomes incredibly difficult for general smooth functions, but because of the rigidity of analytic functions the approximation that "we can just pretend NNs are analytic" is more pernicious than e.g. "we can just pretend NNs are smooth". Typical approximation theorems like Stone-Weierstrass also fail to save you because they only work in the sup-norm and that's completely useless for determining behavior at singularities. So I'm yet to be convinced that the additional details in SLT provide a more useful account of NN learning than my simple description above.

Replies from: tgb
comment by tgb · 2023-06-19T15:45:13.481Z · LW(p) · GW(p)

The effective dimension of the singularity near the origin is much higher, e.g. because near every other minimal point of this loss function the Hessian doesn't vanish, while for the singularity at the origin it does vanish. If you discretized this setup by looking at it with a lattice of mesh , say, you would notice that the origin is surrounded by many parameters that give nearly identical loss, while near other parts of the space the number of such parameters is far fewer.

As I read it, the arguments you make in the original post depend only on the macrostate , which is the same for both the singular and non-singular points of the minimal loss set (in my example), so they can't distinguish these points at all. I see that you're also applying the logic to points near the minimal set and arguing that the nearly-optimal points are more abundant near the singularities than near the non-singularities. I think that's a significant point not made at all in your original point that brings it closer to SLT, so I'd encourage you to add it to the post.

I think there's also terminology mismatch between your post and SLT. You refer to singularities of (i.e. its derivative is degenerate) while SLT refers to singularities of the set of minimal loss parameters. The point  in my example is not singular at all in SLT but  is singular. This terminology collision makes it sound like you've recreated SLT more than you actually have.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-19T17:07:45.861Z · LW(p) · GW(p)

I'm not too sure how to respond to this comment because it seems like you're not understanding what I'm trying to say.

I agree there's some terminology mismatch, but this is inevitable because SLT is a continuous model and my model is discrete. If you want to translate between them, you need to imagine discretizing SLT, which means you discretize both the codomain of the neural network and the space of functions you're trying to represent in some suitable way. If you do this, then you'll notice that the worse a singularity is, the lower the -complexity of the corresponding discrete function will turn out to be, because many of the neighbors map to the same function after discretization.

The content that SLT adds on top of this is what happens in the limit where your discretization becomes infinitely fine and your dataset becomes infinitely large, but your model doesn't become infinitely large. In this case, SLT claims that the worst singularities dominate the equilibrium behavior of SGD, which I agree is an accurate claim. However, I'm not sure what this claim is supposed to tell us about how NNs learn. I can't make any novel predictions about NNs with this knowledge that I couldn't before.

Replies from: interstice, tgb
comment by interstice · 2023-06-19T21:31:46.922Z · LW(p) · GW(p)

In this case, SLT claims that the worst singularities dominate the equilibrium behavior of SGD, which I agree is an accurate claim. However, I'm not sure what this claim is supposed to tell us about how NNs learn

I think the implied claim is something like "analyzing the singularities of the model will also be helpful for understanding SGD in more realistic settings" or maybe just "investigating this area further will lead to insights which are applicable in more realistic settings". I mostly don't buy it myself.

comment by tgb · 2023-06-19T22:44:02.338Z · LW(p) · GW(p)

the worse a singularity is, the lower the -complexity of the corresponding discrete function will turn out to be

This is where we diverge. Please let me know where you think my error is in the following. Returning to my explicit example (though I wrote  originally but will instead use  in this post since that matches your definitions).

1. Let   be the constant zero function and  

2. Observe that  is the minimal loss set under our loss function and also  is the set of parameters  where  or .

3. Let  . Then  by definition of . Therefore, 

4. SLT says that  is a singularity of  but that  is not a singularity.

5. Therefore, there exists a singularity (according to SLT) which has identical -complexity (and also loss) as a non-singular point, contradicting your statement I quote.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-20T08:52:57.733Z · LW(p) · GW(p)

You need to discretize the function before taking preimages. If you just take preimages in the continuous setting, of course you're not going to see any of the interesting behavior SLT is capturing.

In your case, let's say that we discretize the function space by choosing which one of the functions you're closest to for some . In addition, we also discretize the codomain of by looking at the lattice for some . Now, you'll notice that there's a radius disk around the origin which contains only functions mapping to the zero function, and as our lattice has fundamental area this means the "relative weight" of the singularity at the origin is like .

In contrast, all other points mapping to the zero function only get a relative weight of where is the absolute value of their nonzero coordinate. Cutting off the domain somewhere to make it compact and summing over all to exclude the disk at the origin gives for the total contribution of all the other points in the minimum loss set. So in the limit the singularity at the origin accounts for almost everything in the preimage of . The origin is privileged in my picture just as it is in the SLT picture.

I think your mistake is that you're trying to translate between these two models too literally, when you should be thinking of my model as a discretization of the SLT model. Because it's a discretization at a particular scale, it doesn't capture what happens as the scale is changing. That's the main shortcoming relative to SLT, but it's not clear to me how important capturing this thermodynamic-like limit is to begin with.

Again, maybe I'm misrepresenting the actual content of SLT here, but it's not clear to me what SLT says aside from this, so...

Replies from: tgb
comment by tgb · 2023-06-20T11:37:51.956Z · LW(p) · GW(p)

Everything I wrote in steps 1-4 was done in a discrete setting (otherwise  is not finite and whole thing falls apart). I was intending  to be pairs of floating point numbers and  to be floats to floats.

However, using that I think I see what you're trying to say. Which is that  will equal zero for some cases where  and  are both non-zero but very small and will multiply down to zero due to the limits of floating point numbers. Therefore the pre-image of  is actually larger than I claimed, and specifically contains a small neighborhood of .

That doesn't invalidate my calculation that shows that  is equally likely as  though: they still have the same loss and -complexity (since they have the same macrostate). On the other hand, you're saying that there are points in parameter space that are very close to  that are also in this same pre-image and also equally likely. Therefore even if  is just as likely as , being near to  is more likely than being near to . I think it's fair to say that that is at least qualitatively the same as SLT gives in the continous version of this.

However, I do think this result "happened" due to factors that weren't discussed in your original post, which makes it sound like it is "due to" -complexity. -complexity is a function of the macrostate, which is the same at all of these points and so does not distinguish between  and  at all. In other words, your post tells me which  is likely while SLT tells me which  is likely - these are not the same thing. But you clearly have additional ideas not stated in the post that also help you figure out which  is likely. Until that is clarified, I think you have a mental theory of this which is very different from what you wrote.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-20T12:39:40.040Z · LW(p) · GW(p)

Sure, I agree that I didn't put this information into the post. However, why do you need to know which is more likely to know anything about e.g. how neural networks generalize?

I understand that SLT has some additional content beyond what is in the post, and I've tried to explain how you could make that fit in this framework. I just don't understand why that additional content is relevant, which is why I left it out.

As an additional note, I wasn't really talking about floating point precision being the important variable here. I'm just saying that if you want -complexity to match the notion of real log canonical threshold, you have to discretize SLT in a way that might not be obvious at first glance, and in a way where some conclusions end up being scale-dependent. This is why if you're interested in studying this question of the relative contribution of singular points to the partition function, SLT is a better setting to be doing it in. At the risk of repeating myself, I just don't know why you would try to do that.

Replies from: tgb
comment by tgb · 2023-06-20T13:51:39.887Z · LW(p) · GW(p)

In my view, it's a significant philosophical difference between SLT and your post that your post talks only about choosing macrostates while SLT talks about choosing microstates. I'm much less qualified to know (let alone explain) the benefits of SLT, though I can speculate. If we stop training after a finite number of steps, then I think it's helpful to know where it's converging to. In my example, if you think it's converging to , then stopping close to that will get you a function that doesn't generalize too well. If you know it's converging to  then stopping close to that will get you a much better function - possibly exactly equally as good as you pointed out due to discretization.

Now this logic is basically exactly what you're saying in these comments! But I think if someone read your post without prior knowledge of SLT, they wouldn't figure out that it's more likely to converge to a point near  than near . If they read an SLT post instead, they would figure that out. In that sense, SLT is more useful.

I am not confident that that is the intended benefit of SLT according to its proponents, though. And I wouldn't be surprised if you could write a simpler explanation of this in your framework than SLT gives, I just think that this post wasn't it.

comment by interstice · 2023-06-18T16:32:02.465Z · LW(p) · GW(p)

I'm explaining why singularities(places where the minimum-loss set has self-intersections) would also tend to have higher effective dimensionality(number of degrees of freedom which you can vary while obtaining similar loss). That's what's novel about SLT as compared with previous broad-basin theories.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-18T16:46:20.043Z · LW(p) · GW(p)

I don't think this is something that requires explanation, though. If you take an arbitrary geometric object in maths, a good definition of its singular points will be "points where the tangent space has higher dimension than expected". If this is the minimum set of a loss function and the tangent space has higher dimension than expected, that intuitively means that locally there are more directions you can move along without changing the loss function, probably suggesting that there are more directions you can move along without changing the function being implemented at all. So the function being implemented is simple, and the rest of the argument works as I outline it in the post.

I think I understand what you and Jesse are getting at, though: there's a particular behavior that only becomes visible in the smooth or analytic setting, which is that minima of the loss function that are more singular become more dominant as in the Boltzmann integral, as opposed to maintaining just the same dominance factor of . You don't see this in the discrete case because there's a finite nonzero gap in loss between first-best and second-best fits, and so the second-best fits are exponentially punished in the limit and become irrelevant, while in the singular case any first-best fit has some second best "space" surrounding it whose volume is more concentrated towards the singularity point.

While I understand that, I'm not too sure what predictions you would make about the behavior of neural networks on the basis of this observation. For instance, if this smooth behavior is really essential to the generalization of NNs, wouldn't we predict that generalization would become worse as people switch to lower precision floating point numbers? I don't think that prediction would have held up very well if someone had made it 5 years ago.

Replies from: interstice
comment by interstice · 2023-06-18T17:08:32.186Z · LW(p) · GW(p)

If this is the minimum set of a loss function and the tangent space has higher dimension than expected, that intuitively means that locally there are more directions you can move along without changing the loss function

I think it is pretty obvious in the case of valleys without self-intersections, but that's just the broad basin case. As for the self-intersection case, well, if it's obvious to you that singularities will be surrounded by narrow bands of larger dimensionality -- including in cases where that "dimensionality" is fractional -- then you have a better intuition for the geometry of singularities than me and, I suspect, most other readers, so it might be helpful to make that aspect explicit.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-18T17:46:40.583Z · LW(p) · GW(p)

Say that you have a loss function . The minimum loss set is probably not exactly , but it has something to do with that, so let's pretend that it's exactly that for now.

This is a collection of equations that are generically independent and so should define a subset of dimension zero, i.e. a collection of points in . However, there might be points at which the partial derivatives vanishing don't define independent equations, so we get something of positive codimension.

In these cases, what happens is that the gradient itself has vanishing derivatives in some directions. In other words, the Hessian matrix fails to be of full rank. Say that this matrix has rank at a specific singular point and consider the set . Diagonalizing will generically bring into a form where it's the linear combination of quadratic terms and higher-order cubic terms, and locally the volume contribution to this set around will be something of order . The worse the singularity, the smaller the rank and the greater the volume contribution of the singularity to the set .

The worst singularities dominate the behavior at small because you can move "much further" along vectors where scales in a cubic fashion than directions where it scales in a quadratic fashion, so those dimensions are the only ones that "count" in some calculation when you compare singularities. The tangent space intuition doesn't apply directly here but something like that still applies, in the sense that the worse a singularity, the more directions you have to move away from it without changing the value of the loss very much.

Is this intuitive now? I'm not sure what more to do to make the result intuitive.

Replies from: interstice
comment by interstice · 2023-06-18T18:05:35.067Z · LW(p) · GW(p)

Hmm, what you're describing is still in what I was referring to as "the broad basin regime". Sorry if I was unclear -- I was thinking of any case where there is no self-intersection of the minimum loss manifold as being a "broad basin". I think the main innovation of SLT occurs elsewhere.

Look at the image in the tweet I linked. At the point where the curves intersect, it's not just that the Hessian fails to be of full-rank, it's not even well-defined. The image illustrates how volume clusters around a single point where the singularity is, not merely around the minimal-loss manifold with the greatest dimensionality. That is what is novel about singular learning theory.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-19T10:21:32.324Z · LW(p) · GW(p)

Can you give an example of which has the mode of singularity you're talking about? I don't think I'm quite following what you're talking about here.

In SLT is assumed analytic, so I don't understand how the Hessian can fail to be well-defined anywhere. It's possible that the Hessian vanishes at some point, suggesting that the singularity there is even worse than quadratic, e.g. at the origin or something like that. But even in this regime essentially the same logic is going to apply - the worse the singularity, the further away you can move from it without changing the value of very much, and accordingly the singularity contributes more to the volume of the set as .

Replies from: interstice
comment by interstice · 2023-06-19T21:15:36.404Z · LW(p) · GW(p)

In SLT L is assumed analytic, so I don't understand how the Hessian can fail to be well-defined

Yeah sorry that was probably needlessly confusing, I was just referencing the image in Jesse's tweet for ease of illustration(you're right that it's not analytic, I'm not sure what's going on there) The Hessian could also just be 0 at a self-intersection point like in the example you gave. That's the sort of case I had in mind. I was confused by your earlier comment because it sounded like you were just describing a valley of dimension , but as you say there could be isolated points like that also.

I still maintain that this behavior --- of volume clustering near singularities when considering a narrow band about the loss minimum --- is the main distinguishing feature of SLT and so could use a mention in the OP.

comment by martinkunev · 2023-10-01T23:20:38.624Z · LW(p) · GW(p)

To make this easier to parse on the first read, I would add that

N is the number of parameters of the NN and we assume each parameter is binary (instead of the usual float).

comment by Joar Skalse (Logical_Lunatic) · 2023-06-20T15:42:34.501Z · LW(p) · GW(p)

What is the exact derivation that gives you claim (1)?

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-20T17:12:18.659Z · LW(p) · GW(p)

Check the Wikipedia section for the stationary distribution of the overdamped Langevin equation.

I should probably clarify that it's difficult to have a rigorous derivation of this claim in the context of SGD in particular, because it's difficult to show absence of heteroskedasticity in SGD residuals. Still, I believe that this is probably negligible in practice, and in principle this is something that can be tested by experiment.

Replies from: interstice
comment by interstice · 2023-06-21T04:33:13.881Z · LW(p) · GW(p)

This might not hold in practice in fact, see this paper.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-21T09:43:25.011Z · LW(p) · GW(p)

That's useful to know, thanks. Is anything else known about the properties of the noise covariance beyond "it's not constant"?

Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don't they define SGD over a torus instead? Seems like it would fix the problem they are talking about, and if it doesn't change the behavior it means their explanation of what's going on is incorrect.

If the only problem is that with homoskedastic Gaussian noise convergence to a stationary distribution is slow (when a stationary distribution does exist), I could believe that. Similar algorithms such as Metropolis-Hastings also have pretty abysmal convergence rates in practice when applied to any kind of complicated problem. It's possible that SGD with batch noise has better regularization properties and therefore converges faster, but I don't think that changes the basic qualitative picture I present in the post.

Replies from: interstice
comment by interstice · 2023-06-21T16:44:53.392Z · LW(p) · GW(p)

Some comments on the paper itself: if the problem is that SGD with homoskedastic Gaussian noise fails to converge to a stationary distribution, why don’t they define SGD over a torus instead?

Good question. I imagine that would work but it would converge more slowly. I think a more important issue is that the homoskedastic/heteroskedastic noise cases would have different equilibrium distributions even if both existed(they don't say this but it seems intuitively obvious since there would be a pressure away from points with higher noise in the heteroskedastic case). I guess on the torus this would correspond to there being a large number of bad minima which dominate the equilibrium in the homoskedastic case.

Generally speaking the SGD noise seems to provide a regularizing effect towards 'flatter' solutions. The beginning of this paper has a good overview.

Replies from: ege-erdil
comment by Ege Erdil (ege-erdil) · 2023-06-21T18:17:09.135Z · LW(p) · GW(p)

As an aside, I've tried to work out what the optimal learning rate for a large language model should be based on the theory in the post, and if I'm doing the calculations correctly (which is a pretty big if) it doesn't match actual practice very well, suggesting there is actually something important missing from this picture.

Essentially, the coefficient should be where is the variance of the per-parameter noise in SGD. If you have a learning rate , you scale the objective you're optimizing by a factor and the noise variance by a factor . Likewise, a bigger batch size lowers the noise variance by a linear factor. So the equilibrium distribution ends up proportional to

where is the per-token average loss and should be equal to the mean square of the partial derivative of the per-token loss function with respect to one of the neural network parameters. If the network is using some decent batch or layer normalization this should probably be where is the model size.

We want what's inside the exponential to just be , because we want the learning to be equivalent to doing a Bayesian update over the whole data. This suggests we should pick

which is a pretty bad prediction. So there's probably something important that's being left out of this model. I'm guessing that a smaller learning rate just means you end up conditioning on minimum loss and that's all you need to in practice, and larger learning rates cause problems with convergence.