Proof Explained for "Robust Agents Learn Causal World Model"

post by Dalcy (Darcy) · 2024-12-22T15:06:16.880Z · LW · GW · 0 comments

Contents

  Recalling the Basic Setup
    Assumptions
  Proof of the Exact Case
    Load-Bearing part of Oracle Use
    Identification via Induction
      Constraint 1: σ will remove D's parents
      Constraint 2: σ fixes the environment variables other than those in the path.
      Constraint 3: Let σ contain a local intervention making Ck a binary variable
  Proof of the Approximate Case
    New Assumptions
    Load-Bearing part of Oracle Use
    Identification via Induction
    Identifying the Approximate Graph Structure
None
No comments

This post was written during Alex Altair's agent foundations fellowship program [LW · GW], funded by LTFF. Thank you Alex Altair [LW · GW], Alfred Harwood [LW · GW], Daniel C [LW · GW] for feedback and comments.

This is a post explaining the proof of the paper Robust Agents Learn Causal World Model in detail. Check the previous post [LW · GW] in the sequence for a higher-level summary and discussion of the paper, including an explanation of the basic setup [LW · GW] (like terminologies and assumptions) which this post will assume from now on.

Recalling the Basic Setup

Let's recall the basic setup (again, check the previous post [LW · GW] for more explanation):

  1. [World [LW(p) · GW(p)]] The world is a Causal Bayesian Network  over the set of variables corresponding to the environment , utility node , and decision node . The differences from a normal Causal Bayesian Network is that (1)  is a deterministic function of its parents , and (2) , the conditional probability distribution for , is undetermined—it's something that our agent will select.
  2. [Agent as a Policy Oracle [LW · GW]] An agent is a policy oracle  which is a function that takes in an intervention  (where  represents the set of all allowed interventions over ) and returns a policy .
  3. [Robustness as  [LW · GW]-optimality under interventions [LW · GW]] We define a "robust" agent as a policy oracle whose regret is bounded by  under some class of interventions over the environment .

Assumptions

Also recall the following assumptions [LW(p) · GW(p)]:

1) Unmediated Decision Task states that . This is pretty major.

The left diagram doesn't satisfy the Unmediated Decision Task assumption, the right one does.

2) Domain dependence states that there exists distributions over the chance variables  and  (compatible with ) such that .

These together imply:

  1. There does not exist a decision  that is optimal, i.e.  across all .
  2. , i.e. there can't be any intermediate nodes between  and , and all causal effects from  to  must be direct.

Now, with the basic setup of the paper recalled, let's prove the main theorems.

Proof of the Exact Case

We will first prove Theorem 1:

For almost all worlds  satisfying assumption 1 and 2, we can identify the directed acyclic graph  and the joint distribution  over all variables upstream of , given that we have access to a -optimal policy oracle.

I will attempt to present the proof in a way that focuses on how one could've discovered the paper's results by themselves, emphasizing intuitions and how trying to formalize them naturally constrains the solutions or assumptions we must use.

Load-Bearing part of Oracle Use

Somehow we're going to have to elicit information out of the policy oracle.

Recall that the oracle is a function that maps an intervention to a policy, which is a conditional probability distribution . It can be shown that if the oracle is optimal, this distribution is (almost always) a deterministic map. The argument goes like:

Then, somehow our information-eliciting procedure is going to have to exploit the fact that, given a , as we change the intervention , the optimal decision prescribed changes from  to .

To make this possible, we want to rule out the existence of a decision that is universally optimal across all inputs to the utility node, because then no intervention would yield a change in the output of the oracle.

Recall the first consequence of the two assumptions we had: There does not exist a decision  that is optimal, i.e.  across all .

This implies that there is at least one  where the optimal decision  differs from !

But the worry is that such  will be incompatible with .

So, we will restricted to only considering  that masks the inputs of  by only letting  depend on  such that .

Then, given , let  denote the optimal decision associated with some . Then, by the earlier argument, if we consider any intervention  that does , then it would have a different optimal decision, call it .

And to operationalize "as we change the intervention," we define a mixed local intervention .

When , the policy oracle (under ) would prescribe , and for , it would prescribe . There may of course be other intermediate optimal decisions along the way as you slowly increase  from  to  - say,  and  for the current example.

Note that once you switch your decision from  to  as you increase , you will never encounter  again because of linearity of . Namely:

The diagram below makes it more clear why linearity implies the same decision never gets chosen twice. The oracle can be thought of as attaining the upper envelope, as denoted in dotted line.

Let  represent the value of  at which the optimal decision switches from some  (may or may not be ) to .

