"""Simple image recognition library - SVD based.
Ben Herbst - U. Stellenbosch.
Fernando Perez - U. Colorado, Boulder."""
# Required packages
# Std lib
import os
# Third-party
import pylab as P
import numpy as N
import scipy as S
# Scipy has a special loading mechanism to import multiple subpackages into
# its own namespace for convenience
S.pkgload('linalg')
# Classes and functions begin
def imshow2(m1,m2,labels=(None,None)):
    """Display two images side by side.
    Returns the created figure instance."""
    fig = P.figure()
    ax1 = [0.025,0.1,0.45,0.775]
    ax2 = [0.525,0.1,0.45,0.775]
    for m,ax_coord,label in [(m1,ax1,labels[0]),(m2,ax2,labels[1])]:
        ax = fig.add_axes(ax_coord)
        ax.imshow(m,cmap=P.cm.gray)
        if label:
            ax.set_xlabel(label)
        P.xticks([],[])
        P.yticks([],[])
    P.draw_if_interactive()
    return fig
class ImageCollection(object):
    """Class to hold a collection of image files stored  in a directory"""
    def __init__(self,images,names=None):
        """Construct a collection from a list of images and names.
        Inputs:
          - images: a sequence of valid image objects.
        Optional inputs:
          - names: a sequence of strings to be used as names for the images.
          If not given, the images are simply named by their index."""
        num_images = len(images)
        if names is None:
            names = map(str,range(num_images))
        assert (len(names) == num_images,
                'List of names must be of same length as image sequence')
        self.images = images
        self.names = names
        # make a dict for keyed access
        self.img_dict = dict(zip(names,images))
        self.num_images = num_images
        
        # Public attributes
        self.interpolation = 'nearest'
        self.images 
    @staticmethod
    def from_directory(dir_name,verbose=False):
        """Read all images in a given directory, return an ImageCollection.
        It reads all files in the directory and tries to compute an image
        (using scipy's imread) for all of them and stores a list of
        (filename,array) pairs.
        Inputs:
          - dir_name: a string containing the directory name to scan.
        Optional inputs:
          - verbose(False): print extra verbose information about skipped
          files while running.  Set to any True value for basic diagnostics of
          error conditions, and to higher integer values for further messages.
        Outputs:
          - an ImageCollection instance.
          """
        img = []
        names = []
        imread = S.misc.pilutil.imread
        for fname in os.listdir(dir_name):
            full_fname = os.path.join(dir_name,fname)
            try:
                if verbose > 1:
                    print 'Reading file:',fname
                img.append(imread(full_fname))
                names.append(fname)
            except IOError:
                if verbose:
                    print 'Skipping non-image file:',fname
        # Safety warning
        if not img:
            print 'WARNING: empty image collection, no valid images found.'
        return ImageCollection(img,names)
            
    def show_image(self,number,fignum=None,interpolation=None):
        """Show a single image, by its index.
        Inputs:
          - number: index (0-offset) of the image to display.
          
        Optional inputs:
          - fignum(None): a matplotlib figure number, to reuse for display.
          If not given, a new figure is automatically created.
          - interpolation: interpolation mode for display.  If not given, the
          instance .interpolation attribute is used.
        Outputs:
          The number of the created figure window, so it can be reused."""
        if interpolation is None:
            interpolation = self.interpolation
        image = self.images[number]
        name = self.names[number]
        if fignum is None:
            # make a new figure from scratch
            fig = P.matshow(image,cmap=P.cm.gray,interpolation=interpolation)
        else:
            # draw into an existing figure directly
            fig = P.figure(fignum)
            ax_im = fig.axes[0].images[0]
            ax_im.set_data(image)
            P.draw()
            
        P.title('Image [%d]: %s - (%d/%d)' %
                (number,name,number+1,self.num_images))
        P.draw_if_interactive()
        return fig.number
    def browse_images(self):
        """Browse a set of image files"""
        # Show the first figure separately, so we can reuse the figure window
        fignum = self.show_image(0)
        count = 0
        while count < self.num_images:
            self.show_image(count,fignum)
            ans = raw_input('Enter for next, <p> for previous, <q> to quit: ')
            if ans=='p':
                if count>0: 
                    count -= 1
            elif ans=='q':
                break
            else:
                count += 1
    def list_images(self):
        """Print a listing of all images in the collection"""
        print 'Total number of images:',self.num_images
        for i,name in enumerate(self.names):
            print '%3d - %s' % (i,name)
    def normalized_flat(self):
        """Return a normalized version of the images in the collection.
        This is a single array where each image has been 'flattened' as a
        column vector."""
        # Make a single matrix where we'll store the flattened version of all
        # the images for further processing.  We pick up the image dimensions
        # from the first one without checking they all have the same
        # dimensions, we can validate this more strictly later.
        im0 = self.images[0]
        imshape = im0.shape
        shape = (imshape[0]*imshape[1],self.num_images)
        # The flat image has to be in a floating-point type so we can do SVD
        # and similar things with it
        img_flat = N.empty(shape,N.float32)
        for col,im in enumerate(self.images):
            img_flat[:,col] = im.flat
        # remove average vectors to make a 'normalized' image
        img_avg = img_flat.mean(axis=1)
        # we need to broadcast the average as a column vector
        return img_avg, img_flat - img_avg[:,N.newaxis]
        
            
