Techno Blender
Digitally Yours.

Towards Stand-Alone Self-Attention in Vision | by Julian Hatzky | Apr, 2023

0 44


Image created by author using craiyon AI

While self-attention is already widely adopted in NLP and significantly contributes to the performance of state-of-the-art models (e.g. [2], [3]), more and more work is being done to achieve similar results in vision.

Even though, there are hybrid approaches that combine for example CNNs with attention [4] or apply linear transformations on patches of the image [5], a pure attention-based model is harder to train effectively due to various reasons that we will investigate further on.

The Stand-Alone Self-Attention in Vision Models [6] paper introduces the idea of such a pure attention-based model for vision. In the following, I will give an overview of the paper’s ideas and related follow-up work. Further, I assume that you are familiar with the workings of the transformer and have a basic knowledge of CNNs. An understanding of PyTorch is also beneficial for the coding parts but these can also be safely skipped.
If you are on the other hand only interested in the code, feel free to skip this article and directly take a look at this annotated colab notebook.

The Case for Self-Attention in Vision

CNNs are commonly used to build neural networks for image processing due to their powerful geometric prior of translation equivariance. This means that they can handle relative shifts of the input well, making them robust.

On the other hand, self-attention does not have this prior and is instead permutation equivariant. This means that if the input is rearranged, the output will be rearranged in an equivalent way. Although permutation equivariance is more general, it is not as useful for images as translation equivariance.

Fortunately, we can use different positional encodings to constrain the self-attention operation and achieve translation equivariance. Positional encoding — which is also called positional embedding when it has learnable parameters, allows us to have a more flexible architecture than CNNs while still being able to incorporate certain priors.

Implementing Basic Self-Attention in 1D

For one-dimensional inputs like text and speech, the single-head self-attention operation is defined as

Scaled Dot-Product Attention as proposed in [1]

which is essentially a scaled dot-product between the query Q and the key K followed by another dot-product between the resulting matrix and V.

We can also express the dot-product as a weighted sum explicitly and show how to get a specific output. Keep that in mind because later we will generalize this for 2D images.

Self-Attention for a specific output yᵢ

In PyTorch, this could look as follows.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

embedding_dim_k = 10

# linear projection of the input x
key = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
query = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
value = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)

# creating random vector of shape:
# batch_size (b), sequence lenght (t), embedding dim (k)
x = torch.randn(1, 12, embedding_dim_k) # b t k

d_b, d_t, d_k = x.size()

# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k
assert q.shape == (d_b, d_t, d_k)

# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)

# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)

# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_k)

Global vs Local Self-attention

When we talk about global and local self-attention in visual models, we are referring to how much of the image the model is looking at. Global self-attention looks at the entire image at once, while local self-attention focuses only on certain parts. Generally, the larger the area the model is looking at, the more complex it becomes and the more memory it requires.

Let’s take a closer look at the basic self-attention operation and how it performs with larger image sizes. To do this, we’ll use a concept called big O notation to express the complexity of the operation as the input size n increases.

The self-attention operation involves three separate calculations:

  1. Calculating QKᵀ has a complexity of O(n² d_k)
  2. The softmax operation, which involves exponentiation, summation, and division, has a quadratic complexity of O(n²)
  3. Multiplying softmax(QKᵀ)V has a complexity of O(n² d_v)

In total, the basic self-attention operation scales quadratically as the length of the input sequence n increases. So, as we apply self-attention to larger and larger images — which are of roughly length n² = h*w themselves due to their 2D nature — the space and time complexity of the operation becomes increasingly higher. This is one reason why using global receptive fields on larger images can be difficult and why local receptive fields are an appealing solution.

Revisiting CNNs

In Figure 1, we can see that we use small squares called kernels that slide across the image. We choose a center point [i,j] on the image and kernel size, which determines how much of the image is included in the kernel. The kernel is applied to every pixel in the image and the values are fed into the same neural network, so we use fewer parameters. Note that in the figure, there are multiple pixels in each square, but in reality, we only have one pixel per square unless we use pooling to group them together.

Figure 1: An example of a local convolutional window around a point [i, j] (red square) with spatial extend k=3. ©J. Hatzky

