#
# Quality Control for raw gravity data
#  - linear fit of data points; look for slope != 0
#    - fit w/ and w/o first 3 minutes; changes in the slope?
#    - detrend the data using the linear fit
#  - residuals of linear fit; large ==> resonance

from scipy import *
from scipy.linalg import *
import copy
import sys

import grav_util
from grav_util import Truth
import temp_correct

class Reading:
  def __init__(self, grav, sigma, day, index):
    self.G = grav;
    self.S = sigma;
    self.T = day;
    self.index = index;

SKIP_TIME = 180.0;
default_slope_thresh = .000027 # in mGal/sec; roughly 50 uGal/30 min
default_sigma_thresh = 0.050 # mGal

# data is raw gravity data - numerically indexed as array, not dict
# parms, names are dicts of station parameters and names, indexed by id
# options is the dictionary of processing options
# write is the function to display messages
#
# return new copy of the data (updated for trend removal, etc.)
def check(rdata, parms, names, options, write):
  write("\n*** QUALITY CONTROL CHECK OF TIME SERIES ***\n");
  try:
    SKIP_TIME = float(options["skip"]);
    SKIP_TIME = SKIP_TIME*60.0;
  except:
    write("!!! Invalid skip time; assuming 0.0 minutes\n");
    SKIP_TIME = 0.0;
  options["skip"] = SKIP_TIME/60.0;

  if options.has_key("sigma_threshhold"):
    try:
      options["sigma_threshhold"] = float(options["sigma_threshhold"]);
    except:
      write("!!! Invalid sigma threshhold %s; assuming standard of %f\n"%( 
        options["sigma_threshhold"], default_sigma_thresh))
      options["sigma_threshhold"] = default_sigma_thresh
  else:
    options["sigma_threshhold"] = default_sigma_thresh
  SIGMA_THRESHHOLD = options["sigma_threshhold"]

  if options.has_key("slope_threshhold"):
    try:
      options["slope_threshhold"] = float(options["slope_threshhold"]);
    except:
      write("!!! Invalid slope threshhold %s; assuming standard of %f\n"%( 
        options["slope_threshhold"], default_slope_thresh))
      options["slope_threshhold"] = default_slope_thresh
  else:
    options["slope_threshhold"] = default_slope_thresh
  SLOPE_THRESHHOLD = options["slope_threshhold"]

  #
  # Remove/correct wonky temperature readings
  write(">>> Checking raw data for temperature problems\n")
  write("->> Remove points with temperature outside threshhold? ")
  if grav_util.Truth(options["temp_remove"]):
    write("yes.\n")
    if options.has_key("temp_threshhold_drop"):
      thresh = float(options["temp_threshhold_drop"])
    else:
      thresh = float(options["temp_threshhold"])
    thresh = abs(thresh)
    write("--> removing points with temps outside +-%f..."%thresh)
    i=0; dropped=0
    while i < (len(rdata)-1):
      if abs(rdata[i].temp) > thresh:
        del rdata[i]
        dropped += 1
      else:
        i += 1
    write("done.\n")
    write("--> %d points dropped.\n"%dropped)
  else:
    write("no.\n")

  write("->> Estimate correct temperatures & update gravity values? ")
  if grav_util.Truth(options["temp_correct"]):
    write("yes.\n")
    write("--> 'fixing' temperatures outside +-%f; updating gravity values..."%options["temp_threshhold"])
    temp_correct.fix(rdata, abs(float(options["temp_threshhold"])))
    write("done.\n")
    if grav_util.Truth(options["temp_correct_debug"]):
      for i in range(len(raw_data)):
	if hasattr(raw_data[i], "uncorrected_temp"):	# if corrected val
	  grav = raw_data[i].uncorrected_temp[0]
	  temp = raw_data[i].uncorrected_temp[1]
	else:
	  temp = raw_data[i].temp
	  grav = raw_data[i].gravity
	write("TEMPCO: %.5f\t%.3f %.3f\t%.3f %.3f\n"%(raw_data[i].jul_day-start_jday, temp, grav, raw_data[i].temp,
	  raw_data[i].gravity))
  else:
    write("no.\n")

  # break data array into series based on station id
  write(">>> Checking station time series for s.d. > threshhold=%f\n->>"%SIGMA_THRESHHOLD);
  data = {}; cnt=0; dcnt=0
  for i in range(len(rdata)):
    data[rdata[i].station_id] = []
  for i in range(len(rdata)):
    cnt = cnt+1; dcnt += 1
    data[rdata[i].station_id].append(Reading(rdata[i].gravity,
      rdata[i].sigma, rdata[i].jul_day, i));
    if not Truth(options["grav_samples"]):	# grav_samples ==> no s.d. per reading!
      if rdata[i].sigma > SIGMA_THRESHHOLD:
	write("\n->> Reading with s.d. > threshhold; check reading %d of station %s\n"%(
	  len(data[rdata[i].station_id]), rdata[i].station_id)); 
	write("->>");
	dcnt=0;
      else:
	write(".")
	if (dcnt%72) == 0:
	  write("\n->>");
  if Truth(options["grav_samples"]):
    write(" raw timeseries are samples, so no s.d.; skipping check")
  write("\n--> processed %d data points\n"%cnt);


  while 1:
    trend_flag = 0;
    write(">>> Time Series Linear Fits (y=m*x+b), threshhold=%f mGal/sec (%f uGal/hr)\n"%(SLOPE_THRESHHOLD, SLOPE_THRESHHOLD*3.6e6));
    write(">>>            <------- all data --------> <---- w/o 1st %2d sec ----->\n"%SKIP_TIME);
    write(">>> Station ID Linear slope Zero intercept Linear slope Zero intercept Slope (uGal/hr)\n"); 
    write(">>> ---------- ------------ -------------- ------------ -------------- ---------------\n");
    k = data.keys();
    k.sort(grav_util.num_sort);
    for i in k:
      # linear fit to station time series
      G = [];
      T = [];
      S = [];

      for j in range(len(data[i])):
	G.append(data[i][j].G);
	S.append(data[i][j].S);
	# convert time to relative to start, and in seconds
	if j == 0:
	  start_T = data[i][j].T
	T.append((data[i][j].T - start_T)*86400.0);
      (m1, b1, m2, b2, timeOffset) = llsq(G, S, T, write);
      if m1 == None:
	write("->> !!! Error performing linear least squares fit for station %s.\n"%i);
	continue;
      write("->> %10s %12.6f   %12.6f %12.6f   %12.6f %15.1f"%(i, m1, b1, m2, b2, m2*3.6e6));

      if abs(m2) > SLOPE_THRESHHOLD:
	if (not options.has_key("detrend_skip")) or \
	   (not options["detrend_skip"].has_key(i)) or \
	       (options["detrend_skip"][i] == 0):
	  trend_flag = trend_flag + 1
	  # detrend the data
	  for j in range(len(data[i])):
	    dt = (data[i][j].T - data[i][0].T)*86400 - timeOffset
	    data[i][j].G = data[i][j].G - m2*dt
	  if not options.has_key("detrend"):
	    options["detrend"] = {};
	  options["detrend"][i] = 1;
	  write(" m>threshhold; detrended\n");
	else:
	  write(" m>threshhold; skipped\n");
      else:
	write("\n");
    write("--> %d stations detrended.\n"%trend_flag)
    if trend_flag == 0:
      break;
    else:
      write("--> Rerunning with updated data.\n");

  # create copy of raw data, updated with trend removal
  ndata = copy.deepcopy(rdata)
  for i in data.keys():
    for j in range(len(data[i])):
      ndata[data[i][j].index].gravity = data[i][j].G

  write("--> Done with quality control checking\n");
  return (ndata, options)



