In this post we will examine Visual Attention Network for image classification and walk through its implementation in TensorFlow.

The paper introducing VAN proposed the following contributions:

  • A novel Large Kernel Attention (LKA) module which is a self-attention mechanism that takes advantage of the 2D structure of images by capturing channel adaptability in addition to spatial adaptability.
  • A novel neural network based on LKA, called Visual Attention Network (VAN) that outperforms vision transformers and convolutional neural networks in many computer vision tasks.

The following charts from the paper highlights the results of different models on ImageNet-1K validation set. We can see that VAN performs better while keeping the computation cost comparable to other models.

image.png

Implementation

Let's look at the different components of the VAN architecture and understand how to implement it in TensorFlow. We will be referencing the code from the original PyTorch implementation in this repository - VAN-Classification.

You can find another TensorFlow implementation in the following repository tfvan.

import math
import numpy as np
import tensorflow as tf
from tensorflow.keras import initializers, layers, Model

Multilayer Perceptrons

VAN relies on a variation of Multilayer Perceptrons (MLPs) layer that decouple standard MLP into spatial MLP and channel MLP to reduce computational cost of the standard MLP. Furthermore, it uses an attention mechanism similarly to gMLP but without sensitivity to input size or the constraint of processing fixed-size images.

Below is the implementation of the MLP layer as proposed in the VAN architecture