The size of the kernel can vary between the layers of the network. This allows the network to learn local correlation structures within a particular layer. In recent work, differential kernels of variable size have been introduced [7], but we will focus on the basic approach used in traditional CNNs. Since the convolutional kernel is an important concept that we will build upon, I explain it using the notation used in [6].

The input image is specified by its height h, width w, and channel size din (e.g. 3 for RGB image): x ∈ ℝʰˣʷˣᵈⁱⁿ. We define a local neighborhood Nₖ around a pixel xᵢⱼ using a spatial extent k, which is the set of pixels within the kernel. For example, N₃(2,2) would be the set of pixels within a 3×3 square centered around the pixel at row 2, column 2. For completeness, we can define it as: Nₖ(i, j) = {a, b ∣ |a − i| ≤ k/2, |b − j| ≤ k/2}. We optimize a weight matrix W ∈ ℝᵏˣᵏˣᵈᵒᵘᵗˣᵈⁱⁿ to calculate a specific output yᵢⱼ for each pixel.

Weighted sum with spatial extend k and center [i, j]

To get this output, we sum up the product of depth-wise matrix multiplications for each pixel in the local neighborhood. This operation is translation equivariant, which means it’s designed to recognize patterns regardless of where they appear in the image.

The Memory Block as a 2D Local Receptive Field

To perform self-attention on a 2D image, researchers in [6] came up with a memory block concept that is inspired by the way CNNs work. If you want to apply self-attention globally, you just need to make the memory block as big as the entire image. The memory block is essentially the same as the receptive field used in CNNs, but instead of using a CNN, we apply the self-attention operation on the pixels in the receptive field Nₖ, which creates a learnable connection between any pair of pixels in the local memory block.
To define the single-head self-attention operation for this 2D case, we can use the following equation:

Self-Attention for a specific output yᵢⱼ

While losing the translation equivariance of the CNN, we now gained the more general permutation equivariance that is a property of self-attention.
Let’s see how this would look in PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

# we create a random normal tensor as a placeholder for an RGB image with shapes:
# batch_size (b), channels (c), height (h), width (w)
img = torch.randn(1, 3, 28, 28) # b c h w

k = 3 # spatial extend of the memory block N

# we can extract memory blocks by using the pytorch unfold operation and rearranging the result
# we pad the image first to keep our old dimensions intact
stride = 1
padding = 1
memory_blocks = F.pad(img, [padding]*4).unfold(dimension=2, size=k,
step=stride).unfold(dimension=3, size=k, step=stride)
memory_blocks = rearrange(memory_blocks, "b c h w i j -> b h w c i j")
print(memory_blocks.shape)
print(f"We have {memory_blocks.shape[1]}x{memory_blocks.shape[2]} patches of shape: {memory_blocks.shape[2:]}")

# apply the self-attention for a specific ij:
i, j = (3, 4)
memory_block_ij = memory_blocks[:, i, j, : , :, :]
# we can flatten the memory blocks height and width
x = rearrange(memory_block_ij, "b h w c -> b (h w) c")

# our input dimension is the channel size
d_in = x.shape[-1]
d_out = d_in

# linear transformations to embed the input x
key = nn.Linear(d_in, d_out, bias=False)
query = nn.Linear(d_in, d_out, bias=False)
value = nn.Linear(d_in, d_out, bias=False)

d_b, d_t, d_k = x.size()

# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k

assert q.shape == (d_b, d_t, d_out)

# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)

# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)

# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_out)

This simple implementation has one big downside to it. We lose all the spatial information when we apply self-attention to the flattened memory block. One way to resolve this is by adding positional information — the subject of the next section.

2D Relative Positional Embedding

In addition to the 2D self-attention, [6] introduces the 2D application of relative embeddings. Relative embeddings for 1D were first introduced by [8] and then later on extended by e.g. [9] and [10].

With relative embeddings we first get a powerful positional representation that has the potential to generalize better than say absolute embeddings [8], to bigger images (or longer sequences as is the case in NLP).
Further, we introduce a powerful inductive bias in the model which is translation equivariance, and which already has been proven to be very helpful in the case of CNNs.