def llsq(y, dy, x, write=sys.stdout):
  # weighted least squares inversion for
  # y=m*x+b
  # or in matrix form
  # Y = A * M
  # where Y is vector column of gravity readings, M is model vector [m b], and A is the
  # array [ x_i 1 ]
  #
  # for W=diag(w_i)=diag(1/dy_i^2), the solution to the above is
  # M = Iw * Y
  # where Iw is the weighted generalized inverse:
  # Iw = (A'*W^2*A)^-1 * A'*W^2
  R = [ None, None, None, None, None];
  M = len(y);
  if M != len(dy):
    write("--> !!! llsq: error vector not equal length to data vector.\n");
    return R;
  if M != len(x):
    write("--> !!! llsq: independant vector not equal length to data vector.\n");
    return R;

  try:
    (R[0], R[1]) = matrixops(y, dy, x);
  except:
    R[0] = None;
    write("--> !!! llsq: error during matrix operations.\n");
    return R;

  for i in range(M):
    if x[i] >= SKIP_TIME:
      break;
  start = i; 
  yp=[]; dyp=[]; xp=[];
  for i in range(start, M):
    yp.append(y[i])
    dyp.append(dy[i])
    xp.append(x[i]-x[start]) # reset t0 to first time after SKIP_TIME
  try:
    (R[2], R[3]) = matrixops(yp, dyp, xp);
  except:
    write("--> !!! llsq: error during matrix ops with shortened time series.\n");
    R[0] = None;

  return (R[0], R[1], R[2], R[3], x[start])

def matrixops(y, dy, x, write=sys.stdout):
  M = len(y);
  if M != len(dy):
    return (None, None);
  if M != len(x):
    return (None, None);
  A = ones((M,2),float32)
  Y = array(y);
  # W is diag matrix
  W = zeros((M,M),float32);
  for i in range(M):
    if dy[i] == 0: dy[i] = 0.001;	# can't have 0 error!!!
    W[i,i] = 1/(dy[i] ** 2);
    A[i, 0] = x[i];

  # compute model parameters
  #W2 =  W[:,newaxis]*W[:,newaxis]
  W2 =  dot(W, W);
  Q = dot(transpose(A), dot(W2, A));
  I = inv(Q);
  # delete an axis from W2, to make shape [M,] again
  # this is very, very ugly, but seems to be the only way to delete
  # an axis without making a copy.
  I = dot(I, dot(transpose(A), W2))

  M = dot(I, Y);
  return M;
