A starting point for making sense of task structure (in machine learning)

post by Kaarel (kh), RP (Complex Bubble Tea), jake_mendel · 2024-02-24T01:51:49.227Z · LW · GW · 2 comments

Contents

  Introduction
    Why understanding task structure could be useful
      Interpretability
      Learning the abstractions
      Unlearning capabilities
      Quantifying generalization
      Learning how the world works
    Some Subtleties
      What is a task?
      Task decomposition in the dataset vs a particular system’s task decomposition
      Absolute vs relative metrics vs clusterings
  Methods for gauging task structure in ML
    Inspecting activations
    Inspecting learning
    Inspecting weights
  Analogues in humans
  A toy model for testing task decomposition techniques
  Acknowledgements
None
2 comments

ML models can perform a range of tasks and subtasks, some of which are more closely related to one another than are others. In this post, we set out two very initial starting points. First, we motivate reverse engineering models’ task decompositions. We think this can be helpful for interpretability and for understanding generalization. Second, we provide a (potentially non-exhaustive, initial) list of techniques that could be used to quantify the ‘distance’ between two tasks or inputs. We hope these distances might help us identify the task decomposition of a particular model. We close by briefly considering analogues in humans and by suggesting a toy model.

Epistemic status: We didn’t spend much time writing this post. Please let us know in the comments if you have other ideas for measuring task distance or if we are replicating work.

Introduction

It might be useful to think about computation in neural networks (and in LMs specifically) on sufficiently complex tasks as a combination of (a) simple algorithms or circuits for specific tasks[1] and (b) a classifier, or family of classifiers, that determine which simple circuits are to be run on a given input. (Think: an algorithm that captures (some of) how GPT-2 identifies indirect objects in certain cases combined with a method of identifying that indirect object identification is a thing that should be done.[2]) More concretely, some pairs of tasks might overlap in that they are computed together much more than are other pairs, and we might want to build a taxonomic tree of tasks performed by the model in which tree distance between tasks is a measure of how much computation they share.[3] For example, a particularly simple (but unlikely) task structure could be a tree of depth 1: the neural network has one algorithm for classifying tasks which is run on all inputs, and then a single simple task is identified and the corresponding algorithm is run.

Why understanding task structure could be useful

Interpretability

We might hope to interpret a model by 1) identifying the task decomposition, and 2) reverse-engineering both what circuit is implemented in the model for each task individually, and how the model computes this task decomposition. Crucially, (1) is valuable for understanding the internals and behavior of neural networks even without (2), and techniques for making progress at it could look quite different to standard interpretability methods. It could directly make the rest of mechanistic interpretability easier by giving us access to some ground truth about the model’s computation—we might insist that the reverse engineering of the computation respects the task decomposition, or we might be able to use task distance metrics to identify tasks that we want to understand mechanistically. Further, by arranging tasks into a hierarchy, we might be able to choose different levels of resolution on which to attempt to understand the behavior of a model for different applications.

Learning the abstractions

Task decomposition can give direct access to the abstractions learned by the model. Ambitiously, it may even turn out that task decomposition is ‘all you need’—that the hard part of language modeling is learning which atomic concepts to keep track of and how they are related to each other. In this case, it might be possible to achieve lots of the benefits of full reverse engineering, in the sense of understanding how to implement a similar algorithm to GPT4, without needing good methods for identifying the particular way circuits are implemented in any particular language model. Realistically, a good method for measuring task similarity won’t be sufficient for this, but it could be a helpful step.

Unlearning capabilities

We’d like to be able to train models which have certain capabilities but not others. For example, we might want to train a model that can make advancements in vaccine design, but is incapable of designing a bioweapon, perhaps via unlearning bioweapons capabilities. Clearly if two tasks are similar enough, it is not possible to destroy performance at one without affecting the other. Access to the task hierarchy would allow us to understand which capability combinations are feasible.[4] In addition to helping technical researchers, this would be useful for helping policymakers understand the tradeoffs that must be made for good AI regulation. It may also be helpful for doing capability evaluations in models (perhaps after we attempt to remove a capability with unlearning) when we are worried they are sandbagging the eval for deceptive reasons — we can study the model’s performance on similar tasks and become suspicious if the model’s capabilities seem to not respect the task decomposition.

Quantifying generalization