# Functions for actual facial recognition
class TrainingImages(object):
    """Class to represent a collection of training images"""
    def __init__(self,dir_name,verbose=False):
        """Load a collection of images from a directory"""
        self.image_coll = ImageCollection.from_directory(dir_name,
                                                         verbose=verbose)
        self.num_images = self.image_coll.num_images
        img_avg,img_flat_norm = self.image_coll.normalized_flat()
        # Compute singular values and U matrix
        #umat,sigma,vtmat = S.linalg.svd(img_flat_norm)
        umat,sigma,vtmat = N.linalg.svd(img_flat_norm,0)
        # Array of 'eigenimages' in normal (not flattened) format
        imshape = self.image_coll.images[0].shape
        eig_img = N.empty((self.num_images,imshape[0],imshape[1]),umat.dtype)
        for i in range(self.num_images):
            eig_img[i] = N.reshape(umat[:,i],imshape)
        # Store in object all these
        self.sigma = sigma
        self.umat = umat
        self.vtmat = vtmat
        self.imshape = self.image_coll.images[0].shape
        self.eig_img = ImageCollection(eig_img)
        self.img_avg = img_avg
        self.img_flat_norm = img_flat_norm
        # Truncation index, by default we don't truncate
        self.truncate_idx = self.num_images-1
            
    def plot_sigma(self):
        """Plot the singular values, normalized by the first"""
        P.figure()
        P.plot(self.sigma/self.sigma[0])
        P.title('Normalized singular values')
        P.grid()
        P.draw_if_interactive()
    def _set_truncate_idx(self,idx):
        """Set the truncation index as an integer.
        This is typically read from the singular value plot, it should be an
        integer in the range of the number of images in the collection
        (.num_images)."""
        if idx<0 or idx>=self.num_images:
            raise ValueError('index must be >0 and <%d' % self.num_images)
        
        self._trunc_idx = idx
        self.umat_trunc = self.umat[:,:self._trunc_idx]
        # compute projection coefficients for all the original input (but
        # normalized) images against the eigenimages.
        utt = self.utt = self.umat_trunc.transpose().astype(N.float32)
        proj_c = N.empty((self.num_images,idx),N.float32)
        for j in range(self.num_images):
            proj_c[j] = N.dot(utt,self.img_flat_norm[:,j])
        self.proj_c = proj_c
    truncate_idx = property(lambda x: self._trunc_idx,
                            _set_truncate_idx,
                            None,
                            _set_truncate_idx.__doc__)
    def verify(self,test_img,key):
        """Verify that a provided test image corresponds to a given key.
        Assumes that the input is unnormalized."""
        
        ref_coeffs = self.proj_c[key]
        norm_test = test_img.flat - self.img_avg
        test_coeffs = N.dot(self.utt,norm_test)
        l2_err = N.linalg.norm(test_coeffs-ref_coeffs,2)
        imshow2(self.image_coll.images[key],test_img,
                labels = ('Reference Image','Test Image'))
        P.title('L2 error: %.2e' % l2_err)
    def identify(self,test_img):
        """Verify that a provided test image corresponds to a given key.
        Assumes that the input is unnormalized."""
        
        norm_test = test_img.flat - self.img_avg
        test_coeffs = N.dot(self.utt,norm_test)
        diff2 = ((self.proj_c - test_coeffs)**2).sum(axis=1)
        minidx = diff2.argmin()
        best_err = diff2[minidx]
        imshow2(test_img,self.image_coll.images[minidx],
                labels = ('Test Image','Best Match: %d' % minidx))
        P.title('L2 error: %.2e' % best_err)