Multi-Class Classification#
Introduction#
In fast-paced industrial environments where defects can vary significantly, the ability to accurately classify and segment different types of anomalies is crucial for quality assurance and operational efficiency. Traditional classification methods often struggle with the diversity and rarity of defect types, necessitating a robust solution that can handle multiple categories effectively. This notebook introduces the use of the YoloV8 framework for multi-class classification with segmentation capabilities.
YoloV8 is a powerful deep learning tool optimized for object detection and segmentation, making it particularly suited for scenarios where precise identification and localization of various defect types are required. By training on segmented datasets, YoloV8 can discern and categorize multiple defect types from images, providing detailed insights into the nature and extent of the defects encountered.
Imports#
import ultralytics
from ultralytics import YOLO
import os
import torch
from IPython import display
import sys
from PIL import Image
import supervision as sv
sys.path.append('...')
Configurable parameters#
# Yolo settings
window_size = 448
batch_size = 8
epochs = 200
# Data sources
data_source_train = 'c:/Users/admin/Documents/Karel Debedts Thesis/thesis/dataset/yolo_dataset_subset2/fault_only/train/data.yml',
data_source_test = 'c:/Users/admin/Documents/Karel Debedts Thesis/thesis/dataset/yolo_dataset_subset2/fault_only/test/data.yml',
Training#
A full list of training arguments is available here: https://docs.ultralytics.com/nl/modes/train/#arguments
# Load the model.
model = YOLO('yolov8s.pt')
# Training.
results = model.train(
data= data_source_train,
imgsz=window_size,
epochs=epochs,
batch=batch_size,
plots = True,
resume = False,
val = True,
verbose = True,
device = 0,
project = "yolov8",
name = "results",)
Test evaluation#
The model is applied in combination with a sliding window on the test dataset
model = YOLO('yolov8/results/weights/best.pt')
png_files = [f for f in os.listdir(data_source_test) if f.endswith('.png')]
#metrics
true_pos = 0
false_pos = 0
false_negatives = 0
total = 0
correct_guesses = 0
confusion_matrix_dict = {}
device = torch.device("cuda:0")
for file_name in png_files:
path = os.path.join(data_source_test, file_name)
source = Image.open(path)
results = model.predict(source=source, device=device)
detections = sv.Detections.from_ultralytics(results[0])
pred_lbl = "bg"
gnd_truth = "bg"
#get pred
for i in detections.class_id:
pred_lbl = results[0].names[i]
#get ground truth
with open(path.replace(".png", ".txt"), 'r') as file:
lines = file.readlines()
if lines:
# Get the first annotation
first_annotation = lines[0]
# Split the annotation into components
components = first_annotation.strip().split()
# Extract the class ID
class_id = components[0]
#print(results[0].names)
gnd_truth = results[0].names[int(class_id)]
if not (confusion_matrix_dict):
for key, val in results[0].names.items():
confusion_matrix_dict[val] = {"guess_bg": 0, "guess_fault": 0}
confusion_matrix_dict["bg"] = {"guess_bg": 0, "guess_fault": 0}
print(confusion_matrix_dict)
if gnd_truth != "bg":
if pred_lbl != "bg":
true_pos += 1
correct_guesses += 1
confusion_matrix_dict[gnd_truth]["guess_fault"] += 1
else:
false_negatives += 1
confusion_matrix_dict[gnd_truth]["guess_bg"] += 1
else:
if pred_lbl == "bg":
confusion_matrix_dict["bg"]["guess_bg"] += 1
correct_guesses += 1
else:
false_pos += 1
confusion_matrix_dict["bg"]["guess_fault"] += 1
total += 1
accuracy = correct_guesses / total
recall = true_pos / (true_pos + false_negatives)
precision = true_pos / (true_pos + false_pos)
print(f"Total images: {total}")
print(f"Accuracy: {accuracy:.2f}")
print(f"Recall: {recall:.2f}")
print(f"Precision: {precision:.2f}")
# Confusion matrix
# Initialize the confusion matrix
data = confusion_matrix_dict
categories = list(data.keys())
confusion_matrix = []
# Populate the confusion matrix
for category in categories:
row = [data[category]["guess_bg"], data[category]["guess_fault"]]
confusion_matrix.append(row)
# Print the confusion matrix
print(f"{'':>12} {'Background':>10} {'Fault':>5}")
for i, category in enumerate(categories):
print(f"{category:>12}: {confusion_matrix[i][0]:>10} {confusion_matrix[i][1]:>5}")