Having a way to quantify the distance between tasks could lead to a way to measure the ability of a model to generalize[5] by providing a standard unit of ‘generalization distance’ that transfers across tasks and types of intelligent systems (eg. humans and neural networks).[6] Other than being of object-level interest, this is helpful for evaluating and predicting capabilities. Indeed, the ability to generalize (which is intuitively related to the ability to learn quickly) is often cited as a key limitation of present ML models compared to humans. We think it'd be interesting to compare generalization distance in models and humans, e.g., for forecasting when model generalization performance will beat human generalization performance (i.e., maybe, when we'll have AGI). It may also be possible to use distance metrics to track model generalization in more fine-grained ways, e.g. by comparing the input clusterings of different Pythia checkpoints to see when certain inputs first come to be seen as similar by the model, or comparing the subtask clusters of GPT-3 to those of GPT-4, potentially seeing certain clusters merge or split.

Learning how the world works

More generally, science is about identifying the structure and patterns in the world; the task taxonomy learned by powerful language models may be very convergent and could be a useful map for understanding the territory of the world we are in. What’s more, such a decomposition would itself be of scientifico-philosophical interest — it would tell us something about thinking.

Some Subtleties

What is a task?

For defining distance metrics between tasks, it is useful to have an operationalization of a ‘task’. In this post, when we speak about task similarity, a task is specified by providing a dataset (of inputs, or of input-output pairs).[7] For example, the Indirect Object Identification task is specified by providing a dataset of pieces of tasks and completions. There are some concerns to keep in mind here:

Some distance metrics can be applied to measure the task distance between individual data points rather than data sets, which could allow us to create a weighted graph between data points[8]. Clustering on this graph (or perhaps fuzzy clustering which puts nodes into multiple clusters, a bit analogous to sparse autoencoding, or hierarchical clustering which arranges clusters in a hierarchy) may allow us to identify tasks. Unfortunately, some distance metrics only work on datasets of several inputs.[9]

Task decomposition in the dataset vs a particular system’s task decomposition

