from __future__ import division
from functools import partial
from pathlib import Path
from menpo.base import MenpoMissingDependencyError
try:
import dlib
except ImportError:
raise MenpoMissingDependencyError("dlib")
from menpodetect.detect import detect
from .conversion import rect_to_pointgraph
class _dlib_detect(object):
r"""
A utility callable that allows the caching of a dlib detector.
This callable is important for presenting the correct parameters to the
user. It also marshalls the return type of the detector back to
`menpo.shape.PointDirectedGraph`.
Parameters
----------
model : `Path` or `str` or `dlib.simple_object_detector`
Either a path to a `dlib.simple_object_detector` or a
`dlib.fhog_object_detector` or the detector itself.
Raises
------
ValueError
If a path was provided and it does not exist.
"""
def __init__(self, model):
if isinstance(model, str) or isinstance(model, Path):
m_path = Path(model)
if not Path(m_path).exists():
raise ValueError("Model {} does not exist.".format(m_path))
# There are two different kinds of object detector, the
# simple_object_detector and the fhog_object_detector, but we
# can't tell which is which from the file name. Therefore, try one
# and then the other. Unfortunately, it throws a runtime error,
# which we have to catch.
try:
model = dlib.simple_object_detector(str(m_path))
except RuntimeError:
model = dlib.fhog_object_detector(str(m_path))
self._dlib_model = model
def __call__(self, uint8_image, n_upscales=0):
r"""
Perform a detection using the cached dlib detector.
Parameters
----------
uint8_image : `ndarray`
An RGB (3 Channels) or Greyscale (1 Channel) numpy array of uint8
n_upscales : `int`, optional
Number of times to upscale the image when performing the detection,
may increase the chances of detecting smaller objects.
Returns
------
bounding_boxes : `list` of `menpo.shape.PointDirectedGraph`
The detected objects.
"""
# Dlib doesn't handle the dead last axis
if uint8_image.shape[-1] == 1:
uint8_image = uint8_image[..., 0]
rects = self._dlib_model(uint8_image, n_upscales)
return [rect_to_pointgraph(r) for r in rects]
[docs]class DlibDetector(object):
r"""
A generic dlib detector.
Wraps a dlib object detector inside the menpodetect framework and provides
a clean interface to expose the dlib arguments.
"""
def __init__(self, model):
self._detector = _dlib_detect(model)
[docs] def __call__(
self,
image,
greyscale=False,
image_diagonal=None,
group_prefix="dlib",
n_upscales=0,
):
r"""
Perform a detection using the cached dlib detector.
The detections will also be attached to the image as landmarks.
Parameters
----------
image : `menpo.image.Image`
A Menpo image to detect. The bounding boxes of the detected objects
will be attached to this image.
greyscale : `bool`, optional
Convert the image to greyscale or not.
image_diagonal : `int`, optional
The total size of the diagonal of the image that should be used for
detection. This is useful for scaling images up and down for
detection.
group_prefix : `str`, optional
The prefix string to be appended to each each landmark group that is
stored on the image. Each detection will be stored as group_prefix_#
where # is a count starting from 0.
n_upscales : `int`, optional
Number of times to upscale the image when performing the detection,
may increase the chances of detecting smaller objects.
Returns
------
bounding_boxes : `list` of `menpo.shape.PointDirectedGraph`
The detected objects.
"""
detect_partial = partial(self._detector, n_upscales=n_upscales)
return detect(
detect_partial,
image,
greyscale=greyscale,
image_diagonal=image_diagonal,
group_prefix=group_prefix,
)
[docs]def load_dlib_frontal_face_detector():
r"""
Load the dlib frontal face detector.
Returns
-------
detector : `DlibDetector`
The frontal face detector.
"""
return DlibDetector(dlib.get_frontal_face_detector())