[Linkpost] Towards a Theoretical Understanding of the 'Reversal Curse' via Training Dynamics

post by Bogdan Ionut Cirstea (bogdan-ionut-cirstea) · 2024-05-11T22:59:12.368Z · LW · GW · 0 comments

This is a link post for https://arxiv.org/abs/2405.04669

The excerpts below seem to me like a slight update towards arguments about the weakness of one-forward-passes in transformers [LW(p) · GW(p)] and for agendas like externalized reasoning [LW · GW]and translucent thoughts [LW · GW]. They also suggest out-of-context reasoning (OOCR) might remain hard for transformer-based LMs (they currently do very poorly on OOCR evals and scaling also doesn't seem to help much):

We theoretically analyze reversal curse where training or test sequences have the from “𝐴→𝐵” or “𝐵←𝐴” via training dynamics of (stochastic) gradient descent under two auto-regressive models: a bilinear model (Section 3) and one-layer transformers under certain assumptions similar to Tian et al. (2023a) (Section 4). The analysis of the training dynamics of both models reveals a core reason why the reversal curse happens: the weights of the autoregressive models are asymmetric, i.e., the increase of weights from the token 𝐴 to token 𝐵 during training does not necessarily cause the increase of the weights from 𝐵 to 𝐴. [...] Although the (effective) weights from 𝐴 to 𝐵 and from 𝐵 to 𝐴 might be related to some extent since they are both computed using the same set of embeddings, their correlation is weak and thus show asymmetry as verified both theoretically (Sections 3 and 4) and empirically (Section 5).


we use the above framework to show the necessity of chain-of-thought (COT) (Wei et al., 2022b) (i.e., a model trained on “A implies B” and “B implies C” separately struggles to directly conclude “A implies C” without COT, which was also empirically observed by Allen-Zhu and Li (2023)) via training dynamics of one-layer transformers (Section 4.2), which provides a new perspective different from previous work Feng et al. (2024) that focuses on the expressivity of transformers. Slightly different from the reason for the reversal curse, in COT analysis, the model weights show intransitivity, i.e., increasing the weights from the token 𝐴 to 𝐵 and 𝐵 to 𝐶 does not necessarily increase the weights from 𝐴 to 𝐶. We emphasize again that the weights refer to effective weights.


We also empirically validate our theoretical results through the training dynamics of multi-layer transformers (Section 5).


The asymmetry and intransitivity of weights of auto-regressive models indicate that auto-regressive LLMs might not automatically deduce indirect conclusions using separate knowledge learned during training: to make a model predicting token 𝐵 where the input token is 𝐴, the model needs to see 𝐵 following 𝐴 in the same sequence during the training set due to the next token prediction objective and model architectures. This also highlights the importance of ICL, data augmentation, or planning for LLMs with the current popular causal transformer-based structures to solve complex reasoning tasks.


Comments sorted by top scores.