Review 9: The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning

The Sensory Neuron as a Transformer: Permutation-Invariant Neural Networks for Reinforcement Learning by Dr. Ha and Yujin Tang

  • Citation: Tang, Yujin, and David Ha. “The sensory neuron as a transformer: Permutation-invariant neural networks for reinforcement learning.” Advances in Neural Information Processing Systems 34 (2021).

  • Today, I am reading a paper in NeurIPS by Dr. David Ha and Dr. Yujin Tang.
  • These reviews are not really reviews but more like casual reading notes. I am writing what pops into my head as I read.
  • The first thing that stands out to me just from reading the introduction is that shuffling around sensory input is generally fatal for most network problems. I doubt most image recognition modules would work at all if one just shuffled around 4x4 pixel blocks in an image.
    • I can think of a few problems that begin to crop up here.
    • Firstly, if we think about swapping the pixels in a picture of a cat one by one, how many swaps do we need to make until it is just not a picture of a cat.
      • This issue is rectified in the paper because they are trying to complete a task or play a game, which means the objective is still definitively the same regardless of input shuffling.
    • Secondly, how does one parameterize image shuffling? Say we decide to shuffle around $64 x 64$ squares in an image. If we know the image is $256 x 256$ then it makes sense to define $16$ networks. But we’ve cheated and used some a priori knowledge about the image permutations.
      • $\frac{256 \times 256}{64 \times 64} = 16$
    • But now I start getting ahead of myself, I’ll keep reading and perhaps these issues aren’t relevant. Onward!
  • One more thing in the introduction is really surprising/exciting about this. All these references to cellular automata (CA) and emergent behavior makes me really excited to see how these local networks are going to start forming a coherent policy.
  • Self-organizing network agents, while being a mouthful, does seem like a natural progression. I’m having trouble explaining why. I wanted to write something about NNs working better on local computation, but they do seem to do global computation pretty well too. I’m not sure this is the right distinction to make (whether NNs are better as local agents or singular global agents). Also there’s so much more research for the latter.
  • Meta-learning policies
    • fast weights: adapt weights to input.
    • associative weights: allow RNN weights to be attracted to recent hidden states.
    • hypernetworks: one network generates the weights for another.
    • Hebbian learning: See more here
  • Attention
    • similar to previous adaptive weight mechanisms in that it modifies weights based on inputs.
    • Authors cite several examples of attention learning inductive biases.
  • Method
    • A legit permutation invariant (PI) agent doesn’t need to be trained on permutations to recognize them.
    • PI formulation: need a function $f(x): \mathcal{R}^N \longrightarrow \mathcal{R}^M$ s.t. $ f(x[x]) f(x)$ where $s$ is some permutation of the indices ${1 … n}$
    • Self-attention as PE.
      • Simplest self attention: $\sigma(QK^T)V$ where $Q,K \in \mathcal{R^{n\times d_{q}}}$ and $V \in \mathcal{R^{n\times d_{v}}}$. Q,V and K are query, value and key matrices respectively.
      • Normally these (Q,K,V) are functions of the input.
    • contribution
      • authors propose to add an extra layer (called AttentionNeuron) in front of agent’s policy network $\pi$ that accepts observation $o_t$ and previous action $a_{t-1}$ as inputs.
      • The $i$th neuron only has access to $o_t[i]$.
      • each sensory neuron computes messages $f_k(o_t[i], a_{t-1})$ and $f_v(o_t[i])$.
      • These messages are aggregated in global latent code $m_t$.
    • So I think it’s interesting that based off the figure, they don’t model this AttentionNeuron layer as a distinct part of each neuron but rather a layer that receives input from each neuron and aggregates input into $m_t$.
      • Computationally, it doesn’t matter if the transformation is affixed to the last layer of each neuron or if it stands alone. But I thought it was illustrative of how the Query, Key and Value matrices are formed.
    • Key matrix depends on $f_k(o_t, a_{t-1})$ and value matrix depends on $f_v(o_t)$
    • They also have projection matrices $W_k, W_v, W_q$ (which is what gives us PE).
    • Experiments and observation space are
      • Cartpole $\mathcal R^5$
      • Ant $\mathcal R^28$
      • Car Racing $\mathcal R^{96x96x4}$
      • Atari Pong $\mathcal R^{84x84x4}$
        • Car Racing and Atari Pong have dimension $4$ because they are stacked grayscale RGB values.
    • They never use more than 196 activation neurons.
    • Also activation neurons do seem to be a function of observation space.
    • Action space ranges from 1 action, to 8 possible actions.
  • Brief discussion of design choices.
    • I wonder what the motivation is for making $QW_q$ learnable instead of input dependent.
  • They didn’t use a projection matrix for the Value matrix ($V$)
  • The don’t use RNNs for high dimensional input, just feed forward neural network (FFN).
  • Results
    • CartPoleSwingUpHarder
      • They compare against a two layer FFN trained with CMA-ES
      • The benchmark agent is able to balance the pole faster under normal circumstances because their method requires a few burn in steps.
      • But what is really notable here is that various peterbutations of shuffling the input or concatenating a noisy vector to the input don’t affect performance. No retraining required!
        • Might’ve liked to see a comparison against retrained benchmark.
        • Yes I know it’s comparing apples to oranges and it probably wouldn’t make sense to put in a paper … but I’m curious.
    • PyBullet Ant
      • Uses pre-trained models and behavior cloning for the benchmark.
      • From table 4 the most notable insights I picked up:
        • BC works better with larger subsequent layers.
        • ES shuffled works well on their agent as is. Slightly underperforms non-shuffled input version (FFN teacher). Shuffled FFN gets a really bad score.
    • Atari Pong
      • ** shuffled atari pong :)
      • I like the idea of maintaining a spatial representation of the 2D grid in $m_t$ because it is more interpretable.
      • I see that this agent can take a subset of the screen and the authors suggest that would be interesting to conduct some kind of ablation experiment on. While I agree, I am also interested in what would happen if you blacked out part of the screen, thus forcing the agent to hallucinate what is happening in those shuffled blocks.
      • The occluded agent rarely won when it couldn’t see certain blocks, but when the blocks are added back it does well.
      • This is the generalization advantage.
      • Figure 6 is a great example of why keeping 2d spatial embeddings is a nice design. The separability of embedding states based on state space is cool to look at.
        • Sometimes when I read these papers, I wonder how they coded some things in. This is one I am particularly interested on.
    • Car racing
      • As I guessed in the beginning, attention-based mechanisms are crippled by shuffling inputs.
        • This is shown in AttentionAgent failure to do anything but an original task.
      • I have no idea how netRand + AttentionAgent are combined and I don’t understand them pretty well.
      • That said, it’s clear that is the only way to compete with the presented method given new unseen and shuffled racing maps.
      • Ah I spoke too soon… it’s explained below the figure.
      • I see they just add AttentionAgent layer at the end but unfortunately, I don’t understand enough about these methods to say anything else.
  • Discussion + Conclusion
    • Further uses included dynamic input-output mappings in robots or cross-wired systems.
      • This isn’t an entirely broad application – it’s a bit specific to systems that receive shuffled inputs.
    • However, I think the most important impact of this paper is that it “provides a lens for understanding a transformer as a self-organizing network” in what I felt was a natural way.
      • When I read the title, I imagined the authors would need to make more substantial leaps in connecting transformers with a sensory neuron, but now that I’ve read the paper it feels ok.
  • One thing I wonder when reading this paper is: what is the best way to aggregate signals from the individual agents? This paper looked at aggregationthrough neural network layers, but I thought the reshape in the 2D spatial embeddingswas an interesting twist.
    • Generally, this paper adds some interesting ingredients to a mixture of local -> global computation using distributed agents.
  • Another thing I really like about this paper is that it isn’t crazy hard to understand. So many ML papers are steeped in lots of formula and notation and this paper truly tried to clearly communicate the message.