Learning the Q value of multi-step actions

Assume that we have a policy π(a0:k1|s0) that is able to output a variable length k of actions given any state s0. We define a Q function Q(s0,a0:k1) which computes the state-action value of any action sequence a0,a1,,ak1 given s0. The intuitive meaning of this quantity is: what expected future return will we get, if starting from s0, we take actions a0,a1,,ak1 regardless of the future states in the following k steps, after which we follow our policy π? In fact, our k-step action sequence is open-loop because it doesn’t depend on the environment feedback in the next k steps.

We can use the Bellman equation to learn Q(s0,a0:k1) from the data of a memory (e.g. a replay buffer or a search tree). The bootstrapping target for Q(s0,a0:k1) is defined as

TπQ(s0,a0:k1)=i=0k1γiri+γkEskP(|s0,a0:k1)Eak:k+k1π(|sk)Q(sk,ak:k+k1),

where P(sk|s0,a0:k1) is the environment transition probability. In practice, each time we sample a k-step transition s0,a0:k1,sk from the memory, sample a new sequence of actions ak:k+k1π(|sk) with the current policy, and use i=0k1γiri+γkQ(sk,ak:k+k1) as a point estimate of the above target to fit Q(s0,a0:k1). So the key question is: whether the empirical frequency of sk (from the memory) correctly matches the theoretical distribution of sk after we executes a0:k1 in an open loop manner at s0?

The answer is not always, and it depends on how the transition s0,a0:k1,sk is obtained from the memory (e.g. replay buffer or search tree). In a usual case, the action sequence a0:k1 might come from stitching two action sequence outputs during rollout, as the example below shows.

Here, the replayed transition s1,a2:6,s7 actually spans over two rollout policy samplings. This might cause issues for learning the state-action value in some MDPs.

Deterministic environment

For a deterministic environment, it’s completely fine to learn Q(s0,a0:k1) by arbitrarily stitching together temporally contiguous a0:k1 from a memory, even though these actions were not generated in one shot (i.e., generated once for all at s0). The reason is that given s0 and any subsequence a0:m(mk1), sm+1 is uniquely determined in a deterministic environment. So the final state sk correctly reflects the transition destination state with a probability of 1.

Stochastic environment

However, in a stochastic environment, arbitrarily stitching together contiguous actions will have issues. To see this, consider the example above, where a1 can be generated at two different states s1 (prob of 0.2) and s1 (prob of 1), where these two states are valid stochastic successors to (s0,a0) with a chance of 0.5. We assume that the episode always starts with s0 and ends after either s1 or s1 (suppose the end state is denoted by s2).

By our definition, Q(s0,a0:1) (the state-action value of taking a0 and then a1 starting from s0) should be 0.5×20.5×1=0.5. Note that when we compute this value, we should not look at the action distribution at s1 or s1. Even though a1 has only a prob of 0.2 at s1, we always take it after a0.

Now suppose all data in the memory are generated step by step. When using TD learning to estimate Q(s0,a0,a1), we sample either (s0,a0,s1,a1,s2) or (s0,a0,s1,a1,s2) from the memory. The problem is, the second transition appears 4x more frequently than the first one! Thus the learned Q(s0,a0,a1) will be 56×1+16×2=0.5. This is completely wrong given the correct value of 0.5.

So what happened? In this case, we can only learn the correlation between (s0,a0,a1) and Q(s0,a0,a1) but not their causal relationship (i.e., what expected return we will get if taking a0,a1 staring from s0?).

How to avoid the causal confusion

For a policy π(a0:k1|s0) that is able to output a variable length k of actions given any state s0, in general we need to honor the action sequence boundary when sampling from the memory for TD learning. That is to say, in the first figure, we can only sample subsequences within (a0,a3) or (a4,a9), but not across the boundary a3|a4.

Haonan Yu
Haonan Yu
Researcher & Engineer

Personal page