The way that relative positional embeddings work in 2D is to define relative indices for the x (columns) and y (rows) direction. Relative here means, that the indices should be relative to the pixel yᵢⱼ that is queried (Figure 2).

Figure 2: Relative positional embedding for a specific pixel ab ∈ Nₖ(i, j). ©J. Hatzky

As proposed in [6], the row and column offsets are associated with an embedding r for (a-i) and (b-j) respectively each with dimension 1/2*dout. The row and column offset embeddings are then concatenated to form this spatial-relative attention.

Relative positional embeddings are added within the self-attention operation

Essentially, what we create here is an embedding matrix that contains relative positional information and that is added to the QK dot-product within the softmax.
See below how it could be done in PyTorch. Note that there are more efficient ways to implement this, which we will not cover here as we stick to the introduced formulation.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

# number of input channels (e.g. 3 for RGB)
in_channels = 3
# the embedding dim of the input projection (embedding_dim)
mid_channels = 22
# the number of attention heads
num_heads = 2
# the number of channels after projecting the heads together
out_channels = 8
# the maximum number of image pixels of a side(assuming squared images)
max_pos_embedding = 4

# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)

# Define positional embeddings
row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)

# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
relative_indices = (deltas + max_pos_embedding - 1).to(device)

# create an example image
x = torch.randn(4, 3, 4, 4, device=device) # b c h w

b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False, device=device))

q = query(x)
k = key(x)
v = value(x)

# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = row_embedding(
relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = col_embedding(
relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)

# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)

# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, h*h)
expand_col = col_embedding.unsqueeze(-2).expand(-1, w*w, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)

positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=mid_channels, h=h, w=w, i=h, j=w)

# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer

# Compute attention scores based on data
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer

# Combine attention scores
attention_scores = attention_scores + attention_content_scores

# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)

# Re-weight values via attention
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")

# linear project to output dimension
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
out = wout(v_f)

out.shape

Putting it all together

Now we are at the point where we can put all the parts together.
For a better understanding, Figure 3 is an overview of the data flow and shapes involved in self-attention.

Figure 3: Overview of the shapes throughout the self-attention. Inspired by this GitHub post. ©J. Hatzky

Let’s create a class that implements the whole model.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

class StandAloneSelfAttention(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels,
num_heads, max_pos_embedding):
"""
Inputs:
in_channels - Dimensionality of input and attention feature vectors
mid_channels - Embedding dim of the input projection
out_channels - Output dim after projecting heads together
num_heads - Number of heads to use in the Multi-Head Attention block
max_pos_embedding # The max(height, width) of size that has to be embedded
"""
super().__init__()

self.mid_channels = mid_channels
self.num_heads = num_heads
self.out_channels = out_channels

# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
self.query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)

# Define positional embeddings
self.row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
self.col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)

# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
self.relative_indices = (deltas + max_pos_embedding - 1).to(device)

self.verbose = False

def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
if self.verbose is True:
print(f"x: {x.shape}, q: {q.shape}, k: {k.shape}, v:{v.shape}")

b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False,
device=device))

# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = self.row_embedding(
self.relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = self.col_embedding(
self.relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)

# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)

if self.verbose is True:
print(f"q: {q.shape}, k: {k.shape}, v:{v.shape}")

# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, w*w)
expand_col = col_embedding.unsqueeze(-2).expand(-1, h*h, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)

positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=self.mid_channels, h=h, w=w, i=h, j=w)

if self.verbose is True:
print(f"row_encoding: {row_embedding.shape}, column_encoding: {col_embedding.shape}, pos_embedding: {positional_embedding.shape}")

# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer

# Compute attention scores based on data dot-prod(q, k)
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer

# Combine attention scores
attention_scores = attention_scores + attention_content_scores

# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)
if self.verbose is True:
print(f"attention_scores: {attention_scores.shape}, shaped scores: {attention_scores.view(*shape[:-2], -1).shape} att_probs: {att_probs.shape}")

# Re-weight values via attention and map to output dimension.
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
if self.verbose is True:
print(f"(qr + qk)V: {v_f.shape}")
out = self.wout(v_f)

return out

Epilogue

