# grav_util.py - utility functions for gravity reduction

import sys, string, math
from types import *

import scipy as Numeric

import thiele
from grav_data import GravityValue

import jd # Julian Day routines

def jul_day(date, time):
  return jd.jul_day(date, time)
def str2jd(date_str, time_str):
  return jd.str2jd(date_str, time_str);
def str2jd2(date_str, time_str):
  return jd.str2jd2(date_str, time_str);
def calc_jday(Y, M, D, h, m, s):
  return jd.calc_jday(Y, M, D, h, m, s);
def un_jday(jul_day):
  return jd.un_jday(jul_day);
def datestr(jday):
  return jd.datestr(jday);
def datestr2(jday):
  return jd.datestr2(jday);
def fract_day(time_str):
  return jd.fract_day(time_str);

def timestr(time):
  hour = time * 24.0
  minute = (hour % 1) * 60.0
  second = round((minute % 1) * 60.0)
  return "%3d:%02d:%02d" % (hour, minute, second)

# find earliest jul_day and subtract from all
# raw data, returning offset subtracted.
def calc_relative_date(data):
  D=data[0].jul_day
  for i in range(len(data)):
    if D > data[i].jul_day:
      D = data[i].jul_day
  # handle the case where we start in afternoon
  if (D - math.floor(D)) < 0.5:
    # subtract one from start_jday
    D = D - 1.0
  D = math.floor(D)+0.5
  for i in range(len(data)):
    data[i].date = data[i].jul_day - D

  return D

def get_choice(prompt):
  choice = raw_input(prompt + '(Y/N)? ')
  if string.upper(choice) == 'Y':
    return 1
  else:
    return 0

# Average data for each station in raw data set.
#
# Assumes that raw_data has the raw data imported from disk file, and
# that each station has the first reading first! The first entry in
# raw_data for each station ID must be the first station reading!
#
# All stations should be contiguous, so a change of ID indicates a
# change of station data; do not intersperse different station data
# together; each station ID change will be treated as a new, unique,
# station.
#
# skip_times is in minutes.
# max_recs can be used to truncate long time series, but is normally
# left at -1 (use all data in a time series)
# sample_flag is 1/true if raw data series is of samples, NOT averages
#   ==> true for Aliod data, false for CG-3 data!
def average_data(raw_data, parms, names, skip_time, max_recs = -1, sample_flag=0):
  skipped = 0	# count of skipped data points
  dropped = 0	# count of skipped stations
  data = {}
  ids = get_uniq_ids(raw_data)	# get all uniq ids
  # for each unique id, get average values and fill data dict
  for sid in ids.keys():
    bars = average_station(sid, raw_data, skip_time, max_recs, sample_flag)
    skipped = skipped + bars[4]
    if bars[0] != None:	# bars[0]==None ==> not enough data
      incoming = GravityValue()
      incoming.station_id = sid
      incoming.raw_gravity = bars[0]
      incoming.raw_sigma = bars[1]
      incoming.time = bars[2]
      incoming.jul_day = incoming.time
      incoming.GMT_Diff = bars[3]
      data[sid] = incoming
    else:
      dropped += 1
  # get names of the stations, in addition to IDs
  for i in data.keys():
    if names.has_key(i):
      data[i].name = names[i]
    else:
      data[i].name = "UNKNOWN"
  # reset lat/lon to values from parameter file - just in case
  for i in data.keys():
    data[i].lat = parms[data[i].station_id].lat
    data[i].lon = parms[data[i].station_id].lon
    data[i].elevation = parms[data[i].station_id].elevation

  return (data, skipped, dropped)

def average_station(station, data, skip, max_recs=-1, sample_flag=0):
  # computes weighted average of g (gbar) and returns (gbar, s.e., ave
  # time (in jul. days), ave GMT, # dropped)
  # station should be a string specifying the station ID
  # data should be a list of GravityReading objects
  # that are to be averaged using the data.gravity and data.sigma
  # fields (but see below for sample_flag modification)
  # skip is a time, in minutes, to throw away data at beginning
  # (remove elastic effects of spring)
  # max_recs is the maximum number of records to use, including skipped
  # ones; set negative to use all records
  # sample_flag, if true, means to ignore data.sigma as each reading is
  # a single sample, not an average; so each point weighted equally!
  dropped = 0 # dropped data points
  time = skip*60.0   # cvrt to seconds
  G = []; # Build arrays of the data, and then pass to weightedAverage()
  dG = []; # and average() for computation
  T = [];
  GMT = [];
  j=0; nr = 0
  read_time = 0.0
  for i in range(len(data)):
    if data[i].station_id == station:
      if nr == 0: # if first record for this station, set start time
	startTime = data[0].jul_day
      nr = nr+1
      # compute time from first reading, in seconds
      read_time = (data[i].jul_day - startTime)*86400
      if read_time < time:
	continue	# throw this record away
      if max_recs > 0:  # max_recs < 0 ==> use all records
	if nr > max_recs: # if over max_recs, throw away
	  continue
      if data[i].sigma == 0.0:	# sigma == 0.0 ==> bad point!
	dropped = dropped + 1
	continue
      G.append(data[i].gravity)
      dG.append(data[i].sigma)
      T.append(data[i].jul_day)
      GMT.append(data[i].GMT_Diff)
      j = j+1
  if j == 0:
    # no valid data records
    # set results to None and 0
    gbar = None; sbar = 0; tbar = 0;
    GMTbar = 0;
  else:
    if Truth(sample_flag):
      (gbar, sbar) = sampleAverage(G);
    else:
      (gbar, sbar) = weightedAverage(G, dG);
    # Average most fields, but we actually want the last time
    # of a measurement, not the average!
    tbar = max(T);
    GMTbar = average(GMT);
  return (gbar, sbar, tbar, GMTbar, dropped)

