# -*- coding: iso-8859-1 -*-

#
# Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander
# Ilin, Tapani Raiko, Harri Valpola and Tomas Östman.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2, or (at your option)
# any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License (included in file License.txt in the
# program package) for more details.
#

"""Bayes Blocks example models

   Just say 'python examples.py <model>' to run.
   <model> can be one of the following:

     meanvar - Estimation of mean and variance
     fa      - Factor analysis
     dynvar  - Nonstationary ICA by variance modelling
"""

from bblocks.PyNet import PyNet, PyNodeFactory
from bblocks.Helpers import GetMean, GetMeanV, GetVar, GetVarV, Orthogonalize
from bblocks.Label import Label

try:
    import numpy.random
    randn = numpy.random.randn
    rand = numpy.random.rand
    import numpy.oldnumeric
    import numpy.oldnumeric as Numeric
    from numpy.oldnumeric import exp, dot, sin, cos, arange, pi, ones, sqrt, \
         sum, array, NewAxis, Float, zeros
    import numpy.oldnumeric.mlab as MLab
except:
    import MLab
    from MLab import rand, randn
    import Numeric
    from Numeric import exp, dot, sin, cos, arange, pi, ones, sqrt, \
         sum, array, NewAxis, Float, zeros

import sys


MODELS =  ["meanvar", "fa", "dynvar"]

helptext = """Bayes Blocks example models

   Just say 'python examples.py <model>' to run.
   <model> can be one of the following:

     meanvar - Estimation of mean and variance
     fa      - Factor analysis
     dynvar  - Nonstationary ICA by variance modelling
"""
if len(sys.argv) > 1:
    model = sys.argv[1]
    if model not in MODELS:
	if __name__ == "__main__":
	    print helptext
	    sys.exit(0)
	else:
	    model = "fa"
else:
    if __name__ == "__main__":
	print helptext
	sys.exit(0)
    else:
	model = "fa"

## Utility functions

def snr(s, shat):
    """Computes the SNR between s and shat."""
    vars = MLab.std(s, -1)**2
    vare = MLab.std(s - shat, -1)**2
    return 10*Numeric.log10(vars / vare)

def norm(X):
    """Computes the euclidian norm of the columns of X."""
    return sqrt(sum(X**2))

def evaluate_recons(S, A, X):
    """Evaluates the quality of the reconstuction of X by AS."""
    rX = dot(A, S)
    return snr(X, rX)

def evaluate_subspace(A, Ahat):
    A = Orthogonalize(A)
    Ahat = Orthogonalize(Ahat)

    e = []
    for j in range(A.shape[1]):
        e.append(norm(Ahat[:,j] - sum(sum(Ahat[:,j:j+1] * A) * A, 1)))

    return e

def corrcoeff(x, y):
    mx = MLab.mean(x)
    my = MLab.mean(y)

    sxx = Numeric.sum((x - mx)**2, -1)
    syy = Numeric.sum((y - my)**2, -1)
    sxy = Numeric.sum((x - mx) * (y - my), -1)

    return sxy / Numeric.sqrt(sxx * syy)

def corrmat(X, Y):
    C = zeros((X.shape[0], Y.shape[0]), Float)

    for i in range(X.shape[0]):
        for j in range(Y.shape[0]):
            C[i,j] = corrcoeff(X[i,:], Y[j,:])

    return C

def linmap(inputs, outdim, mask=None, prefix=""):
    """Constructs a linear mapping."""
    sums = []
    a = []
    for i in range(outdim):
        sum = f.GetSumNV(Label("%ssum" % prefix, i))
        a.append([])
        for j in range(len(inputs)):
            if (mask is None) or mask[i,j]:
                a[i].append(f.GetGaussian(Label("%sa" % prefix, i, j), c0, c0))
                p = f.GetProdV(Label("%sprod" % prefix, i, j),
                               a[i][j], inputs[j])
                sum.AddParent(p)
            else:
                a[i].append(None)
        sums.append(sum)
    return sums, a


## Estimating the mean and variance of data

if __name__ == "__main__" and model == "meanvar":

    # Generate the data
    
    data = 1.0 + exp(-0.5) * randn(1000)
    
    # Construct the model
    
    net = PyNet(1000)
    f = PyNodeFactory(net)
    
    c0 = f.GetConstant("const+0", 0.0)
    cm5 = f.GetConstant("const-5", -5.0)
    m = f.GetGaussian("m", c0, cm5)
    v = f.GetGaussian("v", c0, cm5)
    x = f.GetGaussianV("x", m, v)

    # Learn the model

    x.Clamp(data)

    for i in range(5):
        net.UpdateAll()
        print "%i : %f" % (i, -net.Cost())
    
    # Print the results

    print "mean = %f +- %f\n var = %f +- %f" % (
            GetMean(m), GetVar(m),
            GetMean(v), GetVar(v))