The Stand-Alone Self-Attention in Vision Models [6] paper presents a fascinating idea for applying pure self-attention models in vision. Despite the self-attention operation’s complexity, the paper demonstrates an effective approach that uses local receptive fields, also known as memory blocks, to reduce computational resources. While the more recently published vision transformers may be stealing the limelight, this method has immense potential to become a top contender for state-of-the-art architectures in vision with additional software and hardware improvements. It’s an exciting piece of work that could take vision models to the next level!


Image created by author using craiyon AI

While self-attention is already widely adopted in NLP and significantly contributes to the performance of state-of-the-art models (e.g. [2], [3]), more and more work is being done to achieve similar results in vision.

Even though, there are hybrid approaches that combine for example CNNs with attention [4] or apply linear transformations on patches of the image [5], a pure attention-based model is harder to train effectively due to various reasons that we will investigate further on.

The Stand-Alone Self-Attention in Vision Models [6] paper introduces the idea of such a pure attention-based model for vision. In the following, I will give an overview of the paper’s ideas and related follow-up work. Further, I assume that you are familiar with the workings of the transformer and have a basic knowledge of CNNs. An understanding of PyTorch is also beneficial for the coding parts but these can also be safely skipped.
If you are on the other hand only interested in the code, feel free to skip this article and directly take a look at this annotated colab notebook.

The Case for Self-Attention in Vision

CNNs are commonly used to build neural networks for image processing due to their powerful geometric prior of translation equivariance. This means that they can handle relative shifts of the input well, making them robust.

On the other hand, self-attention does not have this prior and is instead permutation equivariant. This means that if the input is rearranged, the output will be rearranged in an equivalent way. Although permutation equivariance is more general, it is not as useful for images as translation equivariance.

Fortunately, we can use different positional encodings to constrain the self-attention operation and achieve translation equivariance. Positional encoding — which is also called positional embedding when it has learnable parameters, allows us to have a more flexible architecture than CNNs while still being able to incorporate certain priors.

Implementing Basic Self-Attention in 1D

For one-dimensional inputs like text and speech, the single-head self-attention operation is defined as

Scaled Dot-Product Attention as proposed in [1]

which is essentially a scaled dot-product between the query Q and the key K followed by another dot-product between the resulting matrix and V.

We can also express the dot-product as a weighted sum explicitly and show how to get a specific output. Keep that in mind because later we will generalize this for 2D images.

Self-Attention for a specific output yᵢ

In PyTorch, this could look as follows.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

embedding_dim_k = 10

# linear projection of the input x
key = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
query = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)
value = nn.Linear(embedding_dim_k, embedding_dim_k, bias=False)

# creating random vector of shape:
# batch_size (b), sequence lenght (t), embedding dim (k)
x = torch.randn(1, 12, embedding_dim_k) # b t k

d_b, d_t, d_k = x.size()

# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k
assert q.shape == (d_b, d_t, d_k)

# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)

# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)

# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_k)

Global vs Local Self-attention

When we talk about global and local self-attention in visual models, we are referring to how much of the image the model is looking at. Global self-attention looks at the entire image at once, while local self-attention focuses only on certain parts. Generally, the larger the area the model is looking at, the more complex it becomes and the more memory it requires.

Let’s take a closer look at the basic self-attention operation and how it performs with larger image sizes. To do this, we’ll use a concept called big O notation to express the complexity of the operation as the input size n increases.

The self-attention operation involves three separate calculations:

  1. Calculating QKᵀ has a complexity of O(n² d_k)
  2. The softmax operation, which involves exponentiation, summation, and division, has a quadratic complexity of O(n²)
  3. Multiplying softmax(QKᵀ)V has a complexity of O(n² d_v)

In total, the basic self-attention operation scales quadratically as the length of the input sequence n increases. So, as we apply self-attention to larger and larger images — which are of roughly length n² = h*w themselves due to their 2D nature — the space and time complexity of the operation becomes increasingly higher. This is one reason why using global receptive fields on larger images can be difficult and why local receptive fields are an appealing solution.

Revisiting CNNs

In Figure 1, we can see that we use small squares called kernels that slide across the image. We choose a center point [i,j] on the image and kernel size, which determines how much of the image is included in the kernel. The kernel is applied to every pixel in the image and the values are fed into the same neural network, so we use fewer parameters. Note that in the figure, there are multiple pixels in each square, but in reality, we only have one pixel per square unless we use pooling to group them together.