def gref_average(base_id, data, repeats):
  # compute normal (unweighted) average
  G = data[base_id].gravity
  n = 1
  for i in range(len(repeats[base_id])):
    G = G+data[repeats[base_id][i]].gravity
    n = n+1
  return G/n

def average(y):
  # computes arithmetic average of array y
  # returns average value
  Sy = 0.0;

  if len(y) < 1:
    return (0)
  for i in range(len(y)):
    Sy = Sy + y[i];
  ybar = Sy/len(y)
  return (ybar)

def weightedAverage(y, dy):
  # computes weighted average of array y, using s.d. in dy[]
  # returns (average, 2 * s.e. of average)
  Sy = 0.0; Sw = 0.0
  j=0;
  for i in range(len(y)):
    if dy[i] == 0.0:	# dy = 0.0 ==> bad point!
      continue
    Sy = Sy + y[i] / (dy[i] ** 2)
    Sw = Sw + 1 / (dy[i] ** 2)
    j = j+1
  if j == 0:
    # no valid data records
    # set denominators to 1.0
    Sw = 1.0; j = 1
  ybar = Sy / Sw
  sd = Sw ** -0.5
  # compute 2 standard errors of the mean:
  # s.e. = s.d. / sqrt(n) ==> return 2*s.e.
  se = 2 * sd * (j**-0.5)
  return (ybar, se)

def sampleAverage(y):
  # computes average of sample array y, computing a s.e.
  # returns (average, 2 * s.e.)
  S=0.0; ss=0.0;
  N=len(y);
  if N < 2:
    return (0, -1)
  for i in range(N):
    S += y[i]
    ss += y[i]*y[i]
  ybar = S / N
  sigma = ( (ss - N*ybar*ybar)/(N-1) ) ** 0.5
  # compute 2 standard errors of the mean:
  # s.e. = s.d. / sqrt(n) ==> return 2*s.e.
  se = 2 * sigma * (N**-0.5)
  return (ybar, se)

def num_sort(a, b):
  try:
    A = float(a)
  except ValueError:
    # not a valid float string, so assign it -0.01
    # this way, it is less than 0.0 and before all positive numbers
    A = -0.01
  try:
    B = float(b)
  except ValueError:
    B = -0.01

  if A < B:
    return -1
  elif A > B:
    return 1
  else: return 0

def print_matrix(M, prec=2, width=9):
  # print matrix M with precision prec
  format = " %%%d.%dg"%(width, prec)
  (n, m) = Numeric.shape(M)
  sys.stdout.write("[")
  for i in range (n):
    for j in range(m):
      sys.stdout.write(format%M[i,j])
    if i == n-1:
      sys.stdout.write("]")
    sys.stdout.write(";\n ")

def get_uniq_ids(data):
  # construct dict of unique IDs from raw data matrix
  ids = {}
  for i in range(len(data)):
    if not ids.has_key(data[i].station_id):
      ids[data[i].station_id] = 1
  return ids

def findEarliest(rdata, list, dictFlag):
  T = []
  for j in range(len(list)):
    if dictFlag:	# data in dict, not array!
      for k in rdata.keys():
	if rdata[k].station_id == list[j]:
	  T.append(rdata[k].jul_day)
    else:		# data are array!
      for k in range(len(rdata)):
	if rdata[k].station_id == list[j]:
	  T.append(rdata[k].jul_day)
  # get the station id with earliest time
  match = list[T.index(min(T))]
  return match

def check_repeats(rdata, repeats):
  # check the repeat information against the raw data
  # first need the uniq station ids
  uniqIDs = get_uniq_ids(rdata)
  (R, changes) = update_repeats(uniqIDs, rdata, repeats, False)
  return(R, changes)

def check_proc_repeats(data, repeats):
  # check the repeat information against the processed data
  # first need the uniq station ids in a dict
  uniqIDs = {}
  for k in data.keys():
    uniqIDs[k] = 1
  (R, changes) = update_repeats(uniqIDs, data, repeats, 1)
  return(R, changes)

