## Spectral Regression Discriminant Analysis.

## This is an implementation of Spectral Regression Discriminant Analysis described in:
## 'SRDA: An Efficient Algorithm for Large ScaleDiscriminant Analysis' Deng Cai,
## Xiaofei He, Jiawei Han. 2008.

## This code is written by Roberto Visintainer, <visintainer@fbk.eu> and Davide Albanese, <albanese@fbk.eu>.
## (C) 2008 Fondazione Bruno Kessler - Via Santa Croce 77, 38100 Trento, ITALY.

## This program is free software: you can redistribute it and/or modify
## it under the terms of the GNU General Public License as published by
## the Free Software Foundation, either version 3 of the License, or
## (at your option) any later version.

## This program is distributed in the hope that it will be useful,
## but WITHOUT ANY WARRANTY; without even the implied warranty of
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
## GNU General Public License for more details.

## You should have received a copy of the GNU General Public License
## along with this program.  If not, see <http://www.gnu.org/licenses/>.


__all__ = ['Srda']

from numpy import *
from numpy.linalg import inv


class Srda:
    """Spectral Regression Discriminant Analysis (SRDA).

    Example:
    
    >>> import numpy as np
    >>> import mlpy
    >>> xtr = np.array([[1.0, 2.0, 3.1, 1.0],  # first sample
    ...                 [1.0, 2.0, 3.0, 2.0],  # second sample
    ...                 [1.0, 2.0, 3.1, 1.0]]) # third sample
    >>> ytr = np.array([1, -1, 1])             # classes
    >>> mysrda = mlpy.Srda()                 # initialize srda class
    >>> mysrda.compute(xtr, ytr)             # compute srda
    1
    >>> mysrda.predict(xtr)                  # predict srda model on training data
    array([ 1, -1,  1])
    >>> xts = np.array([4.0, 5.0, 6.0, 7.0]) # test point
    >>> mysrda.predict(xts)                  # predict srda model on test point
    -1
    >>> mysrda.realpred                      # real-valued prediction
    -6.8283034257748758
    >>> mysrda.weights(xtr, ytr)             # compute weights on training data
    array([ 0.10766721,  0.21533442,  0.51386623,  1.69331158])
    """

    def __init__ (self, alpha = 1.0):
        """Initialize the Srda class.

        :Parameters:
          alpha : float(>=0.0)
            regularization parameter
        """

        if alpha < 0.0:
            raise ValueError("alpha (regularization parameter) must be >= 0.0")
        
        self.__alpha = alpha
        
        self.__classes = None
        self.__a = None
        self.__th = 0.0
        self.__computed = False
        
        self.realpred = None
        
    def compute (self, x, y):
        """
        Compute Srda model.
          Initialize array of alphas a.

        :Parameters:  
          x : 2d ndarray float (samples x feats)
            training data
          y : 1d ndarray integer (-1 or 1)
            classes
            
        :Returns:
          1

        :Raises:
          LinAlgError
            if x is singular matrix in __PenRegrModel
        """

        # See eq 19 and 24
        
        self.__classes = unique(y)
        if self.__classes.shape[0] != 2:
            raise ValueError("SRDA works only for two-classes problems")

        cl0 = where(y == self.__classes[0])[0]
        cl1 = where(y == self.__classes[1])[0]             
        ncl0 = cl0.shape[0]
        ncl1 = cl1.shape[0]
        
        y0 = x.shape[0] / float(ncl0)
        y1 = -x.shape[0] / float(ncl1)              
        
        ym = append(ones(ncl0) * y0, ones(ncl1) * y1, axis = 1)

        newpos = r_[cl0, cl1]
        
        xi = x[newpos]
        xiT = xi.transpose()
        
        xXI = inv(dot(xi, xiT) + 1.0 + (self.__alpha * identity(x.shape[0])))
      
        c = dot(xXI, ym)
        
        self.__sumC = sum(c)
        self.__a = dot(xiT, c)
               
        ##### Threshold tuning ######
        ncomptrue = empty(x.shape[0], dtype = int)
        ths       = empty(x.shape[0])
        ytmp      = empty_like(y)

        self.__computed = True

        self.predict(x)
        rpsorted = sort(self.__rp_noTh)
        
        for t in range(ths.shape[0] - 1):
            ths[t] = (rpsorted[t] + rpsorted[t + 1]) * 0.5         
            ytmp[self.__rp_noTh <= ths[t]] = self.__classes[0]
            ytmp[self.__rp_noTh >  ths[t]] = self.__classes[1]
            comp = (y == ytmp)
            ncomptrue[t] = sum(comp)

        # Try th = 0.0
        ths[-1] = 0.0
        ytmp[self.__rp_noTh <=  ths[-1]] = self.__classes[0]
        ytmp[self.__rp_noTh > ths[-1]] = self.__classes[1]
        comp = (y == ytmp)
        ncomptrue[-1] = sum(comp)
        
        self.__th = ths[argmax(ncomptrue)]          
        #############################
    
        return 1
    
    def weights (self, x, y):
        """Return feature weights.
        
        :Parameters:
          x : 2d ndarray float (samples x feats)
            training data
          y : 1d ndarray integer (-1 or 1)
            classes
        
        :Returns:
          fw :  1d ndarray float
            feature weights
        """

        self.compute(x, y)
        return abs(self.__a)

    def predict (self, p):
        """Predict Srda model on test point(s).

        :Parameters:
          p : 1d or 2d ndarray float (sample(s) x feats)
            test sample(s)

        :Returns:
          cl : integer or 1d numpy array integer
            class(es) predicted

        :Attributes:
          self.realpred : float or 1d numpy array float
            real valued prediction
        """
        
        if self.__computed == False:
            raise StandardError("No SRDA model computed")
        
        if p.ndim == 2:
            pred = empty((p.shape[0]), int)
    
            self.__rp_noTh = -dot(self.__a, p.transpose()) - self.__sumC
            self.realpred = self.__rp_noTh - self.__th
            
            pred[self.realpred <= 0.0] = self.__classes[0]
            pred[self.realpred >  0.0] = self.__classes[1]
            
            return pred

        elif p.ndim == 1:
            self.__rp_noTh = -dot(p, self.__a) - self.__sumC
            self.realpred = self.__rp_noTh - self.__th
                                
            if  self.realpred <= 0.0: pred = self.__classes[0]
            elif self.realpred > 0.0: pred = self.__classes[1]
            
            return pred
        
