Paying Attention to Attention
attention-encodingsFresh off the last experiment, I’ve decided to go back to an old experiment and dig a little deeper.
A while ago I was working on an experiment trying to figure out what attention layers are actually doing in transformers.
The traditional introduction to transformers goes something like this:
- You turn your text into tokens
- You create embeddings for each of those tokens
- Some additions and projections later, that token encoding goes into an attention layer, at a particular spot in the sequence based on the order of the text. Let’s assume our token is at position 3.
- The attention layer compares that token encoding to all the others in the sequence (query and keys, yes yes), and uses the relative weights of their similarities as weights to apply to the original embeddings in the sequence (projected by a value matrix.) (This won’t make sense if you don’t already know how transformers work, for which I apologize.)
- The output of that attention layer at position 3 is the original token, but with information from closely-related tokens mashed into it. So it’s a wider, more conceptually complex representation of the original token at position 3.
- Do this 12 times with some standard fully-connected neural networks in-between, and you end up with a new sequence where each encoding is a representation of that original token, but with “meaning” infused based on the overall context of every other token in the sequence.
- Use those fancy high-class output encodings to predict the next token, or classify your sequence, or whatever.
But that “output of the attention layer is another form of the original token” has always been interesting to me. What does an encoding like that look like? Can you do something else with it? Does it relate back to the original token in an interesting way?
But this experiment pretty clearly showed:
- Even at the first layer, the encodings are all mixed up between each other. Some tokens encodings almost seem to switch places.
- After layer 1, the encodings don’t have any sort of relationship to the original tokens. They’re in a totally different conceptual space, and likely any information in them is distributed so far and wide among the tokens that it has no more structure than a typical fully-connected network.
This is quite different from how they’re usually explained. They’re typically infused with some sort of semi-mystical “creating new regions of meaning by combining relevant concepts” abilities or something like that.
I have some hot takes about the situation, but I’ll hold off on them for now.
The original version of this experiment, while useful, is a bit hard to understand, both in the UI and, to be fair, in the code. So I’m going to add some things to make it easier to see what the point is of all this, and re-do the code to make it easier to see what the point of all the code is.
Stay tuned.