In a previous article we saw how to use TensorFlow's Object Detection API to run object detection on images using pre-trained models freely available to download from TF Hub - link. This article we will go one step further by training a model on our own custom Object detection dataset using TensorFlow's Object Detection API.
First, lets install the TensorFlow Object Detection API
%%capture
%%bash
git clone --depth 1 https://github.com/tensorflow/models
cd models/research/
protoc object_detection/protos/*.proto --python_out=.
cp object_detection/packages/tf2/setup.py .
python -m pip install .
We need to have opencv-python-headless version
with version 4.1.2.30
installed for the training using Object Detection API to work
%%capture
%%bash
yes | pip uninstall opencv-python-headless
pip install opencv-python-headless==4.1.2.30
The dataset we will use is Fruit Images for Object Detection dataset from Kaggle. This is a very small dataset with images of the three classes apple
, banana
and orange
. It is more enough to get started with training on custom dataset but you can use your own dataset too.
We will use the Kaggle CLI to download the dataset, unzip and prepare the train/test datasets.
!pip install kaggle --upgrade -q
%%bash
mkdir -p ~/.kaggle
echo '{"username":"KAGGLE_USERNAME","key":"KAGGLE_KEY"}' > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
%%capture
%%bash
kaggle datasets download mbkinaci/fruit-images-for-object-detection
unzip fruit-images-for-object-detection
mkdir -p fruit-images
mv train_zip/* fruit-images/
mv test_zip/* fruit-images/
Note how we move files around in training and test, this is for convinience because the original dataset zip file had images under train_zip/train/*.jpg
and test_zip/test/*.jpg
.
First, import dependencies
import glob
import io
import os
from collections import namedtuple
from xml.etree import ElementTree as tree
import pandas as pd
import tensorflow.compat.v1 as tf
from PIL import Image
from object_detection.utils import dataset_util
from object_detection.protos import pipeline_pb2
from google.protobuf import text_format
We need to define a helper function to encode a label into its index
def encode_class(row_label):
class_mapping = {'apple': 1, 'orange': 2, 'banana': 3}
return class_mapping.get(row_label, None)
We also define a helper function to create the train/test splits
def split(df, group):
Data = namedtuple('data', ['filename', 'object'])
groups = df.groupby(group)
return [Data(filename, groups.get_group(x)) for filename, x in zip(groups.groups.keys(), groups.groups)]
The following function takes train/test images and convert them into one corresponding TF Example file where the image, the bounding boxes, ground-truth classes are grouped as features.
def create_tf_example(group, path):
groups_path = os.path.join(path, f'{group.filename}')
with tf.gfile.GFile(groups_path, 'rb') as f:
encoded_jpg = f.read()
image = Image.open(io.BytesIO(encoded_jpg))
width, height = image.size
filename = group.filename.encode('utf8')
image_format = b'jpg'
# 5. Now, store the dimensions of the bounding boxes, along with the classes of each object contained in the image:
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
xmins.append(row['xmin'] / width)
xmaxs.append(row['xmax'] / width)
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['class'].encode('utf8'))
classes.append(encode_class(row['class']))
# 6. Create a tf.train.Features object that will contain relevant information about the image and its objects:
features = tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes)
})
# 7. Return a tf.train.Example structure initialized with the features created previously:
return tf.train.Example(features=features)
The bounding boxes in the dataset for each image are defined in an XML file (base of PASCAL VOC format - link). We need to parse each of those metadata files to extract the bounding boxes and labels
def bboxes_to_csv(path):
xml_list = []
bboxes_pattern = os.path.sep.join([path, '*.xml'])
for xml_file in glob.glob(bboxes_pattern):
t = tree.parse(xml_file)
root = t.getroot()
for member in root.findall('object'):
value = (root.find('filename').text,
int(root.find('size')[0].text),
int(root.find('size')[1].text),
member[0].text,
int(member[4][0].text),
int(member[4][1].text),
int(member[4][2].text),
int(member[4][3].text))
xml_list.append(value)
column_names = ['filename', 'width', 'height', 'class','xmin', 'ymin', 'xmax', 'ymax']
df = pd.DataFrame(xml_list, columns=column_names)
return df
Now let's process every image in the train/test datasets along with the metadata file to create the corresponding TF Record file.
base = 'fruit-images'
for subset in ['test', 'train']:
labels_path = os.path.sep.join([base,f'{subset}_labels.csv'])
bboxes_df = bboxes_to_csv(f'{base}/{subset}')
bboxes_df.to_csv(labels_path, index=None)
# 10. Then, use the same labels to produce the tf.train.Examples corresponding to the current subset of data being processed:
writer = (tf.io.TFRecordWriter(f'{base}/{subset}.record'))
examples = pd.read_csv(f'{base}/{subset}_labels.csv')
grouped = split(examples, 'filename')
path = os.path.join(f'{base}/{subset}')
for group in grouped:
tf_example = create_tf_example(group, path)
writer.write(tf_example.SerializeToString())
writer.close()
This is how the result of processing the metadata looks like
filename | width | height | class | xmin | ymin | xmax | ymax | |
---|---|---|---|---|---|---|---|---|
0 | mixed_18.jpg | 1023 | 682 | orange | 67 | 163 | 441 | 541 |
1 | mixed_18.jpg | 1023 | 682 | banana | 209 | 134 | 866 | 348 |
2 | mixed_18.jpg | 1023 | 682 | banana | 263 | 267 | 849 | 551 |
3 | apple_11.jpg | 652 | 436 | apple | 213 | 33 | 459 | 258 |
4 | apple_11.jpg | 652 | 436 | apple | 1 | 30 | 188 | 280 |
Now the data is reading for training
!ls -l fruit-images
To run the training on our custom dataset, we will fine tune EfficientNet one of the models in TensorFlow Object Detection API that was trained on COCO dataset. We will download a checkpoint of the model's weights from TensorFlow 2 Detection Model Zoo. Specifically we will downlad the weights of EfficientDet D0 512x512 but you can smaller models like MobileNet v2 320x320 for faster training.
%%capture
%%bash
CHECKPOINT_DATE=20200711
MODEL_NAME=efficientdet_d0_coco17_tpu-32
curl -O http://download.tensorflow.org/models/object_detection/tf2/$CHECKPOINT_DATE/$MODEL_NAME.tar.gz
tar xzf $MODEL_NAME.tar.gz
!ls efficientdet_d0_coco17_tpu-32/checkpoint
We need to create the label_map.txt
file to map the classes to integers
%%writefile fruit-images/label_map.txt
item {
id: 1
name: 'apple'
}
item {
id: 2
name: 'orange'
}
item {
id: 3
name: 'banana'
}
Next, we need to change the configuration file for this network to fit our need. This configuration file can be found locally at models/research/object_detection/configs/tf2/ssd_efficientdet_d0_512x512_coco17_tpu-8.config
.
Note: those configuration files are Protocol Buffers objects described in the .proto
files under models/research/object_detection/protos
. The top level object is a TrainEvalPipelineConfig
defined in pipeline.proto. You can learn more about those configuration files by reading the documentation.
The following helper functions are used to load and save a configuration file, they are based of code borrowed from config_util.py.
def get_pipeline_config(path):
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
return pipeline_config
def save_pipeline_config(pipeline_config, path):
config_text = text_format.MessageToString(pipeline_config)
with tf.gfile.Open(path, "wb") as f:
tf.logging.info("Writing pipeline config file to %s", path)
f.write(config_text)
Load the EfficientNet configuration and update it accordingly:
pipeline_config_path = 'models/research/object_detection/configs/tf2/ssd_efficientdet_d0_512x512_coco17_tpu-8.config'
pipeline_config = get_pipeline_config(pipeline_config_path)
Lower batch size depending on how much memory your system has
pipeline_config.train_config.batch_size = 16
Update the number of classes to match the ones in our custom dataset
pipeline_config.model.ssd.num_classes = 3
Point to the checkout point file of the EfficientNet weights we downloaded earlier
pipeline_config.train_config.fine_tune_checkpoint = '/content/efficientdet_d0_coco17_tpu-32/checkpoint/ckpt-0'
Change the checkpoint type to detection
pipeline_config.train_config.fine_tune_checkpoint_type = 'detection'
Point to the label/index mapping file
pipeline_config.train_input_reader.label_map_path = '/content/fruit-images/label_map.txt'
Point to the training TF Record file we created earlier
pipeline_config.train_input_reader.tf_record_input_reader.input_path[0] = '/content/fruit-images/train.record'
Point to the label/index mapping file
pipeline_config.eval_input_reader[0].label_map_path = '/content/fruit-images/label_map.txt'
Point to the test TF Record file we created earlier
pipeline_config.eval_input_reader[0].tf_record_input_reader.input_path[0] = '/content/fruit-images/test.record'
save_pipeline_config(pipeline_config, 'pipeline.config')
Now we can start training the model using the model_main_tf2.py
helper program and the configuration we just updated.
%%bash
cd models/research/object_detection
mkdir -p /content/training
CONFIG_PATH=/content/pipeline.config
MODEL_DIR=/content/training
python model_main_tf2.py --pipeline_config_path=$CONFIG_PATH --model_dir=$MODEL_DIR --num_train_steps=1000 --record_summaries
INFO:tensorflow:Step 1000 per-step time 1.237s
I0122 03:01:21.339013 139649250535296 model_lib_v2.py:707] Step 1000 per-step time 1.237s
INFO:tensorflow:{'Loss/classification_loss': 0.2277294,
'Loss/localization_loss': 0.1371322,
'Loss/regularization_loss': 0.028924104,
'Loss/total_loss': 0.3937857,
'learning_rate': 0.0326}
I0122 03:01:21.339331 139649250535296 model_lib_v2.py:708] {'Loss/classification_loss': 0.2277294,
'Loss/localization_loss': 0.1371322,
'Loss/regularization_loss': 0.028924104,
'Loss/total_loss': 0.3937857,
'learning_rate': 0.0326}
Once the training is finished, we can check the training logs using TensorBoard
%load_ext tensorboard
%tensorboard --logdir /content/training
To be able to use the new trained model in inference, we need to use the exporter_main_v2.py
program as follows:
%%bash
cd models/research/object_detection
mkdir -p /content/inference_graph
CHECKPOINT_DIR=/content/training
CONFIG_PATH=/content/pipeline.config
OUTPUT_DIR=/content/inference_graph
python exporter_main_v2.py --trained_checkpoint_dir=$CHECKPOINT_DIR --pipeline_config_path=$CONFIG_PATH --output_directory=$OUTPUT_DIR
...
INFO:tensorflow:Assets written to: /content/inference_graph/saved_model/assets
I0122 03:03:07.437493 140608935167872 builder_impl.py:784] Assets written to: /content/inference_graph/saved_model/assets
INFO:tensorflow:Writing pipeline config file to /content/inference_graph/pipeline.config
I0122 03:03:08.824219 140608935167872 config_util.py:254] Writing pipeline config file to /content/inference_graph/pipeline.config
Let's use the trained model on the test images and check its quality.
import glob
import random
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from PIL import Image
from object_detection.utils import ops
from object_detection.utils import visualization_utils as viz
from object_detection.utils.label_map_util import create_category_index_from_labelmap
%matplotlib inline
Define a helper function to load an image and prepare them for the model expected input
def load_image(path):
image_data = tf.io.gfile.GFile(path, 'rb').read()
image = Image.open(BytesIO(image_data))
width, height = image.size
shape = (height, width, 3)
image = np.array(image.getdata())
image = image.reshape(shape).astype('uint8')
return image
Define a helper function to run inference on an input image
def run_inference(net, image):
image = np.asarray(image)
input_tensor = tf.convert_to_tensor(image)
input_tensor = input_tensor[tf.newaxis, ...]
# forward pass
model = net.signatures['serving_default']
result = model(input_tensor)
# extract detections
num_detections = int(result.pop('num_detections'))
result = {key: value[0, :num_detections].numpy() for key, value in result.items()}
result['num_detections'] = num_detections
result['detection_classes'] = result['detection_classes'].astype('int64')
# use mask if available
if 'detection_masks' in result:
detection_masks_reframed = ops.reframe_box_masks_to_image_masks(result['detection_masks'], result['detection_boxes'], image.shape[0], image.shape[1])
detection_masks_reframed = tf.cast(detection_masks_reframed > 0.5, tf.uint8)
result['detection_masks_reframed'] = detection_masks_reframed.numpy()
return result
Let's load the model we exported earlier and create CATEGORY_IDX
based on the label/index mapping file
labels_path = '/content/fruit-images/label_map.txt'
CATEGORY_IDX = create_category_index_from_labelmap(labels_path, use_display_name=True)
model_path = '/content/inference_graph/saved_model'
model = tf.saved_model.load(model_path)
Select random images from the test dataset
image_paths = list(glob.glob('fruit-images/test/*.jpg'))
image_paths = random.choices(image_paths, k=6)
Define a helper function to load and image, run inference on it and draw the predicted bounding boxes:
def get_image_with_boxes(model, path):
image = load_image(path)
annotation = run_inference(model, image)
masks = annotation.get('detection_masks_reframed', None)
viz.visualize_boxes_and_labels_on_image_array(
image,
annotation['detection_boxes'],
annotation['detection_classes'],
annotation['detection_scores'],
CATEGORY_IDX,
instance_masks=masks,
use_normalized_coordinates=True,
line_thickness=5)
return image
image_paths = list(glob.glob('fruit-images/test/*.jpg'))
image_paths = random.choices(image_paths, k=6)
images = [get_image_with_boxes(model, path) for path in image_paths]
Display the images along with the bounding boxes
figure, axis = plt.subplots(2, 3, figsize=(15, 10))
for index, image in enumerate(images):
row, col = int(index / 3), index % 3
axis[row, col].imshow(image)
axis[row, col].axis('off')
The model see to perfom quite well, you can try train it on a harder dataset