/* Perform matrix multiplication between diagonal matrix stored as vector and
   normal matrix */
#include "Python.h"
#include "arrayobject.h"
#include <math.h>

/* Python interface functions */
void initdiag(void);
static PyObject *mm(PyObject *self, PyObject *args);


static PyMethodDef diagMethods[]={{"mm", mm, METH_VARARGS}, {NULL, NULL}};

void initdiag(void)
{
	(void) Py_InitModule("diag", diagMethods);
}

static PyObject *mm(PyObject *self, PyObject *args)
{
  int ok;

  long i, j;

  PyObject *i1, *i2;
  PyArrayObject *result;
  PyArrayObject *a, *b;
  /* parse Python args into C vars */
  ok = PyArg_ParseTuple(args, "OO:diag.mm", (PyObject *)&a, (PyObject *)&b);
  if (!ok) {
    return NULL;
    }
printf("DIAG: parsed args into objects\n");

  /* compute the result */
  /* test the dimensions on array a */
  if (a->nd > 1) {
printf("DIAG: a->nd > 1 ==> A is non-diag\n");
    /* a is not diagonal matrix */
    /* so we have a*diag(b) */
printf("%p\n", a->dimensions);
    result = (PyArrayObject *)PyArray_FromDims(2, a->dimensions, PyArray_DOUBLE);
printf("DIAG: created result array\n");
    for(i=0; i<a->dimensions[0]; i++) {
      for(j=0; j<a->dimensions[1]; j++) {
	result->data[i*result->strides[0] + j*result->strides[1]] =
	 a->data[i*a->strides[0] + j*a->strides[1]]*b->data[j*b->strides[0]];
        }
      }
    }
  else {
printf("DIAG: a->nd <= 1 ==> A is diag\n");
    /* a is diagonal matrix */
    /* so we have diag(a)*b */
    result = (PyArrayObject *)PyArray_FromDims(2, b->dimensions, PyArray_DOUBLE);
    for(i=0; i<b->dimensions[0]; i++) {
      for(j=0; j<b->dimensions[1]; j++) {
	result->data[i*result->strides[0] + j*result->strides[1]] =
	 b->data[i*b->strides[0] + j*b->strides[1]]*a->data[j*a->strides[0]];
        }
      }
    }

  Py_DECREF(a);
  Py_DECREF(b);
  return PyArray_Return(result);
}