## Factor analysis

if __name__ == "__main__" and model == "fa":

    # Generate some data

    sdim = 2
    xdim = 10
    tdim = 500
    
    sources = randn(sdim, tdim)
    lmap = randn(xdim, sdim)
    noise = randn(xdim, tdim) * 0.1
    ary = dot(lmap, sources) + noise
    data = ary
    
    # Construct the model

    net = PyNet(tdim)
    f = PyNodeFactory(net)
    
    c0 = f.GetConstant("const+0", 0.0)
    cm5 = f.GetConstant("const-5", -5.0)

    vs = [f.GetGaussian(Label("vs", i), c0, cm5) for i in range(sdim)]
    s = [f.GetGaussianV(Label("s", i), c0, vs[i]) for i in range(sdim)]

    Aout, A = linmap(s, xdim)

    vx = [f.GetGaussian(Label("vx", i), c0, cm5) for i in range(xdim)]
    x = [f.GetGaussianV(Label("x", i), Aout[i], vx[i]) for i in range(xdim)]

    # Learn the model

    for i in range(10):
        x[i].Clamp(data[i])

    for j in range(2):
        f.EvidenceVNode(s[j], mean=randn(tdim), decay=5)

    for i in range(100):
        net.UpdateAll()
        if (i < 20) or ((i % 10) == 0):
            e = evaluate_subspace(lmap, GetMean(A))
            print "%i : %f : %f - %f" % (
                i, -net.Cost(), MLab.min(e), MLab.max(e))

    # Print the estimated noise std and the quality of the subspace

    print "Noise std estimates:"
    print exp(-0.5*GetMean(vx))
    print "Quality of the estimated subspace:"
    print evaluate_subspace(lmap, GetMean(A))

## Nonstationary ICA by variance modelling

if __name__ == "__main__" and model == "dynvar":

    # Generate some data

    sdim = 2
    xdim = 10
    tdim = 1000

    t = 1.0*arange(tdim)/tdim
    logstd = array([2*sin(10*pi*t) - 2, 2*cos(7*pi*t) - 2], Float) 
    sources = exp(logstd)*randn(sdim, tdim)
    lmap = Orthogonalize(randn(xdim, sdim))
    noise = randn(xdim, tdim) * 1e-2
    data = dot(lmap, sources) + noise
    
    # Repeate the estimation with several random initialisations to
    # escape local minima 

    nets = []

    for iter in range(10):

        # Build the model
        
        net = PyNet(tdim)
        f = PyNodeFactory(net)
    
        c0 = f.GetConstant("const+0", 0.0)
        cm5 = f.GetConstant("const-5", -5.0)
    
        pu = [f.GetProxy(Label("pu", j), Label("u", j)) for j in range(sdim)]
        du = [f.GetDelayV(Label("du", j), c0, pu[j]) for j in range(sdim)]
    
        Bout, B = linmap(du, sdim, prefix="B_")
    
        vu = [f.GetGaussian(Label("vu", j), c0, cm5) for j in range(sdim)]
        u = [f.GetGaussianV(Label("u", j), Bout[j], vu[j]) 
             for j in range(sdim)]
       
        s = [f.GetGaussianV(Label("s", j), c0, u[j]) 
             for j in range(sdim)]
    
        Aout, A = linmap(s, xdim, prefix="A_")
    
        vx = [f.GetGaussian(Label("vx", i), c0, cm5) for i in range(xdim)]
        x = [f.GetGaussianV(Label("x", i), Aout[i], vx[i]) 
             for i in range(xdim)]
    
        net.ConnectProxies()
    
        # Learn the model

        print "Model %d" % iter
    
        for i in range(xdim):
            x[i].Clamp(data[i])
    
        for j in range(sdim):
            f.EvidenceVNode(s[j], mean=randn(tdim), decay=5)
    
        for i in range(200):
            net.UpdateAll()
            if (i < 20) or ((i % 10) == 0):
                print "%i : %f" % (i, -net.Cost())
    
        # Print results
    
        print "Noise std estimates:"
        print exp(-0.5*GetMean(vx))
        print "Correlations between estimated and original sources:"
        print corrmat(sources, GetMeanV(s))
        print ""

        nets.append((net.Cost(), net, iter))

    # Use the evidence to select the best model

    nets.sort()
    net = nets[0][1]

    # Print the final results
    
    print "Final results (Model %d, Log evidence %g):" % (
        nets[0][2], -nets[0][0])
    print "Correlations between estimated and original sources:"
    print corrmat(sources, GetMeanV(net.GetVariableArray("s")))