class MLP(layers.Layer):
    def __init__(self, in_features, hidden_features=None, out_features=None, activation="gelu", drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = layers.Conv2D(hidden_features, 1)
        self.dwconv = layers.DepthwiseConv2D(
            kernel_size=3, strides=1, padding='same',
            use_bias=True, activation=activation
            )
        self.fc2 = layers.Conv2D(out_features, 1)
        self.drop = layers.Dropout(drop)

    def call(self, x):
        x = self.fc1(x)
        x = self.dwconv(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
mlp = MLP(768)
y = mlp(tf.zeros((1, 224, 224, 3)))
y.shape
TensorShape([1, 224, 224, 768])

Large Kernel Attention

Attention mechanisms are used to capture relationship between input features and to produce an attention map indicating the importance of different them. To learn this relationship, we could use:

  • Self-attention mechanism to capture long-range dependence. Self-attention was successful in NLP tasks, in computer vision it has following drawbacks
    • It treats images as 1D sequences which neglects the 2D structure of images.
    • The quadratic complexity is too expensive for high-resolution images.
    • It only achieves spatial adaptability but ignores the adaptability in channel dimension
  • Large kernel convolution to build relevance and produce attention map. But this approach has its own limitations, Large-kernel convolution comes with a huge amount of computational overhead and additional parameters to learn.

image.png

The authors proposed a new attention mechanism that combines the pros of the previous approaches while overcome their drawbacks. This is achieved by decomposing a large kernel convolution as depicted in the above picture.

For instance, given a $K × K$ convolution and dilation rate $d$, we decompose into:

  • a spatial local convolution (DW-Conv or depth-wise convolution) of $⌈\frac{K}{d}⌉×⌈\frac{K}{d}⌉$,
  • a spatial long-range convolution ((DW-D-Conv or depth-wise dilation convolution) of $(2d − 1) × (2d − 1)$, and
  • a channel convolution ($1×1$ convolution).

This decomposition can be writting in the following formula:

$$ Attention = Conv_{1 x 1} (\text{DW-D-Conv} (\text{DW-Conv} (F)) ) $$ $$Output = Attention \otimes F $$

With $F ∈ R^{C×H×W}, Attention ∈ R^{C×H×W} \text{ and } \otimes \text{ stands for element-wise product}$

The architecture of a Large Kernel Attention layer would look like this:

image.png

In TensorFlow, the Large Kernel Attention layer can be implemented as follows:

class LKA(layers.Layer):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = layers.Conv2D(dim, kernel_size=5, padding="same", groups=dim)
        self.conv_spatial = layers.Conv2D(dim, kernel_size=7, strides=1, padding="same", groups=dim, dilation_rate=3)
        self.conv1 = layers.Conv2D(dim, kernel_size=1)

    def call(self, x):        
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return x * attn
attn = LKA(4)
y = attn(tf.zeros((1, 224, 224, 4)))
y.shape
TensorShape([1, 224, 224, 4])
class SpatialAttention(layers.Layer):
    def __init__(self, d_model, activation="gelu"):
        super().__init__()

        self.proj_1 = layers.Conv2D(d_model, kernel_size=1, activation=activation)
        self.spatial_gating_unit = LKA(d_model)
        self.proj_2 = layers.Conv2D(d_model, kernel_size=1)

    def call(self, x):
        attn = self.proj_1(x)
        attn = self.spatial_gating_unit(attn)
        attn = self.proj_2(attn)
        attn = x + attn
        return attn
attn = SpatialAttention(4)
y = attn(tf.zeros((1, 224, 224, 4)))
y.shape
TensorShape([1, 224, 224, 4])

DropPath layer

DropPath is used in the VAN model as an alternative to the Dropout layer, it was originally proposed in the FractalNet paper. Below is the implementation of DropPath in TensorFlow. Instead of using custom layer, we could alternatively we could simply use this StochasticDepth layer.

class DropPath(layers.Layer):
    def __init__(self, rate, **kwargs):
        super().__init__(**kwargs)
        self.rate = rate

    def call(self, inputs, training=None, **kwargs):
        if 0. == self.rate:
            return inputs

        if training is None:
            training = tf.keras.backend.learning_phase()
        training = tf.constant(training, dtype=tf.bool)

        outputs = tf.cond(training, lambda: self.drop(inputs), lambda: tf.identity(inputs))

        return outputs

    def drop(self, inputs):
        keep = 1.0 - self.rate
        batch = tf.shape(inputs)[0]
        shape = [batch] + [1] * (inputs.shape.rank - 1)

        random = tf.random.uniform(shape, dtype=self.compute_dtype) <= keep
        random = tf.cast(random, self.compute_dtype) / keep

        outputs = inputs * random

        return outputs
attn = DropPath(0.1)
y = attn(tf.zeros((1, 224, 224, 3)))
y.shape
TensorShape([1, 224, 224, 3])

VAN stage

VAN uses multiple stages that downsample the input tensor and pass it through combines two Transofmer Encoder blocks but uses a window based self-attentions as illustrated in the following diagram. In this section, we will examine each component of this bloc and implement it in TensorFlow.

image.png

Downsampling layer

The patch embedding layer is used to downsample a tensor using convolutional layers

class OverlapPatchEmbed(layers.Layer):

    def __init__(self, img_size=224, patch_size=7, patch_stride=4, in_chans=3, embed_dim=768, dilation_rate=1, **kwargs):
        super().__init__(**kwargs)
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)

        self.img_size = img_size
        self.patch_size = patch_size
        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W

        # same pad layer
        dilation_rate = (dilation_rate, dilation_rate)
        total_pad = (patch_size[0] - 1) * dilation_rate[0], \
                    (patch_size[1] - 1) * dilation_rate[1]

        top_pad = total_pad[0] // 2
        bottom_pad = total_pad[0] - top_pad
        left_pad = total_pad[1] // 2
        right_pad = total_pad[1] - top_pad

        # noinspection PyAttributeOutsideInit
        self.pad = layers.ZeroPadding2D(((top_pad, bottom_pad), (left_pad, right_pad)))

        # embedding
        self.proj = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_stride)
        self.norm = layers.BatchNormalization()

    def call(self, x):
        x = self.pad(x)
        x = self.proj(x)
        x = self.norm(x)        
        return x
embed = OverlapPatchEmbed()
y = embed(tf.zeros((1, 224, 224, 3)))
y.shape
TensorShape([1, 56, 56, 768])

Block layer

The Block layer is repeated multiple times in a stage. It is composed of the following building blocks:

image.png

In TensorFlow, we can implement this layer as follows:

class Block(layers.Layer):
    def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0., activation="gelu", **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.BatchNormalization()
        self.attn = SpatialAttention(dim)
        self.drop_path = DropPath(drop_path)

        self.norm2 = layers.BatchNormalization()
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=mlp_hidden_dim, activation=activation, drop=drop)
        layer_scale_init_value = 1e-2

        self.layer_scale_1 = tf.Variable(
            initial_value=layer_scale_init_value * tf.ones((1, 1, 1, dim)), trainable=True)  
        self.layer_scale_2 = tf.Variable(
            layer_scale_init_value * tf.ones((1, 1, 1, dim)), trainable=True)

    def call(self, x):
        x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x)))
        x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x)))
        return x