Figure 1: An example of a local convolutional window around a point [i, j] (red square) with spatial extend k=3. ©J. Hatzky

The size of the kernel can vary between the layers of the network. This allows the network to learn local correlation structures within a particular layer. In recent work, differential kernels of variable size have been introduced [7], but we will focus on the basic approach used in traditional CNNs. Since the convolutional kernel is an important concept that we will build upon, I explain it using the notation used in [6].

The input image is specified by its height h, width w, and channel size din (e.g. 3 for RGB image): x ∈ ℝʰˣʷˣᵈⁱⁿ. We define a local neighborhood Nₖ around a pixel xᵢⱼ using a spatial extent k, which is the set of pixels within the kernel. For example, N₃(2,2) would be the set of pixels within a 3×3 square centered around the pixel at row 2, column 2. For completeness, we can define it as: Nₖ(i, j) = {a, b ∣ |a − i| ≤ k/2, |b − j| ≤ k/2}. We optimize a weight matrix W ∈ ℝᵏˣᵏˣᵈᵒᵘᵗˣᵈⁱⁿ to calculate a specific output yᵢⱼ for each pixel.

Weighted sum with spatial extend k and center [i, j]

To get this output, we sum up the product of depth-wise matrix multiplications for each pixel in the local neighborhood. This operation is translation equivariant, which means it’s designed to recognize patterns regardless of where they appear in the image.

The Memory Block as a 2D Local Receptive Field

To perform self-attention on a 2D image, researchers in [6] came up with a memory block concept that is inspired by the way CNNs work. If you want to apply self-attention globally, you just need to make the memory block as big as the entire image. The memory block is essentially the same as the receptive field used in CNNs, but instead of using a CNN, we apply the self-attention operation on the pixels in the receptive field Nₖ, which creates a learnable connection between any pair of pixels in the local memory block.
To define the single-head self-attention operation for this 2D case, we can use the following equation:

Self-Attention for a specific output yᵢⱼ

While losing the translation equivariance of the CNN, we now gained the more general permutation equivariance that is a property of self-attention.
Let’s see how this would look in PyTorch.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

# we create a random normal tensor as a placeholder for an RGB image with shapes:
# batch_size (b), channels (c), height (h), width (w)
img = torch.randn(1, 3, 28, 28) # b c h w

k = 3 # spatial extend of the memory block N

# we can extract memory blocks by using the pytorch unfold operation and rearranging the result
# we pad the image first to keep our old dimensions intact
stride = 1
padding = 1
memory_blocks = F.pad(img, [padding]*4).unfold(dimension=2, size=k,
step=stride).unfold(dimension=3, size=k, step=stride)
memory_blocks = rearrange(memory_blocks, "b c h w i j -> b h w c i j")
print(memory_blocks.shape)
print(f"We have {memory_blocks.shape[1]}x{memory_blocks.shape[2]} patches of shape: {memory_blocks.shape[2:]}")

# apply the self-attention for a specific ij:
i, j = (3, 4)
memory_block_ij = memory_blocks[:, i, j, : , :, :]
# we can flatten the memory blocks height and width
x = rearrange(memory_block_ij, "b h w c -> b (h w) c")

# our input dimension is the channel size
d_in = x.shape[-1]
d_out = d_in

# linear transformations to embed the input x
key = nn.Linear(d_in, d_out, bias=False)
query = nn.Linear(d_in, d_out, bias=False)
value = nn.Linear(d_in, d_out, bias=False)

d_b, d_t, d_k = x.size()

# linear projection of the input
q = query(x) # b, t, k
k = key(x) # b, t, k
v = value(x) # b, t, k

assert q.shape == (d_b, d_t, d_out)

# scaled dot-product self-attention
# dot_prod(Q, K)
scaling_factor = 1/torch.sqrt(torch.tensor(d_k))
scaled_dot_product = F.softmax(
einsum(q, k, "b t k, b l k -> b t l") * scaling_factor, dim=-1 )
assert scaled_dot_product.shape == (d_b, d_t, d_t)

