In this post we will examine Visual Attention Network for image classification and walk through its implementation in TensorFlow.
The paper introducing VAN proposed the following contributions:
- A novel Large Kernel Attention (LKA) module which is a self-attention mechanism that takes advantage of the 2D structure of images by capturing channel adaptability in addition to spatial adaptability.
- A novel neural network based on LKA, called Visual Attention Network (VAN) that outperforms vision transformers and convolutional neural networks in many computer vision tasks.
The following charts from the paper highlights the results of different models on ImageNet-1K validation set. We can see that VAN performs better while keeping the computation cost comparable to other models.
Implementation
Let's look at the different components of the VAN architecture and understand how to implement it in TensorFlow. We will be referencing the code from the original PyTorch implementation in this repository - VAN-Classification.
You can find another TensorFlow implementation in the following repository tfvan.
import math
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializers, layers, Model
Multilayer Perceptrons
VAN relies on a variation of Multilayer Perceptrons (MLPs) layer that decouple standard MLP into spatial MLP and channel MLP to reduce computational cost of the standard MLP. Furthermore, it uses an attention mechanism similarly to gMLP but without sensitivity to input size or the constraint of processing fixed-size images.
Below is the implementation of the MLP layer as proposed in the VAN architecture
class MLP(layers.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, activation="gelu", drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = layers.Conv2D(hidden_features, 1)
        self.dwconv = layers.DepthwiseConv2D(
            kernel_size=3, strides=1, padding='same',
            use_bias=True, activation=activation
            )
        self.fc2 = layers.Conv2D(out_features, 1)
        self.drop = layers.Dropout(drop)
    def call(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
mlp = MLP(768)
y = mlp(tf.zeros((1, 224, 224, 3)))
y.shape
Large Kernel Attention
Attention mechanisms are used to capture relationship between input features and to produce an attention map indicating the importance of different them. To learn this relationship, we could use:
- Self-attention mechanism to
capture long-range dependence. Self-attention was successful in NLP tasks, in computer vision it has following drawbacks- It treats images as 1D sequences which neglects the 2D structure of images.
- The quadratic complexity is too expensive for high-resolution images.
- It only achieves spatial adaptability but ignores the adaptability in channel dimension
 
- Large kernel convolution to build relevance and produce attention map. But this approach has its own limitations, Large-kernel convolution comes with a huge amount of computational overhead and additional parameters to learn.
The authors proposed a new attention mechanism that combines the pros of the previous approaches while overcome their drawbacks. This is achieved by decomposing a large kernel convolution as depicted in the above picture.
For instance, given a $K × K$ convolution and dilation rate $d$, we decompose into:
- a spatial local convolution (DW-Conv or depth-wise convolution) of $⌈\frac{K}{d}⌉×⌈\frac{K}{d}⌉$,
- a spatial long-range convolution ((DW-D-Conv or depth-wise dilation convolution) of $(2d − 1) × (2d − 1)$, and
- a channel convolution ($1×1$ convolution).
This decomposition can be writting in the following formula:
$$ Attention = Conv_{1 x 1} (\text{DW-D-Conv} (\text{DW-Conv} (F)) ) $$ $$Output = Attention \otimes F $$
With $F ∈ R^{C×H×W}, Attention ∈ R^{C×H×W} \text{ and } \otimes \text{ stands for element-wise product}$
The architecture of a Large Kernel Attention layer would look like this:
In TensorFlow, the Large Kernel Attention layer can be implemented as follows:
class LKA(layers.Layer):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = layers.Conv2D(dim, kernel_size=5, padding="same", groups=dim)
        self.conv_spatial = layers.Conv2D(dim, kernel_size=7, strides=1, padding="same", groups=dim, dilation_rate=3)
        self.conv1 = layers.Conv2D(dim, kernel_size=1)
    def call(self, x):        
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)
        return x * attn
attn = LKA(4)
y = attn(tf.zeros((1, 224, 224, 4)))
y.shape
class SpatialAttention(layers.Layer):
    def __init__(self, d_model, activation="gelu"):
        super().__init__()
        self.proj_1 = layers.Conv2D(d_model, kernel_size=1, activation=activation)
        self.spatial_gating_unit = LKA(d_model)
        self.proj_2 = layers.Conv2D(d_model, kernel_size=1)
    def call(self, x):
        attn = self.proj_1(x)
        attn = self.spatial_gating_unit(attn)
        attn = self.proj_2(attn)
        attn = x + attn
        return attn
attn = SpatialAttention(4)
y = attn(tf.zeros((1, 224, 224, 4)))
y.shape
DropPath is used in the VAN model as an alternative to the Dropout layer, it was originally proposed in the FractalNet paper. Below is the implementation of DropPath in TensorFlow. Instead of using custom layer, we could alternatively we could simply use this StochasticDepth layer.
class DropPath(layers.Layer):
    def __init__(self, rate, **kwargs):
        super().__init__(**kwargs)
        self.rate = rate
    def call(self, inputs, training=None, **kwargs):
        if 0. == self.rate:
            return inputs
        if training is None:
            training = tf.keras.backend.learning_phase()
        training = tf.constant(training, dtype=tf.bool)
        outputs = tf.cond(training, lambda: self.drop(inputs), lambda: tf.identity(inputs))
        return outputs
    def drop(self, inputs):
        keep = 1.0 - self.rate
        batch = tf.shape(inputs)[0]
        shape = [batch] + [1] * (inputs.shape.rank - 1)
        random = tf.random.uniform(shape, dtype=self.compute_dtype) <= keep
        random = tf.cast(random, self.compute_dtype) / keep
        outputs = inputs * random
        return outputs
attn = DropPath(0.1)
y = attn(tf.zeros((1, 224, 224, 3)))
y.shape
VAN stage
VAN uses multiple stages that downsample the input tensor and pass it through combines two Transofmer Encoder blocks but uses a window based self-attentions as illustrated in the following diagram. In this section, we will examine each component of this bloc and implement it in TensorFlow.
class OverlapPatchEmbed(layers.Layer):
    def __init__(self, img_size=224, patch_size=7, patch_stride=4, in_chans=3, embed_dim=768, dilation_rate=1, **kwargs):
        super().__init__(**kwargs)
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        # same pad layer
        dilation_rate = (dilation_rate, dilation_rate)
        total_pad = (patch_size[0] - 1) * dilation_rate[0], \
                    (patch_size[1] - 1) * dilation_rate[1]
        top_pad = total_pad[0] // 2
        bottom_pad = total_pad[0] - top_pad
        left_pad = total_pad[1] // 2
        right_pad = total_pad[1] - top_pad
        # noinspection PyAttributeOutsideInit
        self.pad = layers.ZeroPadding2D(((top_pad, bottom_pad), (left_pad, right_pad)))
        # embedding
        self.proj = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_stride)
        self.norm = layers.BatchNormalization()
    def call(self, x):
        x = self.pad(x)
        x = self.proj(x)
        x = self.norm(x)        
        return x
embed = OverlapPatchEmbed()
y = embed(tf.zeros((1, 224, 224, 3)))
y.shape
In TensorFlow, we can implement this layer as follows:
class Block(layers.Layer):
    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., activation="gelu", **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.BatchNormalization()
        self.attn = SpatialAttention(dim)
        self.drop_path = DropPath(drop_path)
        self.norm2 = layers.BatchNormalization()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=activation, drop=drop)
        layer_scale_init_value = 1e-2
        self.layer_scale_1 = tf.Variable(
            initial_value=layer_scale_init_value * tf.ones((1, 1, 1, dim)), trainable=True)  
        self.layer_scale_2 = tf.Variable(
            layer_scale_init_value * tf.ones((1, 1, 1, dim)), trainable=True)
    def call(self, x):
        x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
        return x