def update_repeats(uniqIDs, data, repeats, dictFlag):
  # check the repeat information against the uniq ids, rebuilding
  # as necessary
  # we assume that the repeat information is a superset of the
  # supplied data - no checking for station names, etc. is done
  changes = 0

  # new repeat dict.
  R = {}

  for i in repeats.keys():
    match = ""
    list = []
    if uniqIDs.has_key(i):
      # get id with earliest time
      list.append(i)
      for j in range(len(repeats[i])):
	list.append(repeats[i][j])
      match = findEarliest(data, list, dictFlag)
      R[match] = []
    else:
      changes = changes+1
      # don't have initial station, so lets iterate through and see if we
      # have any
      for j in range(len(repeats[i])):
	if uniqIDs.has_key(repeats[i][j]):
	  list.append(repeats[i][j])
      if list:
	match = findEarliest(data, list, dictFlag)
	R[match] = []
      else:
	continue

    # add existing stations
    if not R.has_key(i):
      if uniqIDs.has_key(i):
	R[match].append(i)
    for j in range(len(repeats[i])):
      if uniqIDs.has_key(repeats[i][j]):
	if not R.has_key(repeats[i][j]):
	  R[match].append(repeats[i][j])
      else:
	changes = changes+1

  # check if we have any stations in R that have no repeats; that is,
  # we removed all repeated occupations of a station
  for i in R.keys():
    if len(R[i]) < 1:
      del R[i];
      changes = changes+1;
  
  return (R, changes)

def remove_extra_parms(parms, rdata):
  # remove parameter entries without matching data records
  P = {}
  for i in range(len(rdata)):
    if parms.has_key(rdata[i].station_id):
      if not P.has_key(rdata[i].station_id):
	P[rdata[i].station_id] = parms[rdata[i].station_id]
  return P


def get_file(filename):
  # read a file, stripping out comment and blank lines
  # return an array of the remaining lines
  file = open(filename, "rt");
  lines = []
  while(1):
    line = file.readline();
    if not line: break;
    line = string.strip(line);
    if not line: continue;
    if line[0] == "#": continue;
    lines.append(line);
  return lines

# Use Thiele extrapolation for each station in the data
# objects.
# sample_flag is a boolean: true ==> data are raw samples, not
# averages, so use sampleAverage to find s.e. of time series, not
# weightedAverage
def thiele_extrapolate(data, parms, names, options):
  adata = {}; series = {}; sseries = {}; tseries = {}
  GMT = {}
  if options.has_key("thiele_tolerance"):
    thiele.THIELE_TOLERANCE = options["thiele_tolerance"]
  for i in range(len(data)):
    series[data[i].station_id] = []
    sseries[data[i].station_id] = []
    tseries[data[i].station_id] = []
  for i in range(len(data)):
    series[data[i].station_id].append(data[i].gravity);
    sseries[data[i].station_id].append(data[i].sigma);
    tseries[data[i].station_id].append(data[i].jul_day);
    GMT[data[i].station_id] = data[i].GMT_Diff
  for i in series.keys():
    if len(series[i]) < 2:
      grav = -1.0
      sigma = 0.0
      time = 0.0
    else:
      if Truth(options["grav_samples"]):
        # filter the series to suppress noise
        series[i] = thiele.seriesfilter(series[i],options["thiele_filt_radius"])
      grav = thiele.extrapolate(series[i]);
      if grav == None:
        sys.stderr.write("Error on station %s\n"%i)
        sys.exit()
      if Truth(options["grav_samples"]):	# unweighted average for sigma
        (toss, sigma) = sampleAverage(series[i])
      else:
	(toss, sigma) = weightedAverage(series[i], sseries[i]);
      # don't average time, take last reading time
      time = max(tseries[i]);

    incoming = GravityValue()
    incoming.station_id = i
    incoming.raw_gravity = grav
    incoming.raw_sigma = sigma
    incoming.time = time
    incoming.GMT_Diff = GMT[i]
    adata[i] = incoming

  # get names of the stations, in addition to IDs
  for i in adata.keys():
    if names.has_key(i):
      adata[i].name = names[i]
    else:
      adata[i].name = "UNKNOWN"
  # reset lat/lon to values from parameter file
  for i in adata.keys():
    adata[i].lat = parms[i].lat
    adata[i].lon = parms[i].lon
    adata[i].elevation = parms[i].elevation

  return (adata, 0, 0)

def ones(length):
  a = []
  for i in range(length):
    a.append(1);
  return a;

def Truth(V):
  T = 0
  if type(V) == StringType:
    # test against "1", "yes"
    if V == "1" or string.upper(V) == "YES":
      T = 1
    else:
      T = 0
  elif type(V) == IntType:
    T = V
  elif type(V) == FloatType:
    T = V
  else:
    # catchall, hopefully
    if str(V) == "0" or string.upper(str(V)) == "NONE":
      T = 0
  return T