# dot-prod(w, v)
self_attention = torch.einsum('b i j , b j d -> b i d', scaled_dot_product, v)

# remember that self-attention is a seq2seq operation:
# the size that goes in, also goes out
assert self_attention.shape == (d_b, d_t, d_out)

This simple implementation has one big downside to it. We lose all the spatial information when we apply self-attention to the flattened memory block. One way to resolve this is by adding positional information — the subject of the next section.

2D Relative Positional Embedding

In addition to the 2D self-attention, [6] introduces the 2D application of relative embeddings. Relative embeddings for 1D were first introduced by [8] and then later on extended by e.g. [9] and [10].

With relative embeddings we first get a powerful positional representation that has the potential to generalize better than say absolute embeddings [8], to bigger images (or longer sequences as is the case in NLP).
Further, we introduce a powerful inductive bias in the model which is translation equivariance, and which already has been proven to be very helpful in the case of CNNs.

The way that relative positional embeddings work in 2D is to define relative indices for the x (columns) and y (rows) direction. Relative here means, that the indices should be relative to the pixel yᵢⱼ that is queried (Figure 2).

Figure 2: Relative positional embedding for a specific pixel ab ∈ Nₖ(i, j). ©J. Hatzky

As proposed in [6], the row and column offsets are associated with an embedding r for (a-i) and (b-j) respectively each with dimension 1/2*dout. The row and column offset embeddings are then concatenated to form this spatial-relative attention.

Relative positional embeddings are added within the self-attention operation

Essentially, what we create here is an embedding matrix that contains relative positional information and that is added to the QK dot-product within the softmax.
See below how it could be done in PyTorch. Note that there are more efficient ways to implement this, which we will not cover here as we stick to the introduced formulation.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

# number of input channels (e.g. 3 for RGB)
in_channels = 3
# the embedding dim of the input projection (embedding_dim)
mid_channels = 22
# the number of attention heads
num_heads = 2
# the number of channels after projecting the heads together
out_channels = 8
# the maximum number of image pixels of a side(assuming squared images)
max_pos_embedding = 4

# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)

# Define positional embeddings
row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)

# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
relative_indices = (deltas + max_pos_embedding - 1).to(device)

# create an example image
x = torch.randn(4, 3, 4, 4, device=device) # b c h w

b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False, device=device))

q = query(x)
k = key(x)
v = value(x)

# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = row_embedding(
relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = col_embedding(
relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)

# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w", heads=num_heads, c=mid_channels)

# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, h*h)
expand_col = col_embedding.unsqueeze(-2).expand(-1, w*w, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)

positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=mid_channels, h=h, w=w, i=h, j=w)

# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer

# Compute attention scores based on data
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer

# Combine attention scores
attention_scores = attention_scores + attention_content_scores

# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)

# Re-weight values via attention
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")

# linear project to output dimension
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
out = wout(v_f)

out.shape

Putting it all together

Now we are at the point where we can put all the parts together.
For a better understanding, Figure 3 is an overview of the data flow and shapes involved in self-attention.

Figure 3: Overview of the shapes throughout the self-attention. Inspired by this GitHub post. ©J. Hatzky

Let’s create a class that implements the whole model.

import torch
import torch.nn as nn
import torch.nn.functional as F

# for some einsum magic
from einops import rearrange, einsum

# use gpu if possible
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device:", device)

class StandAloneSelfAttention(nn.Module):
def __init__(self, in_channels, mid_channels, out_channels,
num_heads, max_pos_embedding):
"""
Inputs:
in_channels - Dimensionality of input and attention feature vectors
mid_channels - Embedding dim of the input projection
out_channels - Output dim after projecting heads together
num_heads - Number of heads to use in the Multi-Head Attention block
max_pos_embedding # The max(height, width) of size that has to be embedded
"""
super().__init__()

self.mid_channels = mid_channels
self.num_heads = num_heads
self.out_channels = out_channels

# create embeddings. if we want to keep the 2D representation of the
# input, we can do this by using 2D convolution
self.query = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.key = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.value = nn.Conv2d(in_channels, mid_channels * num_heads, kernel_size=1, device=device)
self.wout = nn.Conv2d(mid_channels * num_heads, out_channels, kernel_size=1, device=device)

