# Linear least-squares inversion for polynomial drift function
# we make degree of polynomial a parameter, so we can implement
# iterative approach to find optimal degree.

# REQUIRES SCIPY EXTENSIONS
from scipy import *
from scipy.linalg import *

from math import *

from quadrature import sinc_ftest

import sys

# Linear least-squares inversion to a polynomial of degree n
# using repeated station pairs; one set of poly. coefficients a
# algorithm is standard weighted linear least-squares, as per
# Inversion Theory class, Fall 1998 (Zhdanov)
# See comments for how the system Am=d is setup and solved.
# See off-line writeup for algebra of the problem.

# input variables:
# n	degree of poly to fit
# data	dictionary of GravityValue points, already corrected and averaged
# reps	dictionary of station reoccupations

# returns:
# a	list of polynomial coefficients
# m	number of data pairs; size of system
def polyfit(n, data, reps, weighting=1):
  # data is a GravityValue dict, so we want elements
  # data[].G, data[].sigma, and data[].time

  # create initial lists
  # reps is dictionary mapping id to ids of reoccupations
  # so we loop over the dictionary to construct lists:
  # t1, t2	times for the pairs
  # r1, r2	gravity for the pairs
  # s1, s2	errors for the pairs
  # also compute m = # of pairs
  m=0; t1=[]; t2=[]; r1=[]; r2=[]; s1=[]; s2=[]
  id1=[]; id2=[];
  for i in reps.keys():
    for j in reps[i]:
      try:
	t1.append(data[i].time)
	r1.append(data[i].G)
	s1.append(data[i].sigma)
	id1.append(i)
	t2.append(data[j].time)
	r2.append(data[j].G)
	s2.append(data[j].sigma)
	id2.append(j)
      except KeyError:
        # tried to use a repeat of a station that doesn't exist
	return ([], -1)
      m=m+1
      i=j
  # add entries for 1st to 3rd,4th,etc. pairs
  for i in reps.keys():
    for j in range(1,len(reps[i])):
      t1.append(data[i].time)
      r1.append(data[i].G)
      s1.append(data[i].sigma)
      id1.append(i)
      t2.append(data[reps[i][j]].time)
      r2.append(data[reps[i][j]].G)
      s2.append(data[reps[i][j]].sigma)
      id2.append(reps[i][j])
      m=m+1

  # now construct matrices
  # form matrix A, W, d
  # matrix A is operator matrix of "delta" times
  # matrix d is vector of dG
  #                             1                    0.5
  # W is matrix of weights; W = -, where S = (S + S )
  #                          i  S              1   2
  # where S , S  are the standard deviations for the readings
  #        1   2
  # this is by propogation of error for subtraction (see Melissinos)
  A = zeros((m,n), float64)
  W = zeros((m,m), float64)
  d = zeros((m,), float64)
  for i in range(m):
    d[i] = r2[i] - r1[i]
    W[i,i] = 1 / sqrt(s1[i]**2 + s2[i]**2)
    for j in range(n):
      A[i,j] = t2[i]**(j+1) - t1[i]**(j+1)

  if not weighting:
    W[i,i] = 1.0;

  # now compute A* - generalized weighted inverse of A
  #            T 2  -1 T 2   -1 T
  # I = A* = (A W A)  A W 
  # do everything as matrix ops, not loops!
  ### NOTE NOTE NOTE
  ## pinv2() currently(?)/used to explode, so do matrix
  ## inv of A'WWA.
  #(U, s, V) = svd( dot(A.T, dot(W, dot(W, A))) )
  #I = dot( dot(V, dot(diag(1/s), U.T)), A.T)
  I = dot( inv(dot(A.T, dot(W, dot(W, A)))), A.T)
  I = dot( I, dot(W,W) )

  # now compute m
  M = dot(I, d)

  a = []
  # now put it in a list
  for i in range(len(M)):
    a.append(M[i])

  return (a, m)



# compute weighted L1, L2 norm residuals of dG data
def residual(dG, S):
  chisqL2 = 0.0; chisqL1 = 0.0; N = 0.0
  for i in range(len(dG)):
      chisqL1 = chisqL1 + ( dG[i] / S[i] )
      chisqL2 = chisqL2 + ( (dG[i] / S[i])**2 )
      N = N + (1.0/S[i])

  L1 = 1.0/N * chisqL1
  L2 = 1.0/N * sqrt(chisqL2)

  return (L1, L2)


def Ftest(chi2n, chi2p, n, p, N):
  global nu1, nu2
  # compute confidence level of improvement with higher order
  # note that assume p>n
  # N is total # of data points

  F = ( (chi2n - chi2p)/(p-n) ) / ( chi2p / (N-p) ) 

  # now compute integral
  nu1 = float(N - n); nu2 = float(N - p)
  C = exp(gammln((nu1+nu2)/2.0) - gammln(nu1/2.0) - gammln(nu2/2.0))
  C = C * nu1**(nu1/2.0) * nu2**(nu2/2.0)

  # Simpson's rule for quadrature
  I = sinc_ftest(t, nu1, nu2)

  return C * I

def g(x):
  global nu1, nu2
  return (x**(nu2/2 - 1) * ( nu2 + nu1*x )**((nu1+nu2)/2))



def correction(data, reps, n, weighting=1):
  # compute correction for each station
  # first we need coefficients of polynomial of degree n

  # then, we compute the "offset" of the drift function so that at
  # start of data, it is 0

  # Compute polynomial fit
  if n > 0:
    (C, N) = polyfit(n, data, reps, weighting)
  else:
    # order 0 ==> no corrections
    C = []

  # cycle through data points, computing f(t) = Sum(k=1...n){C[k]*t^k}
  for i in data.keys():
    f = 0.0
    for k in range(n):
      f = f + (C[k] * data[i].time**(k+1))
    data[i].drift_correction = f

  # create dG data for residual calculation
  L1 = [0,0]; L2 = [0,0]
  dG = []; dG2 = []; S = []
  for i in reps.keys():
    for j in reps[i]:
      # compute dG, S
      t = data[j].G - data[j].drift_correction
      t = t - (data[i].G - data[i].drift_correction)
      dG.append(t)
      t = data[j].G - data[i].G
      dG2.append(t)
      S.append(sqrt(data[i].sigma**2 + data[j].sigma**2))

  # compute residuals
  (L1[0], L2[0]) = residual(dG2, S)
  (L1[1], L2[1]) = residual(dG, S)

  # create new array that has C[0] as the offset
  a = [0]
  for i in range(len(C)):
    a.append(C[i])

  return (a, L1, L2)


def printMatrix(A):
  (m, n) = A.shape
  for i in range(m):
    for j in range(n):
      sys.stdout.write("%.6e "%(A[i,j]))
    sys.stdout.write("\n")