Insight is interesting. It's a behavioral property of our oracle, meaning we can estimate it by repeatedly sampling it (across random samples of ). But it can also probably be expressed in closed-form in terms of some parameters of the environment (by just expanding out the definition of expected utility). So  is a bridge that lets us infer properties about the environment from the oracle.

Let's derive a closed-form expression of . Let .

Detailed Proof

 is a value that satisfies

Expanding out the left-hand side:

Expanding out the difference of both sides and setting it to zero:


Now we solve for :

Again,  can be estimated from sampling the oracle. We know the denominator because we assume the knowledge of .

Therefore, the oracle lets us calculate , the difference in expected utility of some two decisions given some context.

Restating our chain of inquiry as a lemma:

(Lemma 4) Given  that masks the inputs such that   such that we can approximate , where  and .

Identification via Induction

By the Unmediated Decision Task assumption, we see that the ancestors of U look like the following. We notice that there are two types of paths to consider in our induction argument.

Let's first consider the first type, , where .

We first define the following variables:

Assume  are known, and  are known.

The claim to prove is that, given these are known, we can identify the conditional probability distribution .

Assume we have some . We want to identify , and we have to somehow use the policy oracle for that.

Recall from Lemma 4 that given an intervention  such that , for all values of , there exists two different decisions  and  such that  can be identified.

The trick is in setting  such that it makes this sum contain  terms for all , for arbitrary choices of  and .

Let's think through how we might discover the right constraints on .

Constraint 1:  will remove 's parents

Since we want , let's just choose  such that  hence ).

Constraint 2:  fixes the environment variables other than those in the path.

Note the following: Since 's value can be computed, if it can be expressed in terms of  that would let us solve for , since by the induction hypothesis all the terms except it are known. Note that we also somehow have to figure out what the set  is, too.

Expanding out the above sum will give us some clue as to what further constraints we must impose on  in order for the sum to be expressed that simply:


How do we choose  such that

becomes

for arbitrary choices of ?

Note that setting  to a constant will:

Thus such intervention immediately gets rid of the  terms as we sum across , while being able to arbitrarily control the values of , and  (among other variables in ).

So constraint 2:  contains  (such that values of  should be compatible with the values of  that are set earlier in constraint .)

Then, we have the following expression:

So far, we haven't intervened in . So,  for  compatible with the value  (if applicable) and , further simplifying the expression:

But this isn't yet solvable. By induction hypothesis we know  for all values of  and , and we know the value of the left-hand side. This equation then involves  unknowns.

A fix then, is obvious an intervention that effectively sets  to , which brings us to the third constraint:

Constraint 3: Let  contain a local intervention making  a binary variable

Specifically, let  contain , where 

This effectively makes  a binary variable. Precisely:

and now the equation can be solved.

Let , which can be written , where .

The earlier  lets us simplify  as . Thus . We know  (via the policy oracle), we know values of  (via the induction hypothesis).

But important subtlety here: remember that we don't actually know  yet. The  in the above expression  is meant to be understood as the implicit assignment of values to the (yet unknown to us)  by the means of  in .

So, by performing a set of interventions that fixes all but one of the variables of , one can discover to which variables  responds to (the values of  changes), and hence figure out the  set.

Then, by varying the choices of , and , we can identify  completely.

The base case of  is clear, since  where  and  are of the form , which is known, and so is  using the oracle.

To recap, our choice of  is a local intervention such that:

  1. masks all input to , i.e. 
  2. fixes rest of the nodes in  to a constant
  3. does a local intervention  making  into a binary variable

and we have showed that this intervention lets us identify  for all  along the path , where .

Similar arguments can be used to prove the same for paths of the second type, , where .

Proof of the Approximate Case

Now, we will extend the proof to the approximate case (Theorem 2):

For almost all worlds  satisfying assumption 1 and 2 and some other new assumptions (explained below), we can identify the directed acyclic graph  and the joint distribution  over some subset of variables upstream of , and the quality of estimation for each of the conditional distributions scale linearly with .

New Assumptions

Unless I'm mistaken here and these can actually be derived from the earlier two assumptions (Unmediated Decision Task, Domain Dependence), here are the three new conditions that the paper implicitly assumes:

3) -optimal policies are (still) almost always deterministic

The earlier proof of determinism doesn't go through in the approximate case, but the paper implicitly assumes the policy oracle still (almost always) returns an output deterministically.

4) Uniform  regret

We say  is -optimal if .

We say  is uniformly -optimal if  for all , where we define.

Note that uniformly -optimal is a stronger condition than \delta-optimal in the sense that the former implies the latter.

5) Shape of the -optimal decision boundary

The left diagram is the ground truth for how the expected utility (under context ) of various decisions change as you increase  from  to . Then, the right diagram shows the decision boundary for the -optimal oracle, whose decision boundaries must exactly follow the intersection points of the left diagram's lines.