# Define positional embeddings
self.row_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)
self.col_embedding = nn.Embedding(2 * max_pos_embedding - 1, mid_channels // 2, device=device)

# create relative indices
deltas = torch.arange(max_pos_embedding).view(1, -1) - torch.arange(max_pos_embedding).view(
-1, 1
)
# -- shift the delta to [0, 2 * max_position_embeddings - 1]
self.relative_indices = (deltas + max_pos_embedding - 1).to(device)

self.verbose = False

def forward(self, x):
q = self.query(x)
k = self.key(x)
v = self.value(x)
if self.verbose is True:
print(f"x: {x.shape}, q: {q.shape}, k: {k.shape}, v:{v.shape}")

b, cin, h, w = x.size()
sqrt_normalizer = torch.sqrt(torch.tensor([cin], requires_grad=False,
device=device))

# Compute attention scores based on position
# the relative indices are used to get the stair-case pattern corret vectors
row_embedding = self.row_embedding(
self.relative_indices[:w, :w].reshape(-1)
).transpose(0, 1)
col_embedding = self.col_embedding(
self.relative_indices[:h, :h].reshape(-1)
).transpose(0, 1)

# unfold heads
q = rearrange(
q, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
k = rearrange(
k, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)
v = rearrange(
v, "b (c heads) h w -> b c heads h w",
heads=self.num_heads, c=self.mid_channels)

if self.verbose is True:
print(f"q: {q.shape}, k: {k.shape}, v:{v.shape}")

# now expand the rows and columns and conncatenate them
expand_row = row_embedding.unsqueeze(-1).expand(-1, -1, w*w)
expand_col = col_embedding.unsqueeze(-2).expand(-1, h*h, -1)
positional_embedding = torch.cat((expand_row, expand_col), dim=0)

positional_embedding = rearrange(
positional_embedding, "c (h w) (i j) -> c h w i j",
c=self.mid_channels, h=h, w=w, i=h, j=w)

if self.verbose is True:
print(f"row_encoding: {row_embedding.shape}, column_encoding: {col_embedding.shape}, pos_embedding: {positional_embedding.shape}")

# dot-prod(q, r)
attention_scores = einsum(q, positional_embedding,
"b c h i j, c i k j l -> b h i j k l")
attention_scores = attention_scores / sqrt_normalizer

# Compute attention scores based on data dot-prod(q, k)
attention_content_scores = einsum(q, k, "b c h i j, b c h k l -> b h i j k l")
attention_content_scores = attention_content_scores / sqrt_normalizer

# Combine attention scores
attention_scores = attention_scores + attention_content_scores

# Normalize to obtain probabilities.
shape = attention_scores.shape
att_probs = nn.Softmax(dim=-1)(attention_scores.view(*shape[:-2], -1)).view(shape)
if self.verbose is True:
print(f"attention_scores: {attention_scores.shape}, shaped scores: {attention_scores.view(*shape[:-2], -1).shape} att_probs: {att_probs.shape}")

# Re-weight values via attention and map to output dimension.
v_f = einsum(att_probs, v, "b h i j k l, b c h k l -> b c h i j")
v_f = rearrange(v_f, "b c h i j -> b (c h) i j")
if self.verbose is True:
print(f"(qr + qk)V: {v_f.shape}")
out = self.wout(v_f)

return out

Epilogue

The Stand-Alone Self-Attention in Vision Models [6] paper presents a fascinating idea for applying pure self-attention models in vision. Despite the self-attention operation’s complexity, the paper demonstrates an effective approach that uses local receptive fields, also known as memory blocks, to reduce computational resources. While the more recently published vision transformers may be stealing the limelight, this method has immense potential to become a top contender for state-of-the-art architectures in vision with additional software and hardware improvements. It’s an exciting piece of work that could take vision models to the next level!

FOLLOW US ON GOOGLE NEWS

Read original article here

Denial of responsibility! Techno Blender is an automatic aggregator of the all world’s media. In each content, the hyperlink to the primary source is specified. All trademarks belong to their rightful owners, all materials to their authors. If you are the owner of the content and do not want us to publish your materials, please contact us by email – [email protected]. The content will be deleted within 24 hours.

Leave a comment