block = Block(3)
y = block(tf.zeros((1, 224, 224, 3)))
y.shape
The following layer implements a VAN's stage by using the OverlapPatchEmbed layer and a sequence of Block layers
class Stage(layers.Layer):
    def __init__(self, i, embed_dims, mlp_ratios, depths, path_drops, drop_rate=0., **kwargs):
        super().__init__(**kwargs)
        # downsample
        self.patch_embed = OverlapPatchEmbed(
            patch_size = 7 if 0 == i else 3,
            patch_stride = 4 if 0 == i else 2, 
            embed_dim = embed_dims[i], 
            name = f'patch_embed{i + 1}'
            )
        # blocks
        self.blocks = [Block(
            dim=embed_dims[i],
            mlp_ratio=mlp_ratios[i],
            drop=drop_rate,
            drop_path=path_drops[sum(depths[:i]) + j],
            name=f'block{i + 1}.{j}'
            ) for j in range(depths[i])]
        # normalization
        self.norm = layers.LayerNormalization(name=f'norm{i + 1}')
    def call(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return x
Putting it together
After defining all the major components of the VAN model, we can put them together to build the model. This is fairly straightforward now as we just need to pass the input image through a sequence of Stage layers. Then, add a head to the model that consists of a pooling followed by a dense layer that outputs the result (e.g. class ID if the task is classification).
def create_VAN(embed_dims, mlp_ratios, depths, drop_rate=0., path_drop=0.1, input_shape=(224, 224, 3), pooling=None, classes=2):
    # stochastic depth decay rule
    path_drops = np.linspace(0., path_drop, sum(depths))
    # input image
    inputs = layers.Input(shape=input_shape, name='image')
    x = inputs
    # create stages
    for i in range(len(depths)):
        stage = Stage(i, embed_dims, mlp_ratios, depths, path_drops, drop_rate, name = f'stage_{i}')
        x = stage(x)
    # pooling layer
    if pooling in {None, 'avg'}:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    elif pooling == 'max':
        x = layers.GlobalMaxPooling2D(name='max_pool')(x)
    else:
        raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {pooling}')
    # head classifier
    x = layers.Dense(classes, name='head')(x)
    outputs = layers.Activation('softmax', dtype='float32', name='pred')(x)
    # Create model.
    model = Model(inputs, outputs, name='van')
    return model
The smallest model in terms of parameters is TinyVAN, which we create as follows:
model = create_VAN(embed_dims=(32, 64, 160, 256), mlp_ratios=(8, 8, 4, 4), depths=(3, 3, 5, 2))
tf.keras.utils.plot_model(model, rankdir='LR', show_shapes=True)
model.summary()
Training
The authors trained the VAN model for various vision tasks (e.g. classification or object detection). They trained the model during 310 epochs using AdamW optimizer with momentum=0.9, weight decay=5 × 10−2 and batch size = 1,024.
For the learning rate (LR), Cosine scheduling and warm-up strategy were used. The initial LR is set to 5 × 10−4
They used the following data augmentation techniques:
- random clipping
- random horizontal flipping
- label-smoothing
- mixup
- cutmix
- random erasing.
That's all folks
I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc