/* Copyright 1996 by A. J. Krener and Dean Hickerson */

/******************************************************************************/
/* This file implements the function [h,nh]=mply(f,nf,df,g,ng,dg,d).          */
/*                                                                            */
/* (written by Dean Hickerson, 8/24/94)                                       */
/* (modified 8/20/2001 to work with complex numbers)                          */
/******************************************************************************/

#define shortname   "mply"
#define fullname    "mply(f,nf,df,g,ng,dg,d)"

#include <stdio.h>
#include "mex.h"
#include "mexdefines.h"

/*************/
/* Externals */
/*************/
#ifdef __STDC__
extern int crdsum(int m, int n0, int n1);
extern boolean samesize(int fr, int fc, int wantedfr, int wantedfc);
extern void initvecdesc(int numsubvecs, int subveclth[]);
extern void red2lex(double f[], int deg0,int deg1,int outsize, double *fl[]);
extern void lex2red(double fl[], int deg0,int deg1,int outsize, double *f[]);
extern void red2lexcomplex(double f[], double fi[], int deg0, int deg1,
		int outsize, double *fl[], double *fli[]);
extern void lex2redcomplex(double fl[], double fli[], int deg0, int deg1,
		int outsize, double *f[], double *fi[]);
extern void extendinputs(double f[], int insize, int deg0, int deg1,
        int outsize, boolean varpresent[], double *fext[], boolean *extending);
extern void extendinputscomplex(double f[], double fi[], int insize,
                         int deg0, int deg1, int outsize, boolean varpresent[],
                         double *fext[], double *fexti[], boolean *extending);
extern void addmatprod(double a[], int adeg0, int adeg1,
                       double b[], int bdeg0, int bdeg1,
                       int insize, int m1, int m2, int m3,
                       double prod[], int pdeg0, int pdeg1, boolean subtract);
extern void addmatprodcomplex(double a[], double ai[], int adeg0, int adeg1,
                       double b[], double bi[], int bdeg0, int bdeg1,
                       int insize, int m1, int m2, int m3,
                       double prod[], double prodi[], int pdeg0, int pdeg1,
					   boolean subtract);
#else
extern int crdsum();
extern boolean samesize();
extern void initvecdesc();
extern void red2lex();
extern void lex2red();
extern void red2lexcomplex();
extern void lex2redcomplex();
extern void extendinputs();
extern void extendinputscomplex();
extern void addmatprod();
extern void addmatprodcomplex();
#endif

/******************************************************************************/

#ifdef __STDC__
void mexFunction(int nlhs, mxArray *plhs[], int nrhs, const mxArray *prhs[])
#else
mexFunction(nlhs, plhs, nrhs, prhs)
int nlhs, nrhs;
mxArray *plhs[], *prhs[];
#endif
{   double *f, *nf, *df, *g, *ng, *dg, *d, *h, *nh, *fl, *gl, *fle, *gle, *hl,
		*fi, *gi, *hi, *fli, *gli, *hli, *flei, *glei;
    int fr, fc, nfr, nfc, dfr, dfc, gr, gc, ngr, ngc, dgr, dgc, dr, dc,
        fdeg0, fdeg1, gdeg0, gdeg1, hdeg0, hdeg1,
        fnumsubvecs, gnumsubvecs, hnumsubvecs, foutr, foutc, goutr, goutc,
        i0, i1, i2, il, i, finsize, ginsize, hinsize;
    int fsubveclth[MAXSUBVECS], gsubveclth[MAXSUBVECS], hsubveclth[MAXSUBVECS];
    boolean fvarpresent[MAXINSIZE], gvarpresent[MAXINSIZE];
    boolean extending, complex;
    char outbuff[200];
    int inputargcounter;

    /***************************************************/
    /* Check that there are 7 inputs and <= 2 outputs. */
    /***************************************************/
    if (nrhs != 7)  err("ERROR: mply requires 7 inputs.")
    if (nlhs > 2)   err("ERROR: mply has only 2 outputs.")
    checktypes
	complex = mxIsComplex(prhs[0]);

    /****************************************/
    /* Get sizes of and pointers to inputs. */
    /****************************************/
    f  = mxGetPr(prhs[0]);  fr  = mxGetM(prhs[0]);  fc  = mxGetN(prhs[0]);
    nf = mxGetPr(prhs[1]);  nfr = mxGetM(prhs[1]);  nfc = mxGetN(prhs[1]);
    df = mxGetPr(prhs[2]);  dfr = mxGetM(prhs[2]);  dfc = mxGetN(prhs[2]);
    g  = mxGetPr(prhs[3]);  gr  = mxGetM(prhs[3]);  gc  = mxGetN(prhs[3]);
    ng = mxGetPr(prhs[4]);  ngr = mxGetM(prhs[4]);  ngc = mxGetN(prhs[4]);
    dg = mxGetPr(prhs[5]);  dgr = mxGetM(prhs[5]);  dgc = mxGetN(prhs[5]);
    d  = mxGetPr(prhs[6]);  dr  = mxGetM(prhs[6]);  dc  = mxGetN(prhs[6]);
	if (complex)
	  { fi = mxGetPi(prhs[0]);
	    gi = mxGetPi(prhs[3]);
	  }

    if (nfc != 3)  err(
      "ERROR: In mply(f,nf,df,g,ng,dg,d), nf must have 3 columns.")

    if (ngc != 2 && ngc != 3)  err(
      "ERROR: In mply(f,nf,df,g,ng,dg,d), ng must have 2 or 3 columns.")

    if (nfr != ngr)  err(
"ERROR: In mply(f,nf,df,g,ng,dg,d), nf and ng must have same number of rows.")

    checkrange(2, "df")
    checkrange(5, "dg")
    checkrange(6, "d")

    /************************************************/
    /* Get the min and max degrees for f, g, and h. */
    /************************************************/
    fdeg0 = df[0];  fdeg1 = df[dfc-1];
    gdeg0 = dg[0];  gdeg1 = dg[dgc-1];
    hdeg0 = d[0];   hdeg1 = d[dc-1];

    /******************************************************************/
    /* Scan nf and ng to determine input and output sizes of f and g. */
    /******************************************************************/
    finsize=ginsize=hinsize = fnumsubvecs=gnumsubvecs=hnumsubvecs = 0;
    foutr=foutc=goutr=goutc = 0;

    /***********************************************************/
    /* In the following loop:                                  */
    /* i0 is an index into the first columns of nf and ng,     */
    /* i1 is an index into the second columns of nf and ng,    */
    /* i2 is an index into the third column of nf,             */
    /* il is an index into the last (2nd or 3rd) column of ng. */
    /***********************************************************/
    for (i0=0, i1=nfr, i2=nfr<<1, il=nfr<<(ngc-2);  i0<nfr;
                                                        i0++, i1++, i2++, il++)
      { /**********************************************/
        /* Record info about output sizes of f and g. */
        /**********************************************/
        foutr += nf[i0];
        foutc += nf[i1];
        goutr += ng[i0];
        goutc += ng[i1];    /* Irrelevant if ngc=2 */

        /*********************************************/
        /* Record info about input sizes of f and g. */
        /*********************************************/
        if (nf[i2])
          { finsize += (fsubveclth[fnumsubvecs++] = nf[i2]);
            hsubveclth[hnumsubvecs++] = nf[i2];
            if (ng[il])         /* nf[i2] != 0,  ng[il] != 0 */
              { ginsize += (gsubveclth[gnumsubvecs++] = ng[il]);
                if (nf[i2] == ng[il])
                  for (i=nf[i2]; i; i--)
                    { fvarpresent[hinsize]   = TRUE;
                      gvarpresent[hinsize++] = TRUE;
                    }
#ifdef THINK_C
                else err2(
                  "ERROR: In mply(f,nf,df,g,ng,dg,d), corresponding elements ",
                  "of 3rd columns\rof nf and ng must be equal or zero.")
#else
                else err2(
                  "ERROR: In mply(f,nf,df,g,ng,dg,d), corresponding elements ",
                  "of 3rd columns\nof nf and ng must be equal or zero.")
#endif
              }
            else                /* nf[i2] != 0,  ng[il] = 0 */
              for (i=nf[i2]; i; i--)
                { fvarpresent[hinsize]   = TRUE;
                  gvarpresent[hinsize++] = FALSE;
                }
          }
        else
          { if (ng[il])         /* nf[i2] = 0,  ng[il] != 0 */
              { ginsize += (gsubveclth[gnumsubvecs++] = ng[il]);
                hsubveclth[hnumsubvecs++] = ng[il];
                for (i=ng[il]; i; i--)
                  { fvarpresent[hinsize]   = FALSE;
                    gvarpresent[hinsize++] = TRUE;
                  }
              }
          }
      }

    if (foutc != goutr)
	  err2("ERROR: In mply(f,nf,df,g,ng,dg,d), 2nd col of nf and 1st col ",
	  "of ng\nmust have same sum.")

    if (ngc == 2)  goutc=1;

    if (fdeg0 < 0 || fdeg1 < fdeg0) err(
      "ERROR: In mply(f,nf,[df0 df1],g,ng,dg,d), must have 0 <= df0 <= df1.")

    if (gdeg0 < 0 || gdeg1 < gdeg0) err(
      "ERROR: In mply(f,nf,df,g,ng,[dg0 dg1],d), must have 0 <= dg0 <= dg1.")

    if (hdeg0 < 0 || hdeg1 < hdeg0) err(
      "ERROR: In mply(f,nf,df,g,ng,dg,[d0 d1]), must have 0 <= d0 <= d1.")

    /*************************************************************/
    /* Check that f and g have the sizes specified by nf and ng. */
    /*************************************************************/
    if (!samesize(fr, fc, foutr*foutc, crdsum(finsize,fdeg0,fdeg1)))
      { sprintf(outbuff,
          "ERROR: In %s, f should be %ld by %ld, not %ld by %ld.",
          fullname, foutr*foutc, crdsum(finsize,fdeg0,fdeg1), fr, fc);
        err(outbuff)
      }

    if (!samesize(gr, gc, goutr*goutc, crdsum(ginsize,gdeg0,gdeg1)))
      { sprintf(outbuff,
          "ERROR: In %s, g should be %ld by %ld, not %ld by %ld.",
          fullname, goutr*goutc, crdsum(ginsize,gdeg0,gdeg1), gr, gc);
        err(outbuff)
      }

    /*********************************/
    /* If 2 output args, compute nh. */
    /*********************************/
    if (nlhs == 2)
      { plhs[1] = mxCreateDoubleMatrix(ngr,ngc,mxREAL);
        nh = mxGetPr(plhs[1]);
        for (i0=0, i1=nfr, i2=nfr<<1, il=nfr<<(ngc-2);  i0<nfr;
                                                        i0++, i1++, i2++, il++)
          { nh[i0] = nf[i0];
            nh[i1] = ng[i1];    /* Overwritten by next statement if ngc=2 */
            nh[il] = max(nf[i2],ng[il]);
          }
      }

    /**************************************************/
    /* Convert f and g to lexicographic reduced form. */
    /**************************************************/

	if (complex)
	  { initvecdesc(fnumsubvecs, fsubveclth);
		fl = NULL;
        red2lexcomplex(f, fi, fdeg0, fdeg1, foutr*foutc, &fl, &fli);
        initvecdesc(gnumsubvecs, gsubveclth);
        gl = NULL;
		red2lexcomplex(g, gi, gdeg0, gdeg1, goutr*goutc, &gl, &gli);
	  }
    else
      { initvecdesc(fnumsubvecs, fsubveclth);
	    fl = NULL;
        red2lex(f, fdeg0, fdeg1, foutr*foutc, &fl);
        initvecdesc(gnumsubvecs, gsubveclth);
	    gl = NULL;
        red2lex(g, gdeg0, gdeg1, goutr*goutc, &gl);
	  }

    /*******************************************/
    /* If necessary, extend inputs of f and g. */
    /*******************************************/
    if (complex)
	  { fle = gle = NULL;
        extendinputscomplex(fl,fli,hinsize,fdeg0,fdeg1,foutr*foutc,fvarpresent,
		                                                 &fle,&flei,&extending);
        if (extending)  { mxFree(fl); mxFree(fli); }

        extendinputscomplex(gl,gli,hinsize,gdeg0,gdeg1,goutr*goutc,gvarpresent,
		                                                 &gle,&glei,&extending);
        if (extending)  { mxFree(gl); mxFree(gli); }
	  }
    else
	  { fle = gle = NULL;
        extendinputs(fl,hinsize,fdeg0,fdeg1,foutr*foutc,fvarpresent,&fle,
                                                                    &extending);
        if (extending)  mxFree(fl);

        extendinputs(gl,hinsize,gdeg0,gdeg1,goutr*goutc,gvarpresent,&gle,
                                                                    &extending);
        if (extending)  mxFree(gl);
	  }

    /*********************/
    /* Multiply f and g. */
    /*********************/
    hl = mxCalloc(foutr*goutc*(crdsum(hinsize,hdeg0,hdeg1)), sizeof(double));
	if (complex)
      hli = mxCalloc(foutr*goutc*(crdsum(hinsize,hdeg0,hdeg1)),sizeof(double));

	if (complex)
      addmatprodcomplex(fle, flei, fdeg0, fdeg1, gle, glei, gdeg0, gdeg1,
            hinsize, foutr, foutc, goutc, hl, hli, hdeg0, hdeg1, FALSE);
    else
      addmatprod(fle, fdeg0, fdeg1, gle, gdeg0, gdeg1,
            hinsize, foutr, foutc, goutc, hl, hdeg0, hdeg1, FALSE);

    mxFree(fle);
    mxFree(gle);
	if (complex)
	  { mxFree(flei);
	    mxFree(glei);
	  }

    /***************************************/
    /* Allocate (and clear) output matrix. */
    /***************************************/
	if (complex)
      plhs[0] = mxCreateDoubleMatrix(foutr*goutc, crdsum(hinsize,hdeg0,hdeg1),
																mxCOMPLEX);
	else
      plhs[0] = mxCreateDoubleMatrix(foutr*goutc, crdsum(hinsize,hdeg0,hdeg1),
																mxREAL);

    h = mxGetPr(plhs[0]);

	if (complex)
	  hi = mxGetPi(plhs[0]);

    /************************************/
    /* Convert product to reduced form. */
    /************************************/
	if (complex)
      { initvecdesc(hnumsubvecs, hsubveclth);
	    lex2redcomplex(hl, hli, hdeg0, hdeg1, foutr*goutc, &h, &hi);
	  }
	else
      { initvecdesc(hnumsubvecs, hsubveclth);
        lex2red(hl, hdeg0, hdeg1, foutr*goutc, &h);
	  }
}