We will sometimes talk about the task decomposition (of, e.g., natural language) without referring to a particular reference system that is attempting to do the tasks[10]; and sometimes about how some particular model or another (or a human) (implicitly) decomposes natural language into tasks. Here are some ways in which these are related: (1) each can provide a helpful guess for the other; (2) alternatively, one could argue that the former really only makes sense as a case of the latter with the observer left implicit (though we think there's more to the former than that); (3) uniformity across observers of the latter (is worth investigating and) could help establish that the former is a sensible thing to consider. But we won't track this distinction.

Absolute vs relative metrics vs clusterings

Most of the metrics provided below are not intended to output individually meaningful/interesting numbers; indeed, most are not actually metrics in the precise mathematical sense. However, the numbers can become meaningful when compared to other outputs. For example, it's hardly meaningful to say that a pair of inputs have a certain kind of similarity , but it could begin to be meaningful in a context where other similarities are . And even if these similarities are also somehow wrongly normalized — the ordering of these similarities is sometimes not that of the 'true similarities' — clusterings could still be meaningful. More generally, we won't be mathematically careful (that said, we will try not to get anything 'wrong'). For example, we will not discuss which clustering algorithm is most appropriate in a particular context. To be clear: we consider it obviously valuable to be mathematically careful — it's just outside the scope for now.

Methods for gauging task structure in ML

In this section, we specify a number of ways one can try to measure task similarity.[11]

Inspecting activations

The activation-based metrics below are trying to get at task similarity by measuring whether the representations computed on two tasks (or two inputs) are similar or otherwise related.

Inspecting learning

Here, we discuss methods that gauge task structure via examining a model's learning. The first three similarity metrics below are supposed to track whether two tasks benefit from [the same things being learned] / [the same existing internal structures being reinforced]. The last metric below tries to get at whether two behaviors were learned from similar sets of examples. One way these might differ from the metrics above is that it is possible some of these metrics would already begin to be meaningful before the model has interesting fully-formed internal structures.

Inspecting weights

Here, we discuss methods that gauge task structure via inspecting/changing a model's weights. Note that the first three metrics in this section would have fit equally well under the above subsection on learning.

Analogues in humans

A toy model for testing task decomposition techniques

We briefly propose a family of toy data sets with custom chosen ‘ground truth’ task decompositions, in the sense that for each data set, there is a particular task decomposition [a model which ends up getting low loss when trained on the data set] would plausibly learn.

For the toy model, we create an artificial set of tasks with relationships that we choose. We build on the toy multitask sparse parity (MSP) task from Michaud et al. In the MSP task, each input bitstring is split into control bits and task bits. The number of control bits is equal to the number of subtasks; the control bits are always set to 0s except for having a 1 in one token position, identifying which subtask is to be performed. The task bits can be 0 or 1 freely, and each subtask is to calculate the output of a particular boolean function on the task bits, with the particular choice of boolean function/subtask specified by which control bit is present (let’s say that the assignment of boolean functions to control bits is arbitrary). The suitability of the MSP task comes from us having access to the ground truth task decomposition in this case: it is a lookup table (or depth 1 tree) of disjoint subtasks. We can straightforwardly modify the set of subtasks[15] to give them interesting relational structure in a number of ways:

A point of studying these toy models is to give us some feedback on how good different distance metrics are and what it is precisely that each one measures[17] (although language models are of course likely to be different from the toy model in important ways).

Acknowledgements

Thanks to Andis Draguns and Lawrence Atkins for helpful discussions, including contributing a couple methods; to Clem von Stengel, Lucius Bushnaq, Nina Rimsky, Robert Avery, Dmitry Vaintrob, Caspar Oesterheld, and Hoagy Cunningham for discussions, comments, and edits; and potentially to people we've forgotten (feel free to message us).


  1. What we are calling a ‘task’ is similar to what Arora & Goyal call a ‘skill’. ‘Task’ takes the perspective of asking a model to do something; ‘skill’ takes the perspective of the model. This notion of ‘task’/’skill’ is also very similar to what Michaud et al. call a `quantum'. These authors also make certain assumptions about skills/quanta that we think of as providing interesting concrete cases. So, the picture we present here differs from the pictures proposed by these authors in that we have a looser notion of task decomposition having to do with 'how much computation is shared' / 'how similar the computation is' which could be made precise in various ways (including using ideas from these authors). It also differs in that, at least in the tree picture of tasks we present, the tasks have internal structure that can be shared with other tasks — they could be composed of shared circuits/skills/quanta. But the picture here is much inspired by Michaud et al.. ↩︎

  2. A possible objection here: wouldn't the ideal indirect object identification circuit be more like a full description of what a model does to do IOI; i.e., isn't the task dictionary/classifier part of this decomposition unnecessary? So, couldn't it be more like: there's a bunch of circuits for various tasks that are always running, except perhaps not having some nodes activate because of looking for something that does not exist in the input, or something, and then the final answer is some aggregation of the outputs of the circuits that do activate? Well, maybe it could be like that (or, at least, we won't get into an extended analysis whether it could be like that here), but as far as we can tell, Wang et al. is not significant evidence of that — the method used would plausibly not detect computations with outputs shared by everything on the dataset on which its mean ablations are computed. In particular, its method would not find a hypothetical task classifier (which could well be more complex than the circuit found) which always decides that the task is indeed IOI on the reference data set (this is also true for resampling ablations from the same data set). In any case, even if the correct hypothesis to entertain were that models are more like ensembles of unconditional circuits that always 'try to run', the present bullet point would still make sense, mutatis mutandis. ↩︎

  3. Besides conceptual reasons to think a taxonomy is appropriate, work like Saxe et al. provides motivation for a tree-like structure. ↩︎

  4. A slightly more nuanced model of unlearning is: almost certainly the path of things that must be learned for vaccine design and bioweapon design is very similar, with a fork at the end. Unlearning bioweapon design is not a binary thing, but a spectrum from just superficially not outputting bioweapon advice without a jailbreak to reinitialising all the parameters in the network. One way of quantifying the degree of unlearning is how many steps of fine tuning are required to reintroduce the capability. If we want to unlearn bioweapon design without unlearning vaccine design then we can walk the model back up the bioweapons path until we hit the fork: the further the fork is from the end of the path (corresponding to higher task distance), the more deeply we are able to unlearn bioweapons without affecting vaccine capabilities. Equivalently, the more deeply we unlearn bioweapons, the more ‘collateral damage’ we necessarily pick up in terms of unlearning other things by accident, in order of increasing task distance from bioweapons. One problem with this picture is that it might not be a good way of describing the capabilities of a generally intelligent system which has learned how to learn about the world efficiently (eg. a system capable of making research advances) because it may be impossible to unlearn bioweapon design in this system such that the system could not rediscover bioweapons on the fly without unlearning general reasoning capabilities. ↩︎

  5. It seems reasonable to operationalize generalization as applying understanding of a task (say, writing English poetry) to other subtasks (e.g. writing French poetry) of a certain natural "metatask" (writing poetry). ↩︎

  6. Generalization distance clearly depends on things like the amount of allowed fine-tuning, the number of few-shot examples etc, and some of these things can be hard to compare to a human, but one might hope that we can fix an allowed amount of fine-tuning/prompting and still end up with something that makes sense. ↩︎

  7. Roughly equivalently, we can alternatively think of a task as being specified by a distribution (of inputs or input-output pairs). ↩︎

  8. Given a way to measure the similarity between two inputs, one can measure the similarity between two data sets with the expectation of the similarity between a random input from the first and a random input from the second. ↩︎

  9. Still, we think there are likely reasonable ways to go from certain task-wise metrics to task decompositions — for instance, minimizing the sum of distances of each proposed task to itself minus the sum of distances between different proposed tasks — but we haven’t thought about this carefully. ↩︎

  10. This is different from topic modeling, though we don’t rule out that approaches from topic modeling could be brought to bear here. When the domain is natural language, we are not looking for a partition of contexts/documents into topics here, nor quite a partition of words into topics (which topic modeling methods provide), at least in the sense of usual topic modeling methods. What we have in mind is more like a classification of which pattern(s)/rule(s) was(/were) used to generate each token, or more correctly, which pattern(s)/rule(s)/skill(s) might most naturally be used to predict each token, and (while admittedly not being very familiar with this literature) we don’t expect that standard topic modeling techniques would get at this with the level of sophistication we’d like. ↩︎

  11. We note that each method below could well turn out not to measure any reasonable kind of similarity. We also note that the methods would likely end up measuring distinct flavors of task similarity, but we do not provide a detailed analysis of these flavors. ↩︎

  12. We might similarly want to upweight contributions from middle layers in many scores below. ↩︎

  13. Here and later, it would also make sense to look at the change in a more fine-grained manner, i.e., to not just track this single parameter. ↩︎

  14. We add the small amount of fine-tuning on at the end because we want the model to be able to make some amount of connections between what it has learned from and the new domain . ↩︎

  15. We also will probably want each task to be equally frequent in the dataset, unlike the original MSP task, which was designed with a different purpose in mind. ↩︎

  16. That is, if the subsets for that task are and , then the task is to compute . ↩︎

  17. One can make progress here by running an experiment or by just thinking through what each task decomposition method would capture when applied on a plausible NN-implementation of an algorithm solving the task. ↩︎

2 comments

Comments sorted by top scores.

comment by NicholasKees (nick_kees) · 2024-02-24T21:42:55.650Z · LW(p) · GW(p)

More generally, science is about identifying the structure and patterns in the world; the task taxonomy learned by powerful language models may be very convergent and could be a useful map for understanding the territory of the world we are in. What’s more, such a decomposition would itself be of scientifico-philosophical interest — it would tell us something about thinking.

I would love to see someone expand on the ways we could use interpretability to learn about the world, or the structure of tasks (or perhaps examples of how we've already done this?). Aside from being interesting scientifically, maybe this could also help us build economically valuable systems which are more explicit and predictable?

comment by Adam Shai (adam-shai) · 2024-02-29T20:07:11.952Z · LW(p) · GW(p)

I find this focus on task structure and task decomposition to be incredibly important when thinking about what neural networks are doing, what they could be doing in the future, and how they are doing it. The manner in which a system understands/represents/instantiates task structures and puts them in relation to one another is, as far as I can tell, just a more concrete way of asking "what is it that this neural network knows? what cognitive abilities does it have? what abstractions is it making? under what out of distribution inputs will it succeed/fail, etc."

This comment isn't saying anything that wasn't in the post, just wanted to express happiness and solidarity with this framing!

I do wonder if the tree-structure of which-task and then task algorithm is what we should expect, in general. I have nothing super concrete to say here, my feeling is just that the manners in which a neural network can represent structures and put them in relation to eachother may be instantiated differently than a tree (with that specific ordering). The onus is probably on me here though - I should come up with a set of tasks in certain relations that aren't most naturally described with tree structures.

Another question that comes to mind is, is there a hard distinction between categorizing which sub-task one is in and the algorithm which carries out the computation for a specific subtask. Is it all just tasks all the way down?