import cv2
import numpy as np
from .utils import Hungarian
[docs]class MultiBoxAssertionError(AssertionError):
"""Custom assertion error class for multi-box assertions"""
def __init__(self, message):
"""
Initialize a MultiBoxAssertionError.
Args:
message (str): The error message.
"""
super().__init__(message)
[docs]def nearly_contains(box_1, box_2, eps):
"""
Check if box_1 nearly contains box_2 with a given epsilon value.
Args:
box_1 (tuple|list): The coordinates of the first box in the format (x_min, y_min, x_max, y_max).
box_2 (tuple|list): The coordinates of the second box in the format (x_min, y_min, x_max, y_max).
eps (float): The epsilon value.
Returns:
bool: True if box_1 nearly contains box_2, False otherwise.
"""
return (box_1[0] + eps < box_2[2] or box_2[2] + eps < box_1[0]) and (
box_1[1] + eps < box_2[3] or box_2[3] + eps < box_1[1]
)
[docs]def assert_multibox(boxes, counter_threshold=2, eps=0):
"""
Assert the multi-box condition for a list of boxes.
Args:
boxes (list): A list of boxes. The format of each box is a tuple or list with the format of (x_min, y_min, x_max, y_max)
counter_threshold (int, optional): The minimum number of overlaps required for each box. Defaults to 2.
eps (float, optional): The epsilon value for nearly_contains function. Defaults to 0.
Raises:
MultiBoxAssertionError: If any box overlaps with more than counter_threshold other boxes.
"""
len_boxes = len(boxes)
counter = [0 for i in range(len_boxes)]
for i in range(len_boxes):
for j in range(len_boxes):
if i == j:
continue
if nearly_contains(boxes[i], boxes[j], eps):
counter[i] += 1
if counter[i] >= counter_threshold:
raise MultiBoxAssertionError(
f"Box {i} overlaps more than {counter_threshold} other boxes."
)
[docs]class FlickerException(Exception):
"""Custom exception class for flickering detection."""
def __init__(self, message):
"""
Initialize a FlickerException.
Args:
message (str): The error message.
"""
super().__init__(message)
[docs]def psnr(img_1, img_2, data_range=255, eps=1e-8):
"""
Calculate the peak signal-to-noise ratio (PSNR) between two images.
Args:
img_1 (numpy.ndarray): The first image.
img_2 (numpy.ndarray): The second image.
data_range (int, optional): The data range of the images. Defaults to 255.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-8.
Returns:
float: The PSNR value.
"""
if img_1.shape != img_2.shape:
new_shape = (
max(img_1.shape[0], img_2.shape[0]),
(max(img_1.shape[1], img_2.shape[1])),
)
img_1 = cv2.resize(img_1, new_shape)
img_2 = cv2.resize(img_2, new_shape)
mse = np.mean((img_1.astype(float) - img_2.astype(float)) ** 2)
return 10 * np.log10((data_range**2) / (mse + eps))
[docs]def get_simillar_boxes(
frame_1,
frame_2,
boxes_1,
boxes_2,
similarity_threshold=35,
data_range=255,
eps=1e-8,
):
"""
Get similar boxes between two frames based on the peak signal-to-noise ratio (PSNR).
Args:
frame_1 (numpy.ndarray): The first frame.
frame_2 (numpy.ndarray): The second frame.
boxes_1 (list): A list of boxes in frame_1.
boxes_2 (list): A list of boxes in frame_2.
similarity_threshold (float, optional): The similarity threshold for matching boxes. Defaults to 35.
data_range (int, optional): The data range of the images. Defaults to 255.
eps (float, optional): A small value to avoid division by zero. Defaults to 1e-8.
Returns:
list: A list of similar boxes.
"""
n1 = len(boxes_1)
n2 = len(boxes_2)
sim_matrix = np.zeros((max(n1, n2), max(n1, n2))) + (1 / eps)
for i in range(n1):
for j in range(n2):
sim_matrix[i][j] = 1 / (
psnr(
frame_1[
int(boxes_1[i][1]) : int(boxes_1[i][3]) :,
int(boxes_1[i][0]) : int(boxes_1[i][2]),
],
frame_2[
int(boxes_2[j][1]) : int(boxes_2[j][3]) :,
int(boxes_2[j][0]) : int(boxes_2[j][2]),
],
data_range,
eps,
)
+ eps
)
h = Hungarian()
assignment = h.compute(sim_matrix)
sim_thres_inv = 1 / similarity_threshold
similar_boxes = []
for ass in assignment:
if sim_matrix[ass[0]][ass[1]] < sim_thres_inv:
similar_boxes.append(boxes_2[ass[1]])
return similar_boxes
[docs]def except_flicker(
frames,
boxes_of_frames,
cur_frame_id=-1,
window_size=10,
similarity_threshold=35,
data_range=255,
):
"""
Check for flickering between the current frame and previous frames.
Args:
frames (list): A list of frames.
boxes_of_frames (list): A list of boxes for each frame.
cur_frame_id (int, optional): The index of the current frame. Defaults to -1.
window_size (int, optional): The number of previous frames to consider. Defaults to 10.
similarity_threshold (float, optional): The similarity threshold for matching boxes. Defaults to 35.
data_range (int, optional): The data range of the images. Defaults to 255.
Raises:
FlickerException: If flickering is detected between the current frame and a previous frame.
"""
cur_frame = frames[cur_frame_id]
cur_boxes = boxes_of_frames[cur_frame_id]
for i in range(
cur_frame_id - 1, max(-len(frames), cur_frame_id - window_size) - 1, -1
):
similar_boxes = get_simillar_boxes(
cur_frame,
frames[i],
cur_boxes,
boxes_of_frames[i],
similarity_threshold,
data_range,
)
if len(similar_boxes) == 0:
for j in range(
i - 1, max(-len(frames), cur_frame_id - window_size) - 1, -1
):
overlapping_boxes = get_simillar_boxes(
cur_frame,
frames[j],
cur_boxes,
boxes_of_frames[j],
similarity_threshold,
data_range,
)
if len(overlapping_boxes) == 0:
continue
else:
raise FlickerException(
f"Flickering detected between frame {cur_frame_id} and frame {j}"
)
else:
cur_frame = frames[i]
cur_boxes = similar_boxes