The paper then assumes that the -optimal oracle's decision boundaries must be simply a slightly shifted version of the 0-optimal oracle's decision boundaries, like the right diagram. A priori, there's no reason for the boundaries to look like this, e.g., it can look very complicated, like the left diagram. But the paper implicitly assumes this.

Let's now proceed to the proof. The subsections parallel that of the optimal oracle case.

Load-Bearing part of Oracle Use

Our goal is to derive , which we call .

Recall from earlier that in the optimal oracle case:

where we define  as follows:

Notice that  is the unique solution to .

Also note that . We know the value of . In the optimal oracle case, recall that  can be estimated via MCMC.

But the problem with -oracle is that this only yields a biased estimate, which we call .

Using  but naively substituting  with the biased estimate, we get a biased estimate for  that we call .

Our aim, then, is to bound the quality of the estimate  with the bound only involving non-estimate terms, like , and .

Expanding out, . That  is an estimate term, which we want to remove by bounding it via a term involving non-estimate terms. How?

First, we have yet to exploit any behavioral properties of the oracle that is related to . What are those? By definition, the oracle chooses  for  and  for  for very small . Then, the uniform  regret condition says:

Take  to , and assuming continuity, we can subtract the two to get

.

In other words, , or .

Because , we can rewrite the inequality as . Expanding out and rearranging the inequalities, we find out that .

Substituting this in to the expansion of Q-\tilde{Q}, we obtain the following simple bound: .

Finally,  in the denominator can be eliminated as follows:

where the last line is via Taylor expanding  around  (arbitrarily truncated at fourth order), valid for .

Or more simply, . The error term is linear with respect to  for small values of .

Identification via Induction

The argument is basically the same as that of the optimal case, except needing to incorporate error terms.

The exact case's induction hypothesis was that  and  are known for . Then, we showed that using a specific choice of , we can derive the relation  all the terms in the right-hand are known.

Then, for the approximate case's induction hypothesis, instead assume that  are known for , up to . We will show that this implies the knowledge of  up to . Let's denote the approximation we have , so .

Let . Because it is the sum of product of , overall it differs from  by .

Hence . Because  and  by the earlier section, we see that .

Using the big- fact that , we thus prove .

More simply,  as  goes to . That proves the induction step.

The base case of  is once again clear, since  where  and  are of the form , which is known, and so is  using the oracle. This shows that it can be computed, and the earlier paragraphs show that it is accurate up to .

Identifying the Approximate Graph Structure

The above showed that we can identify  up to , or more precisely,  up to , for all values of  and . In the optimal oracle case, this was sufficient for perfectly identifying , by holding all but one variables of  fixed and observing which changes in those variables cause a change in the values of .

What is the issue with the approximate case?

First of all, we'll have to use  instead of .

Say we want to test whether  is a parent of . So we have  fix everything in  and vary the value of  across the elements in its domain. Let's denote the value of  where  is set to  as .

So the process is: Set , vary  and see if there is a change, set a new , repeat.

The problem is that, because  is only accurate up to , we can't tell if the change is due to actual differences in the underlying  or due to the error in approximation.

The solution is to use one of the earlier explicit bounds on  in terms of quantities that the algorithm has access to, i.e. . We can then use this bound to derive an explicit upper and lower bound for the values of , which we'll call  and .

And if it's the case that there exists  such that there exists  and  whose intervals  and  don't overlap, then we can guarantee that the change is due to actual differences in the underlying .

This procedure lets us identify a subset of , hence a subgraph of .

Detailed Proof

Suppose given a context, two decisions had the same expected utility.

Recall the definition: literally taking expectation over all the values that the ancestor of  could take.Let  and .

where  goes  if  is incompatible with .

Note  and , because  only has an effect on its descendants, which  isn't part of, and neither is .

Therefore, 

And we're curious about the case when the difference in expected utility is zero:

Suppose that  without loss of generality. Then the terms can be written as such:

Long story short, this is a polynomial constraint on the parameters  of the network, and solution sets of a polynomial equation have measure zero (intuitively because for a polynomial equation to be precisely equal to zero, then the values should precisely be aligned, which is rare).

  • Specifically: given a Bayesian Network G over  nodes, its CPDs are as follows: . And since we're assuming all variables are discrete, these are a finite number of values that parameterize , each of which takes a value between 0 and 1.
  • Suppose we find that they should satisfy some polynomial constraint: .

Then we can reasonably claim that this is extremely unlikely to happen, because in the space of all possible parameterizations , solutions to a polynomial constraint happen in a measure zero set.

0 comments

Comments sorted by top scores.