block = Block(3)
y = block(tf.zeros((1, 224, 224, 3)))
y.shape
TensorShape([1, 224, 224, 3])

Stage layer

The following layer implements a VAN's stage by using the OverlapPatchEmbed layer and a sequence of Block layers

class Stage(layers.Layer):

    def __init__(self, i, embed_dims, mlp_ratios, depths, path_drops, drop_rate=0., **kwargs):
        super().__init__(**kwargs)
        # downsample
        self.patch_embed = OverlapPatchEmbed(
            patch_size = 7 if 0 == i else 3,
            patch_stride = 4 if 0 == i else 2, 
            embed_dim = embed_dims[i], 
            name = f'patch_embed{i + 1}'
            )
        # blocks
        self.blocks = [Block(
            dim=embed_dims[i],
            mlp_ratio=mlp_ratios[i],
            drop=drop_rate,
            drop_path=path_drops[sum(depths[:i]) + j],
            name=f'block{i + 1}.{j}'
            ) for j in range(depths[i])]
        # normalization
        self.norm = layers.LayerNormalization(name=f'norm{i + 1}')

    def call(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return x

Putting it together

After defining all the major components of the VAN model, we can put them together to build the model. This is fairly straightforward now as we just need to pass the input image through a sequence of Stage layers. Then, add a head to the model that consists of a pooling followed by a dense layer that outputs the result (e.g. class ID if the task is classification).

def create_VAN(embed_dims, mlp_ratios, depths, drop_rate=0., path_drop=0.1, input_shape=(224, 224, 3), pooling=None, classes=2):
    # stochastic depth decay rule
    path_drops = np.linspace(0., path_drop, sum(depths))
    # input image
    inputs = layers.Input(shape=input_shape, name='image')
    x = inputs
    # create stages
    for i in range(len(depths)):
        stage = Stage(i, embed_dims, mlp_ratios, depths, path_drops, drop_rate, name = f'stage_{i}')
        x = stage(x)

    # pooling layer
    if pooling in {None, 'avg'}:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
    elif pooling == 'max':
        x = layers.GlobalMaxPooling2D(name='max_pool')(x)
    else:
        raise ValueError(f'Expecting pooling to be one of None/avg/max. Found: {pooling}')

    # head classifier
    x = layers.Dense(classes, name='head')(x)
    outputs = layers.Activation('softmax', dtype='float32', name='pred')(x)

    # Create model.
    model = Model(inputs, outputs, name='van')

    return model

The smallest model in terms of parameters is TinyVAN, which we create as follows:

model = create_VAN(embed_dims=(32, 64, 160, 256), mlp_ratios=(8, 8, 4, 4), depths=(3, 3, 5, 2))
tf.keras.utils.plot_model(model, rankdir='LR', show_shapes=True)
model.summary()
Model: "van"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 image (InputLayer)          [(None, 224, 224, 3)]     0         
                                                                 
 stage_0 (Stage)             (None, 56, 56, 32)        80384     
                                                                 
 stage_1 (Stage)             (None, 28, 28, 64)        286528    
                                                                 
 stage_2 (Stage)             (None, 14, 14, 160)       1608480   
                                                                 
 stage_3 (Stage)             (None, 7, 7, 256)         1880832   
                                                                 
 avg_pool (GlobalAveragePool  (None, 256)              0         
 ing2D)                                                          
                                                                 
 head (Dense)                (None, 2)                 514       
                                                                 
 pred (Activation)           (None, 2)                 0         
                                                                 
=================================================================
Total params: 3,856,738
Trainable params: 3,849,314
Non-trainable params: 7,424
_________________________________________________________________

Training

The authors trained the VAN model for various vision tasks (e.g. classification or object detection). They trained the model during 310 epochs using AdamW optimizer with momentum=0.9, weight decay=5 × 10−2 and batch size = 1,024.

For the learning rate (LR), Cosine scheduling and warm-up strategy were used. The initial LR is set to 5 × 10−4

They used the following data augmentation techniques:

  • random clipping
  • random horizontal flipping
  • label-smoothing
  • mixup
  • cutmix
  • random erasing.

That's all folks

I hope you enjoyed this article, feel free to leave a comment or reach out on twitter @bachiirc